package api import ( log "github.com/sirupsen/logrus" "net/http" "context" "os" "reichard.io/imagini/graph/model" ) type Middleware func(http.Handler) http.HandlerFunc func multipleMiddleware(h http.HandlerFunc, m ...Middleware) http.HandlerFunc { if len(m) < 1 { return h } wrapped := h for i := len(m) - 1; i >= 0; i-- { wrapped = m[i](wrapped) } return wrapped } func (api *API) injectContextMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Info("[middleware] Entering testMiddleware...") authContext := &model.AuthContext{ AuthResponse: &w, AuthRequest: r, } accessCookie, err := r.Cookie("AccessToken") if err != nil { log.Warn("[middleware] AccessToken not found") } else { authContext.AccessToken = accessCookie.Value } refreshCookie, err := r.Cookie("RefreshToken") if err != nil { log.Warn("[middleware] RefreshToken not found") } else { authContext.RefreshToken = refreshCookie.Value } // Add context ctx := context.WithValue(r.Context(), "auth", authContext) r = r.WithContext(ctx) log.Info("[middleware] Exiting testMiddleware...") next.ServeHTTP(w, r) }) } func (api *API) authMiddleware(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Acquire Token accessCookie, err := r.Cookie("AccessToken") if err != nil { log.Warn("[middleware] AccessToken not found") errorJSON(w, "Invalid token.", http.StatusUnauthorized) return } // Validate JWT Tokens accessToken, err := api.Auth.ValidateJWTAccessToken(accessCookie.Value) if err != nil && err.Error() == "exp not satisfied" { log.Info("[middleware] Refreshing AccessToken") accessToken, err = api.refreshAccessToken(w, r) if err != nil { log.Warn("[middleware] Refreshing AccessToken failed: ", err) errorJSON(w, "Invalid token.", http.StatusUnauthorized) return } log.Info("[middleware] AccessToken Refreshed") } else if err != nil { log.Warn("[middleware] AccessToken failed to validate") errorJSON(w, "Invalid token.", http.StatusUnauthorized) return } // Acquire UserID and DeviceID reqInfo := make(map[string]string) uid, _ := accessToken.Get("sub") did, _ := accessToken.Get("did") reqInfo["uid"] = uid.(string) reqInfo["did"] = did.(string) // Add context ctx := context.WithValue(r.Context(), "uuids", reqInfo) sr := r.WithContext(ctx) next.ServeHTTP(w, sr) }) } func (api *API) logMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.SetOutput(os.Stdout) log.Println(r.Method, r.URL) h.ServeHTTP(w, r) }) }