diff options
| author | mo khan <mo@mokhan.ca> | 2025-05-11 21:12:57 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-05-11 21:12:57 -0600 |
| commit | 60440f90dca28e99a31dd328c5f6d5dc0f9b6a2e (patch) | |
| tree | 2f54adf55086516f162f0a55a5347e6b25f7f176 /vendor/github.com/oauth2-proxy/mockoidc/mockoidc.go | |
| parent | 05ca9b8d3a9c7203a3a3b590beaa400900bd9007 (diff) | |
chore: vendor go dependencies
Diffstat (limited to 'vendor/github.com/oauth2-proxy/mockoidc/mockoidc.go')
| -rw-r--r-- | vendor/github.com/oauth2-proxy/mockoidc/mockoidc.go | 273 |
1 files changed, 273 insertions, 0 deletions
diff --git a/vendor/github.com/oauth2-proxy/mockoidc/mockoidc.go b/vendor/github.com/oauth2-proxy/mockoidc/mockoidc.go new file mode 100644 index 0000000..e66ca58 --- /dev/null +++ b/vendor/github.com/oauth2-proxy/mockoidc/mockoidc.go @@ -0,0 +1,273 @@ +package mockoidc + +import ( + "context" + "crypto/rsa" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "time" +) + +// NowFunc is an overrideable version of `time.Now`. Tests that need to +// manipulate time can use their own `func() Time` function. +var NowFunc = time.Now + +// MockOIDC is a minimal OIDC server for use in OIDC authentication +// integration testing. +type MockOIDC struct { + ClientID string + ClientSecret string + + AccessTTL time.Duration + RefreshTTL time.Duration + + CodeChallengeMethodsSupported []string + + // Normally, these would be private. Expose them publicly for + // power users. + Server *http.Server + Keypair *Keypair + SessionStore *SessionStore + UserQueue *UserQueue + ErrorQueue *ErrorQueue + + tlsConfig *tls.Config + middleware []func(http.Handler) http.Handler + fastForward time.Duration +} + +// Config gives the various settings MockOIDC starts with that a test +// application server would need to be configured with. +type Config struct { + ClientID string + ClientSecret string + Issuer string + + AccessTTL time.Duration + RefreshTTL time.Duration + + CodeChallengeMethodsSupported []string +} + +// NewServer configures a new MockOIDC that isn't started. An existing +// rsa.PrivateKey can be passed for token signing operations in case +// the default Keypair isn't desired. +func NewServer(key *rsa.PrivateKey) (*MockOIDC, error) { + clientID, err := randomNonce(24) + if err != nil { + return nil, err + } + clientSecret, err := randomNonce(24) + if err != nil { + return nil, err + } + keypair, err := NewKeypair(key) + if err != nil { + return nil, err + } + + return &MockOIDC{ + ClientID: clientID, + ClientSecret: clientSecret, + AccessTTL: time.Duration(10) * time.Minute, + RefreshTTL: time.Duration(60) * time.Minute, + CodeChallengeMethodsSupported: []string{"plain", "S256"}, + Keypair: keypair, + SessionStore: NewSessionStore(), + UserQueue: &UserQueue{}, + ErrorQueue: &ErrorQueue{}, + }, nil +} + +// Run creates a default MockOIDC server and starts it +func Run() (*MockOIDC, error) { + return RunTLS(nil) +} + +// RunTLS creates a default MockOIDC server and starts it. It takes a +// tester configured tls.Config for TLS support. +func RunTLS(cfg *tls.Config) (*MockOIDC, error) { + m, err := NewServer(nil) + if err != nil { + return nil, err + } + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + return m, m.Start(ln, cfg) +} + +// Start starts the MockOIDC server in its own Goroutine on the provided +// net.Listener. In generic `Run`, this defaults to `127.0.0.1:0` +func (m *MockOIDC) Start(ln net.Listener, cfg *tls.Config) error { + if m.Server != nil { + return errors.New("server already started") + } + + handler := http.NewServeMux() + handler.Handle(AuthorizationEndpoint, m.chainMiddleware(m.Authorize)) + handler.Handle(TokenEndpoint, m.chainMiddleware(m.Token)) + handler.Handle(UserinfoEndpoint, m.chainMiddleware(m.Userinfo)) + handler.Handle(JWKSEndpoint, m.chainMiddleware(m.JWKS)) + handler.Handle(DiscoveryEndpoint, m.chainMiddleware(m.Discovery)) + + m.Server = &http.Server{ + Addr: ln.Addr().String(), + Handler: handler, + TLSConfig: cfg, + } + // Track this to know if we are https + m.tlsConfig = cfg + + go func() { + err := m.Server.Serve(ln) + if err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + return nil +} + +// Shutdown stops the MockOIDC server. Use this to cleanup test runs. +func (m *MockOIDC) Shutdown() error { + return m.Server.Shutdown(context.Background()) +} + +func (m *MockOIDC) AddMiddleware(mw func(http.Handler) http.Handler) error { + if m.Server != nil { + return errors.New("server already started") + } + + m.middleware = append(m.middleware, mw) + return nil +} + +// Config returns the Config with options a connection application or unit +// tests need to be aware of. +func (m *MockOIDC) Config() *Config { + return &Config{ + ClientID: m.ClientID, + ClientSecret: m.ClientSecret, + Issuer: m.Issuer(), + CodeChallengeMethodsSupported: m.CodeChallengeMethodsSupported, + AccessTTL: m.AccessTTL, + RefreshTTL: m.RefreshTTL, + } +} + +// QueueUser allows adding mock User objects to the authentication queue. +// Calls to the `authorization_endpoint` will pop these mock User objects +// off the queue and create a session with them. +func (m *MockOIDC) QueueUser(user User) { + m.UserQueue.Push(user) +} + +// QueueCode allows adding mock code strings to the authentication queue. +// Calls to the `authorization_endpoint` will pop these code strings +// off the queue and create a session with them and return them as the +// code parameter in the response. +func (m *MockOIDC) QueueCode(code string) { + m.SessionStore.CodeQueue.Push(code) +} + +// QueueError allows queueing arbitrary errors for the next handler calls +// to return. +func (m *MockOIDC) QueueError(se *ServerError) { + m.ErrorQueue.Push(se) +} + +// FastForward moves the MockOIDC's internal view of time forward. +// Use this to test token expirations in your tests. +func (m *MockOIDC) FastForward(d time.Duration) time.Duration { + m.fastForward = m.fastForward + d + return m.fastForward +} + +// Now is what MockOIDC thinks time.Now is +func (m *MockOIDC) Now() time.Time { + return NowFunc().Add(m.fastForward) +} + +// Addr returns the server address (if started) +func (m *MockOIDC) Addr() string { + if m.Server == nil { + return "" + } + proto := "http" + if m.tlsConfig != nil { + proto = "https" + } + return fmt.Sprintf("%s://%s", proto, m.Server.Addr) +} + +// Issuer returns the OIDC Issuer that will be in `iss` token claims +func (m *MockOIDC) Issuer() string { + if m.Server == nil { + return "" + } + return m.Addr() + IssuerBase +} + +// DiscoveryEndpoint returns the full `/.well-known/openid-configuration` URL +func (m *MockOIDC) DiscoveryEndpoint() string { + if m.Server == nil { + return "" + } + return m.Addr() + DiscoveryEndpoint +} + +// AuthorizationEndpoint returns the OIDC `authorization_endpoint` +func (m *MockOIDC) AuthorizationEndpoint() string { + if m.Server == nil { + return "" + } + return m.Addr() + AuthorizationEndpoint +} + +// TokenEndpoint returns the OIDC `token_endpoint` +func (m *MockOIDC) TokenEndpoint() string { + if m.Server == nil { + return "" + } + return m.Addr() + TokenEndpoint +} + +// UserinfoEndpoint returns the OIDC `userinfo_endpoint` +func (m *MockOIDC) UserinfoEndpoint() string { + if m.Server == nil { + return "" + } + return m.Addr() + UserinfoEndpoint +} + +// JWKSEndpoint returns the OIDC `jwks_uri` +func (m *MockOIDC) JWKSEndpoint() string { + if m.Server == nil { + return "" + } + return m.Addr() + JWKSEndpoint +} + +func (m *MockOIDC) chainMiddleware(endpoint func(http.ResponseWriter, *http.Request)) http.Handler { + chain := m.forceError(http.HandlerFunc(endpoint)) + for i := len(m.middleware) - 1; i >= 0; i-- { + mw := m.middleware[i] + chain = mw(chain) + } + return chain +} + +func (m *MockOIDC) forceError(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if se := m.ErrorQueue.Pop(); se != nil { + errorResponse(rw, se.Error, se.Description, se.Code) + } else { + next.ServeHTTP(rw, req) + } + }) +} |
