diff options
| author | mo khan <mo@mokhan.ca> | 2025-05-14 13:18:54 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-05-14 13:18:54 -0600 |
| commit | 4b2d609a0efcc1d9b2f1a08f954d067ad1d9cd1e (patch) | |
| tree | 0afacf9217b0569130da6b97d4197331681bf119 /vendor/github.com/playwright-community/playwright-go/connection.go | |
| parent | ab373d1fe698cd3f53258c09bc8515d88a6d0b9e (diff) | |
test: use playwright to test out an OIDC login
Diffstat (limited to 'vendor/github.com/playwright-community/playwright-go/connection.go')
| -rw-r--r-- | vendor/github.com/playwright-community/playwright-go/connection.go | 401 |
1 files changed, 401 insertions, 0 deletions
diff --git a/vendor/github.com/playwright-community/playwright-go/connection.go b/vendor/github.com/playwright-community/playwright-go/connection.go new file mode 100644 index 0000000..ba1e365 --- /dev/null +++ b/vendor/github.com/playwright-community/playwright-go/connection.go @@ -0,0 +1,401 @@ +package playwright + +import ( + "errors" + "fmt" + "reflect" + "regexp" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/go-stack/stack" + "github.com/playwright-community/playwright-go/internal/safe" +) + +var ( + pkgSourcePathPattern = regexp.MustCompile(`.+[\\/]playwright-go[\\/][^\\/]+\.go`) + apiNameTransform = regexp.MustCompile(`(?U)\(\*(.+)(Impl)?\)`) +) + +type connection struct { + transport transport + apiZone sync.Map + objects *safe.SyncMap[string, *channelOwner] + lastID atomic.Uint32 + rootObject *rootChannelOwner + callbacks *safe.SyncMap[uint32, *protocolCallback] + afterClose func() + onClose func() error + isRemote bool + localUtils *localUtilsImpl + tracingCount atomic.Int32 + abort chan struct{} + abortOnce sync.Once + err *safeValue[error] // for event listener error + closedError *safeValue[error] +} + +func (c *connection) Start() (*Playwright, error) { + go func() { + for { + msg, err := c.transport.Poll() + if err != nil { + _ = c.transport.Close() + c.cleanup(err) + return + } + c.Dispatch(msg) + } + }() + + c.onClose = func() error { + if err := c.transport.Close(); err != nil { + return err + } + return nil + } + + return c.rootObject.initialize() +} + +func (c *connection) Stop() error { + if err := c.onClose(); err != nil { + return err + } + c.cleanup() + return nil +} + +func (c *connection) cleanup(cause ...error) { + if len(cause) > 0 { + c.closedError.Set(fmt.Errorf("%w: %w", ErrTargetClosed, cause[0])) + } else { + c.closedError.Set(ErrTargetClosed) + } + if c.afterClose != nil { + c.afterClose() + } + c.abortOnce.Do(func() { + select { + case <-c.abort: + default: + close(c.abort) + } + }) +} + +func (c *connection) Dispatch(msg *message) { + if c.closedError.Get() != nil { + return + } + method := msg.Method + if msg.ID != 0 { + cb, _ := c.callbacks.LoadAndDelete(uint32(msg.ID)) + if cb.noReply { + return + } + if msg.Error != nil { + cb.SetError(parseError(msg.Error.Error)) + } else { + cb.SetResult(c.replaceGuidsWithChannels(msg.Result).(map[string]interface{})) + } + return + } + object, _ := c.objects.Load(msg.GUID) + if method == "__create__" { + c.createRemoteObject( + object, msg.Params["type"].(string), msg.Params["guid"].(string), msg.Params["initializer"], + ) + return + } + if object == nil { + return + } + if method == "__adopt__" { + child, ok := c.objects.Load(msg.Params["guid"].(string)) + if !ok { + return + } + object.adopt(child) + return + } + if method == "__dispose__" { + reason, ok := msg.Params["reason"] + if ok { + object.dispose(reason.(string)) + } else { + object.dispose() + } + return + } + if object.objectType == "JsonPipe" { + object.channel.Emit(method, msg.Params) + } else { + object.channel.Emit(method, c.replaceGuidsWithChannels(msg.Params)) + } +} + +func (c *connection) LocalUtils() *localUtilsImpl { + return c.localUtils +} + +func (c *connection) createRemoteObject(parent *channelOwner, objectType string, guid string, initializer interface{}) interface{} { + initializer = c.replaceGuidsWithChannels(initializer) + result := createObjectFactory(parent, objectType, guid, initializer.(map[string]interface{})) + return result +} + +func (c *connection) WrapAPICall(cb func() (interface{}, error), isInternal bool) (interface{}, error) { + if _, ok := c.apiZone.Load("apiZone"); ok { + return cb() + } + c.apiZone.Store("apiZone", serializeCallStack(isInternal)) + return cb() +} + +func (c *connection) replaceGuidsWithChannels(payload interface{}) interface{} { + if payload == nil { + return nil + } + v := reflect.ValueOf(payload) + if v.Kind() == reflect.Slice { + listV := payload.([]interface{}) + for i := 0; i < len(listV); i++ { + listV[i] = c.replaceGuidsWithChannels(listV[i]) + } + return listV + } + if v.Kind() == reflect.Map { + mapV := payload.(map[string]interface{}) + if guid, hasGUID := mapV["guid"]; hasGUID { + if channelOwner, ok := c.objects.Load(guid.(string)); ok { + return channelOwner.channel + } + } + for key := range mapV { + mapV[key] = c.replaceGuidsWithChannels(mapV[key]) + } + return mapV + } + return payload +} + +func (c *connection) sendMessageToServer(object *channelOwner, method string, params interface{}, noReply bool) (cb *protocolCallback) { + cb = newProtocolCallback(noReply, c.abort) + + if err := c.closedError.Get(); err != nil { + cb.SetError(err) + return + } + if object.wasCollected { + cb.SetError(errors.New("The object has been collected to prevent unbounded heap growth.")) + return + } + + id := c.lastID.Add(1) + c.callbacks.Store(id, cb) + var ( + metadata = make(map[string]interface{}, 0) + stack = make([]map[string]interface{}, 0) + ) + apiZone, ok := c.apiZone.LoadAndDelete("apiZone") + if ok { + for k, v := range apiZone.(parsedStackTrace).metadata { + metadata[k] = v + } + stack = append(stack, apiZone.(parsedStackTrace).frames...) + } + metadata["wallTime"] = time.Now().UnixMilli() + message := map[string]interface{}{ + "id": id, + "guid": object.guid, + "method": method, + "params": params, // channel.MarshalJSON will replace channel with guid + "metadata": metadata, + } + if c.tracingCount.Load() > 0 && len(stack) > 0 && object.guid != "localUtils" { + c.LocalUtils().AddStackToTracingNoReply(id, stack) + } + + if err := c.transport.Send(message); err != nil { + cb.SetError(fmt.Errorf("could not send message: %w", err)) + return + } + + return +} + +func (c *connection) setInTracing(isTracing bool) { + if isTracing { + c.tracingCount.Add(1) + } else { + c.tracingCount.Add(-1) + } +} + +type parsedStackTrace struct { + frames []map[string]interface{} + metadata map[string]interface{} +} + +func serializeCallStack(isInternal bool) parsedStackTrace { + st := stack.Trace().TrimRuntime() + if len(st) == 0 { // https://github.com/go-stack/stack/issues/27 + st = stack.Trace() + } + + lastInternalIndex := 0 + for i, s := range st { + if pkgSourcePathPattern.MatchString(s.Frame().File) { + lastInternalIndex = i + } + } + apiName := "" + if !isInternal { + apiName = fmt.Sprintf("%n", st[lastInternalIndex]) + } + st = st.TrimBelow(st[lastInternalIndex]) + + callStack := make([]map[string]interface{}, 0) + for i, s := range st { + if i == 0 { + continue + } + callStack = append(callStack, map[string]interface{}{ + "file": s.Frame().File, + "line": s.Frame().Line, + "column": 0, + "function": s.Frame().Function, + }) + } + metadata := make(map[string]interface{}) + if len(st) > 1 { + metadata["location"] = serializeCallLocation(st[1]) + } + apiName = apiNameTransform.ReplaceAllString(apiName, "$1") + if len(apiName) > 1 { + apiName = strings.ToUpper(apiName[:1]) + apiName[1:] + } + metadata["apiName"] = apiName + metadata["isInternal"] = isInternal + return parsedStackTrace{ + metadata: metadata, + frames: callStack, + } +} + +func serializeCallLocation(caller stack.Call) map[string]interface{} { + line, _ := strconv.Atoi(fmt.Sprintf("%d", caller)) + return map[string]interface{}{ + "file": fmt.Sprintf("%s", caller), + "line": line, + } +} + +func newConnection(transport transport, localUtils ...*localUtilsImpl) *connection { + connection := &connection{ + abort: make(chan struct{}, 1), + callbacks: safe.NewSyncMap[uint32, *protocolCallback](), + objects: safe.NewSyncMap[string, *channelOwner](), + transport: transport, + isRemote: false, + err: &safeValue[error]{}, + closedError: &safeValue[error]{}, + } + if len(localUtils) > 0 { + connection.localUtils = localUtils[0] + connection.isRemote = true + } + connection.rootObject = newRootChannelOwner(connection) + return connection +} + +func fromChannel(v interface{}) interface{} { + return v.(*channel).object +} + +func fromNullableChannel(v interface{}) interface{} { + if v == nil { + return nil + } + return fromChannel(v) +} + +type protocolCallback struct { + done chan struct{} + noReply bool + abort <-chan struct{} + once sync.Once + value map[string]interface{} + err error +} + +func (pc *protocolCallback) setResultOnce(result map[string]interface{}, err error) { + pc.once.Do(func() { + pc.value = result + pc.err = err + close(pc.done) + }) +} + +func (pc *protocolCallback) waitResult() { + if pc.noReply { + return + } + select { + case <-pc.done: // wait for result + return + case <-pc.abort: + select { + case <-pc.done: + return + default: + pc.err = errors.New("Connection closed") + return + } + } +} + +func (pc *protocolCallback) SetError(err error) { + pc.setResultOnce(nil, err) +} + +func (pc *protocolCallback) SetResult(result map[string]interface{}) { + pc.setResultOnce(result, nil) +} + +func (pc *protocolCallback) GetResult() (map[string]interface{}, error) { + pc.waitResult() + return pc.value, pc.err +} + +// GetResultValue returns value if the map has only one element +func (pc *protocolCallback) GetResultValue() (interface{}, error) { + pc.waitResult() + if len(pc.value) == 0 { // empty map treated as nil + return nil, pc.err + } + if len(pc.value) == 1 { + for key := range pc.value { + return pc.value[key], pc.err + } + } + + return pc.value, pc.err +} + +func newProtocolCallback(noReply bool, abort <-chan struct{}) *protocolCallback { + if noReply { + return &protocolCallback{ + noReply: true, + abort: abort, + } + } + return &protocolCallback{ + done: make(chan struct{}, 1), + abort: abort, + } +} |
