summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2022-05-16 10:28:31 -0600
committermo khan <mo@mokhan.ca>2022-05-16 10:28:31 -0600
commit59b6a9a68bf585c5baabe57ebc5cdfe45485dcd0 (patch)
tree054f86c691cbde7bae15ff3cd3bcdd6a360c6e3b
parentc664deedf96f5ec088ab221a21a1eb26c3ca12ed (diff)
extract functions to dump request/response
-rw-r--r--cmd/ui/main.go67
-rw-r--r--pkg/x/must.go6
-rw-r--r--pkg/x/round_tripper.go71
3 files changed, 83 insertions, 61 deletions
diff --git a/cmd/ui/main.go b/cmd/ui/main.go
index 6bb3ec8..6319b58 100644
--- a/cmd/ui/main.go
+++ b/cmd/ui/main.go
@@ -1,16 +1,13 @@
package main
import (
- "bytes"
"context"
"fmt"
"html/template"
- "io/ioutil"
"log"
"net/http"
"net/url"
"os"
- "strings"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/hashicorp/uuid"
@@ -40,54 +37,6 @@ func SessionFor(sessions map[string]*x.Session, r *http.Request, w http.Response
return session
}
-type LoggingRoundTripper struct {
- Proxied http.RoundTripper
-}
-
-func (l LoggingRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
- body := x.Must(ioutil.ReadAll(r.Body))
- r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
-
- fmt.Println(strings.Repeat("-", 80))
- fmt.Printf("%v %v\n", r.Method, r.URL)
- for key, values := range r.Header {
- if len(values) == 1 {
- fmt.Printf("%v: %v\n", key, values[0])
- } else {
- fmt.Printf("%v: %v\n", key, values)
- }
- }
- fmt.Printf("\n")
- fmt.Printf("%v\n", string(body))
-
- params := x.Must(url.ParseQuery(string(body)))
- for key, values := range params {
- if len(values) == 1 {
- fmt.Printf("\t%v: %v\n", key, values[0])
- } else {
- fmt.Printf("\t%v: %v\n", key, values)
- }
- }
- fmt.Println(strings.Repeat("-", 80))
-
- response, err := l.Proxied.RoundTrip(r)
-
- fmt.Printf("%v %v\n", response.StatusCode, http.StatusText(response.StatusCode))
- for key, values := range response.Header {
- if len(values) == 1 {
- fmt.Printf("%v: %v\n", key, values[0])
- } else {
- fmt.Printf("%v: %v\n", key, values)
- }
- }
- fmt.Printf("\n")
- responseBody := x.Must(ioutil.ReadAll(response.Body))
- response.Body = ioutil.NopCloser(bytes.NewBuffer(responseBody))
- fmt.Printf("%v\n", string(responseBody))
-
- return response, err
-}
-
func main() {
sessions := map[string]*x.Session{}
@@ -128,15 +77,12 @@ func main() {
return
}
- var profile map[string]interface{}
- code := r.URL.Query().Get("code")
- client := &http.Client{
- Transport: LoggingRoundTripper{http.DefaultTransport},
- }
-
- ctx := context.WithValue(r.Context(), oauth2.HTTPClient, client)
-
- token := x.Must(cfg.Exchange(ctx, code, oauth2.SetAuthURLParam("audience", os.Getenv("AUTH0_AUDIENCE"))))
+ client := &http.Client{Transport: x.LoggingRoundTripper{http.DefaultTransport}}
+ token := x.Must(cfg.Exchange(
+ context.WithValue(r.Context(), oauth2.HTTPClient, client),
+ r.URL.Query().Get("code"),
+ oauth2.SetAuthURLParam("audience", os.Getenv("AUTH0_AUDIENCE")),
+ ))
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
@@ -147,6 +93,7 @@ func main() {
Verifier(&oidc.Config{ClientID: cfg.ClientID}).
Verify(r.Context(), rawIDToken))
+ var profile map[string]interface{}
idToken.Claims(&profile)
session.Token = token
diff --git a/pkg/x/must.go b/pkg/x/must.go
index 4ba6fe2..98363bc 100644
--- a/pkg/x/must.go
+++ b/pkg/x/must.go
@@ -2,9 +2,13 @@ package x
import "log"
-func Must[T any](x T, err error) T {
+func Check(err error) {
if err != nil {
log.Fatal(err)
}
+}
+
+func Must[T any](x T, err error) T {
+ Check(err)
return x
}
diff --git a/pkg/x/round_tripper.go b/pkg/x/round_tripper.go
new file mode 100644
index 0000000..06b4de7
--- /dev/null
+++ b/pkg/x/round_tripper.go
@@ -0,0 +1,71 @@
+package x
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+)
+
+type LoggingRoundTripper struct {
+ Proxied http.RoundTripper
+}
+
+func dumpHeader(h http.Header) {
+ keys := make([]string, 0, len(h))
+ for key := range h {
+ keys = append(keys, key)
+ }
+ sort.Strings(keys)
+ for _, key := range keys {
+ values := h[key]
+ if len(values) == 1 {
+ fmt.Printf("%v: %v\n", key, values[0])
+ } else {
+ fmt.Printf("%v: %v\n", key, values)
+ }
+ }
+ fmt.Printf("\n")
+}
+
+func dumpBody(mime string, io io.ReadCloser) io.ReadCloser {
+ body := Must(ioutil.ReadAll(io))
+
+ if mime == "application/json" {
+ var prettyJSON bytes.Buffer
+ Check(json.Indent(&prettyJSON, body, "", " "))
+ fmt.Printf("%v\n", string(prettyJSON.Bytes()))
+ } else if mime == "application/x-www-form-urlencoded" {
+ params := Must(url.ParseQuery(string(body)))
+ for key, values := range params {
+ if len(values) == 1 {
+ fmt.Printf("%v: %v\n", key, values[0])
+ } else {
+ fmt.Printf("%v: %v\n", key, values)
+ }
+ }
+ }
+
+ return ioutil.NopCloser(bytes.NewBuffer(body))
+}
+
+func (l LoggingRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
+ fmt.Println(strings.Repeat("-", 80))
+ fmt.Printf("%v %v\n", r.Method, r.URL)
+ dumpHeader(r.Header)
+ r.Body = dumpBody(r.Header.Get("Content-Type"), r.Body)
+
+ response, err := l.Proxied.RoundTrip(r)
+
+ fmt.Printf("\n%v %v\n", response.StatusCode, http.StatusText(response.StatusCode))
+ dumpHeader(response.Header)
+ response.Body = dumpBody(response.Header.Get("Content-Type"), response.Body)
+ fmt.Println(strings.Repeat("-", 80))
+
+ return response, err
+}