diff options
| author | mo khan <mo@mokhan.ca> | 2022-05-16 10:28:31 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2022-05-16 10:28:31 -0600 |
| commit | 59b6a9a68bf585c5baabe57ebc5cdfe45485dcd0 (patch) | |
| tree | 054f86c691cbde7bae15ff3cd3bcdd6a360c6e3b | |
| parent | c664deedf96f5ec088ab221a21a1eb26c3ca12ed (diff) | |
extract functions to dump request/response
| -rw-r--r-- | cmd/ui/main.go | 67 | ||||
| -rw-r--r-- | pkg/x/must.go | 6 | ||||
| -rw-r--r-- | pkg/x/round_tripper.go | 71 |
3 files changed, 83 insertions, 61 deletions
diff --git a/cmd/ui/main.go b/cmd/ui/main.go index 6bb3ec8..6319b58 100644 --- a/cmd/ui/main.go +++ b/cmd/ui/main.go @@ -1,16 +1,13 @@ package main import ( - "bytes" "context" "fmt" "html/template" - "io/ioutil" "log" "net/http" "net/url" "os" - "strings" "github.com/coreos/go-oidc/v3/oidc" "github.com/hashicorp/uuid" @@ -40,54 +37,6 @@ func SessionFor(sessions map[string]*x.Session, r *http.Request, w http.Response return session } -type LoggingRoundTripper struct { - Proxied http.RoundTripper -} - -func (l LoggingRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - body := x.Must(ioutil.ReadAll(r.Body)) - r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) - - fmt.Println(strings.Repeat("-", 80)) - fmt.Printf("%v %v\n", r.Method, r.URL) - for key, values := range r.Header { - if len(values) == 1 { - fmt.Printf("%v: %v\n", key, values[0]) - } else { - fmt.Printf("%v: %v\n", key, values) - } - } - fmt.Printf("\n") - fmt.Printf("%v\n", string(body)) - - params := x.Must(url.ParseQuery(string(body))) - for key, values := range params { - if len(values) == 1 { - fmt.Printf("\t%v: %v\n", key, values[0]) - } else { - fmt.Printf("\t%v: %v\n", key, values) - } - } - fmt.Println(strings.Repeat("-", 80)) - - response, err := l.Proxied.RoundTrip(r) - - fmt.Printf("%v %v\n", response.StatusCode, http.StatusText(response.StatusCode)) - for key, values := range response.Header { - if len(values) == 1 { - fmt.Printf("%v: %v\n", key, values[0]) - } else { - fmt.Printf("%v: %v\n", key, values) - } - } - fmt.Printf("\n") - responseBody := x.Must(ioutil.ReadAll(response.Body)) - response.Body = ioutil.NopCloser(bytes.NewBuffer(responseBody)) - fmt.Printf("%v\n", string(responseBody)) - - return response, err -} - func main() { sessions := map[string]*x.Session{} @@ -128,15 +77,12 @@ func main() { return } - var profile map[string]interface{} - code := r.URL.Query().Get("code") - client := &http.Client{ - Transport: LoggingRoundTripper{http.DefaultTransport}, - } - - ctx := context.WithValue(r.Context(), oauth2.HTTPClient, client) - - token := x.Must(cfg.Exchange(ctx, code, oauth2.SetAuthURLParam("audience", os.Getenv("AUTH0_AUDIENCE")))) + client := &http.Client{Transport: x.LoggingRoundTripper{http.DefaultTransport}} + token := x.Must(cfg.Exchange( + context.WithValue(r.Context(), oauth2.HTTPClient, client), + r.URL.Query().Get("code"), + oauth2.SetAuthURLParam("audience", os.Getenv("AUTH0_AUDIENCE")), + )) rawIDToken, ok := token.Extra("id_token").(string) if !ok { @@ -147,6 +93,7 @@ func main() { Verifier(&oidc.Config{ClientID: cfg.ClientID}). Verify(r.Context(), rawIDToken)) + var profile map[string]interface{} idToken.Claims(&profile) session.Token = token diff --git a/pkg/x/must.go b/pkg/x/must.go index 4ba6fe2..98363bc 100644 --- a/pkg/x/must.go +++ b/pkg/x/must.go @@ -2,9 +2,13 @@ package x import "log" -func Must[T any](x T, err error) T { +func Check(err error) { if err != nil { log.Fatal(err) } +} + +func Must[T any](x T, err error) T { + Check(err) return x } diff --git a/pkg/x/round_tripper.go b/pkg/x/round_tripper.go new file mode 100644 index 0000000..06b4de7 --- /dev/null +++ b/pkg/x/round_tripper.go @@ -0,0 +1,71 @@ +package x + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "sort" + "strings" +) + +type LoggingRoundTripper struct { + Proxied http.RoundTripper +} + +func dumpHeader(h http.Header) { + keys := make([]string, 0, len(h)) + for key := range h { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + values := h[key] + if len(values) == 1 { + fmt.Printf("%v: %v\n", key, values[0]) + } else { + fmt.Printf("%v: %v\n", key, values) + } + } + fmt.Printf("\n") +} + +func dumpBody(mime string, io io.ReadCloser) io.ReadCloser { + body := Must(ioutil.ReadAll(io)) + + if mime == "application/json" { + var prettyJSON bytes.Buffer + Check(json.Indent(&prettyJSON, body, "", " ")) + fmt.Printf("%v\n", string(prettyJSON.Bytes())) + } else if mime == "application/x-www-form-urlencoded" { + params := Must(url.ParseQuery(string(body))) + for key, values := range params { + if len(values) == 1 { + fmt.Printf("%v: %v\n", key, values[0]) + } else { + fmt.Printf("%v: %v\n", key, values) + } + } + } + + return ioutil.NopCloser(bytes.NewBuffer(body)) +} + +func (l LoggingRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + fmt.Println(strings.Repeat("-", 80)) + fmt.Printf("%v %v\n", r.Method, r.URL) + dumpHeader(r.Header) + r.Body = dumpBody(r.Header.Get("Content-Type"), r.Body) + + response, err := l.Proxied.RoundTrip(r) + + fmt.Printf("\n%v %v\n", response.StatusCode, http.StatusText(response.StatusCode)) + dumpHeader(response.Header) + response.Body = dumpBody(response.Header.Get("Content-Type"), response.Body) + fmt.Println(strings.Repeat("-", 80)) + + return response, err +} |
