package main import ( "context" "fmt" "io/ioutil" "log" "net/http" "net/url" "os" "path/filepath" "time" "github.com/casbin/casbin/v2" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/go-chi/jwtauth/v5" "github.com/joho/godotenv" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/xlgmokha/api-auth0/pkg/x" ) func ValidAudience(audiences []string) bool { for _, aud := range audiences { if aud == os.Getenv("AUTH0_AUDIENCE") { return true } } return false } func Enforce(subject, resource, action string) bool { e := x.Must(casbin.NewEnforcer("pkg/api/policy.conf", "pkg/api/policy.csv")) ok := x.Must(e.Enforce(subject, resource, action)) fmt.Printf("%v: %v %v %v\n", ok, subject, action, resource) return ok } func Authz() func(next http.Handler) http.Handler { url := x.Must(url.Parse("https://" + os.Getenv("AUTH0_DOMAIN") + "/.well-known/jwks.json")).String() cache := jwk.NewCache(context.Background()) cache.Register(url, jwk.WithMinRefreshInterval(15*time.Minute)) cache.Refresh(context.Background(), url) keySet := jwk.NewCachedSet(cache, url) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var subject string raw := jwtauth.TokenFromHeader(r) if raw == "" { subject = "*" } else { token, err := jwt.ParseString(raw, jwt.WithKeySet(keySet)) if err != nil { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(fmt.Sprintf(`{"message":"%v"}`, err))) return } if !ValidAudience(token.Audience()) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(fmt.Sprintf(`{"message":"%v"}`, "invalid audience"))) return } subject = token.Subject() } if Enforce(subject, r.URL.Path, r.Method) { next.ServeHTTP(w, r) } else { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(fmt.Sprintf(`{"message":"%v"}`, http.StatusText(http.StatusUnauthorized)))) } }) } } func readFixture(path string) []byte { finalPath := filepath.Join(x.Must(os.Getwd()), "pkg/api/fixtures", path) return x.Must(ioutil.ReadFile(finalPath)) } func main() { if err := godotenv.Load(); err != nil { log.Fatal(err) } router := chi.NewRouter() router.Use(middleware.RequestID) router.Use(middleware.RealIP) router.Use(middleware.Logger) router.Use(middleware.Recoverer) router.Use(middleware.Timeout(30 * time.Second)) router.Use(middleware.AllowContentType("application/json")) router.Use(middleware.Heartbeat("/health")) router.Use(cors.AllowAll().Handler) router.Use(Authz()) router.Route("/api", func(router chi.Router) { router.Use(middleware.SetHeader("Content-Type", "application/json")) router.Get("/public", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"message":"public"}`)) }) router.Route("/users", func(router chi.Router) { // router.Use(Authorize(token, "read:users")) router.Get("/", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"message":"users"}`)) }) }) router.Route("/incidents", func(router chi.Router) { // router.Use(Authorize(token, "read:incidents")) router.Get("/", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"message":"incidents"}`)) }) }) // vercel hosted ? router.Route("/auth", func(router chi.Router) { router.Get("/providers", func(w http.ResponseWriter, r *http.Request) { }) router.Get("/csrf", func(w http.ResponseWriter, r *http.Request) { }) router.Get("/signin/auth0", func(w http.ResponseWriter, r *http.Request) { }) }) router.Route("/session", func(router chi.Router) { router.Get("/", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(readFixture("session.json")) }) }) router.Route("/notifications", func(router chi.Router) { router.Get("/", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(readFixture("notifications.json")) }) router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(readFixture(fmt.Sprintf("notifications/%v.json", chi.URLParam(r, "id")))) }) }) router.Route("/resources", func(router chi.Router) { router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(readFixture(fmt.Sprintf("resources/%v.json", chi.URLParam(r, "id")))) }) }) router.Route("/atlas", func(router chi.Router) { router.Get("/", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(readFixture("atlas.json")) }) router.Get("/{id}/capabilities", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(readFixture(fmt.Sprintf("atlas/%v/capabilities.json", chi.URLParam(r, "id")))) }) router.Get("/{id}/investigations", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(readFixture(fmt.Sprintf("atlas/%v/investigations.json", chi.URLParam(r, "id")))) }) router.Get("/{id}/investigations/{iid}", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(readFixture(fmt.Sprintf("atlas/%v/investigations/%v.json", chi.URLParam(r, "id"), chi.URLParam(r, "iid")))) }) router.Get("/{id}/investigations/{iid}/conversation", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(readFixture(fmt.Sprintf("atlas/%v/investigations/%v/conversation.json", chi.URLParam(r, "id"), chi.URLParam(r, "iid")))) }) }) }) router.NotFound(func(w http.ResponseWriter, r *http.Request) { fmt.Printf("%v %v Not Found\n", r.Method, r.URL.Path) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusNotFound) w.Write([]byte("[]")) }) log.Fatal(http.ListenAndServe("localhost:4000", router)) }