From a5692babb809d4bd31d5f515e35d70b36609dd41 Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Thu, 4 Feb 2021 15:31:07 -0500 Subject: [PATCH] Clean Up --- cmd/server/server.go | 77 ++++++++++----------- graph/generated/generated.go | 36 +++++----- graph/model/models_gen.go | 6 +- graph/schema.graphqls | 11 +-- graph/schema.resolvers.go | 4 ++ internal/api/auth.go | 30 --------- internal/api/directives.go | 50 ++++++++++++++ internal/api/media.go | 4 +- internal/api/middlewares.go | 22 +----- internal/api/models.go | 24 ------- internal/api/routes.go | 27 +------- internal/auth/auth.go | 118 -------------------------------- internal/auth/jwt.go | 126 +++++++++++++++++++++++++++++++++++ internal/auth/ldap.go | 4 +- internal/auth/local.go | 22 +++--- internal/config/config.go | 42 ++++++------ internal/db/db.go | 1 + internal/db/errors.go | 7 -- internal/session/session.go | 36 +++++----- plugin/models.go | 20 +++--- 20 files changed, 319 insertions(+), 348 deletions(-) create mode 100644 internal/api/directives.go delete mode 100644 internal/api/models.go create mode 100644 internal/auth/jwt.go delete mode 100644 internal/db/errors.go diff --git a/cmd/server/server.go b/cmd/server/server.go index 4229c1e..95d3b65 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -1,58 +1,59 @@ package server import ( - "time" - "context" - "net/http" - log "github.com/sirupsen/logrus" + "context" + "net/http" + "time" - "reichard.io/imagini/internal/db" - "reichard.io/imagini/internal/api" - "reichard.io/imagini/internal/auth" - "reichard.io/imagini/internal/config" + log "github.com/sirupsen/logrus" + + "reichard.io/imagini/internal/api" + "reichard.io/imagini/internal/auth" + "reichard.io/imagini/internal/config" + "reichard.io/imagini/internal/db" ) type Server struct { - API *api.API - Auth *auth.AuthManager - Config *config.Config - Database *db.DBManager - httpServer *http.Server + API *api.API + Auth *auth.AuthManager + Config *config.Config + Database *db.DBManager + httpServer *http.Server } func NewServer() *Server { - c := config.Load() - db := db.NewMgr(c) - auth := auth.NewMgr(db, c) - api := api.NewApi(db, c, auth) + c := config.Load() + db := db.NewMgr(c) + auth := auth.NewMgr(db, c) + api := api.NewApi(db, c, auth) - return &Server{ - API: api, - Auth: auth, - Config: c, - Database: db, - } + return &Server{ + API: api, + Auth: auth, + Config: c, + Database: db, + } } func (s *Server) StartServer() { - listenAddr := (":" + s.Config.ListenPort) + listenAddr := (":" + s.Config.ListenPort) - s.httpServer = &http.Server{ - Handler: s.API.Router, - Addr: listenAddr, - } + s.httpServer = &http.Server{ + Handler: s.API.Router, + Addr: listenAddr, + } - go func() { - err := s.httpServer.ListenAndServe() - if err != nil { - log.Error("Error starting server ", err) - return - } - }() + go func() { + err := s.httpServer.ListenAndServe() + if err != nil { + log.Error("Error starting server ", err) + return + } + }() } func (s *Server) StopServer() { - ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second) - defer cancel() - s.httpServer.Shutdown(ctx) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s.httpServer.Shutdown(ctx) } diff --git a/graph/generated/generated.go b/graph/generated/generated.go index dee8a1b..8aa1d49 100644 --- a/graph/generated/generated.go +++ b/graph/generated/generated.go @@ -43,6 +43,7 @@ type ResolverRoot interface { type DirectiveRoot struct { HasMinRole func(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) + IsPrivate func(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) Meta func(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error) } @@ -243,21 +244,21 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.AlbumResponse.PageInfo(childComplexity), true - case "AuthResponse.Device": + case "AuthResponse.device": if e.complexity.AuthResponse.Device == nil { break } return e.complexity.AuthResponse.Device(childComplexity), true - case "AuthResponse.Error": + case "AuthResponse.error": if e.complexity.AuthResponse.Error == nil { break } return e.complexity.AuthResponse.Error(childComplexity), true - case "AuthResponse.Result": + case "AuthResponse.result": if e.complexity.AuthResponse.Result == nil { break } @@ -863,6 +864,7 @@ scalar Upload # https://gqlgen.com/reference/directives/ directive @hasMinRole(role: Role!) on FIELD_DEFINITION +directive @isPrivate on FIELD_DEFINITION | INPUT_FIELD_DEFINITION directive @meta( gorm: String, @@ -899,9 +901,9 @@ enum AuthResult { } type AuthResponse { - Result: AuthResult! - Device: Device - Error: String + result: AuthResult! + device: Device + error: String } # ------------------------------------------------------------ @@ -1004,7 +1006,7 @@ type Device { type: DeviceType! @meta(gorm: "default:Unknown;not null") userID: ID! @meta(gorm: "not null") user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null") - refreshKey: String + refreshKey: String @deprecated(reason: "Private Field") # @isPrivate } type User { @@ -1017,7 +1019,7 @@ type User { lastName: String role: Role! @meta(gorm: "default:User;not null") authType: AuthType! @meta(gorm: "default:Local;not null") - password: String + password: String @deprecated(reason: "Private Field") #@isPrivate } type MediaItem { @@ -1927,7 +1929,7 @@ func (ec *executionContext) _AlbumResponse_pageInfo(ctx context.Context, field g return ec.marshalNPageInfo2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐPageInfo(ctx, field.Selections, res) } -func (ec *executionContext) _AuthResponse_Result(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { +func (ec *executionContext) _AuthResponse_result(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { ec.Error(ctx, ec.Recover(ctx, r)) @@ -1962,7 +1964,7 @@ func (ec *executionContext) _AuthResponse_Result(ctx context.Context, field grap return ec.marshalNAuthResult2reichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResult(ctx, field.Selections, res) } -func (ec *executionContext) _AuthResponse_Device(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { +func (ec *executionContext) _AuthResponse_device(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { ec.Error(ctx, ec.Recover(ctx, r)) @@ -1994,7 +1996,7 @@ func (ec *executionContext) _AuthResponse_Device(ctx context.Context, field grap return ec.marshalODevice2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐDevice(ctx, field.Selections, res) } -func (ec *executionContext) _AuthResponse_Error(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { +func (ec *executionContext) _AuthResponse_error(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { ec.Error(ctx, ec.Recover(ctx, r)) @@ -7553,15 +7555,15 @@ func (ec *executionContext) _AuthResponse(ctx context.Context, sel ast.Selection switch field.Name { case "__typename": out.Values[i] = graphql.MarshalString("AuthResponse") - case "Result": - out.Values[i] = ec._AuthResponse_Result(ctx, field, obj) + case "result": + out.Values[i] = ec._AuthResponse_result(ctx, field, obj) if out.Values[i] == graphql.Null { invalids++ } - case "Device": - out.Values[i] = ec._AuthResponse_Device(ctx, field, obj) - case "Error": - out.Values[i] = ec._AuthResponse_Error(ctx, field, obj) + case "device": + out.Values[i] = ec._AuthResponse_device(ctx, field, obj) + case "error": + out.Values[i] = ec._AuthResponse_error(ctx, field, obj) default: panic("unknown field " + strconv.Quote(field.Name)) } diff --git a/graph/model/models_gen.go b/graph/model/models_gen.go index 50787f3..d54d2f4 100644 --- a/graph/model/models_gen.go +++ b/graph/model/models_gen.go @@ -33,9 +33,9 @@ type AlbumResponse struct { } type AuthResponse struct { - Result AuthResult `json:"Result" ` - Device *Device `json:"Device" ` - Error *string `json:"Error" ` + Result AuthResult `json:"result" ` + Device *Device `json:"device" ` + Error *string `json:"error" ` } type AuthTypeFilter struct { diff --git a/graph/schema.graphqls b/graph/schema.graphqls index 92e8493..4d67510 100644 --- a/graph/schema.graphqls +++ b/graph/schema.graphqls @@ -5,6 +5,7 @@ scalar Upload # https://gqlgen.com/reference/directives/ directive @hasMinRole(role: Role!) on FIELD_DEFINITION +directive @isPrivate on FIELD_DEFINITION | INPUT_FIELD_DEFINITION directive @meta( gorm: String, @@ -41,9 +42,9 @@ enum AuthResult { } type AuthResponse { - Result: AuthResult! - Device: Device - Error: String + result: AuthResult! + device: Device + error: String } # ------------------------------------------------------------ @@ -146,7 +147,7 @@ type Device { type: DeviceType! @meta(gorm: "default:Unknown;not null") userID: ID! @meta(gorm: "not null") user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null") - refreshKey: String + refreshKey: String @isPrivate } type User { @@ -159,7 +160,7 @@ type User { lastName: String role: Role! @meta(gorm: "default:User;not null") authType: AuthType! @meta(gorm: "default:Local;not null") - password: String + password: String @isPrivate } type MediaItem { diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index 50f185e..6f54faa 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -75,6 +75,8 @@ func (r *queryResolver) Login(ctx context.Context, user string, password string, } } else { foundDevice.Type = deriveDeviceType(req) + foundDevice.UserID = foundUser.ID + // TODO: foundDevice.User = &foundUser err := r.DB.CreateDevice(&foundDevice) if err != nil { return &model.AuthResponse{Result: model.AuthResultFailure}, nil @@ -97,6 +99,8 @@ func (r *queryResolver) Login(ctx context.Context, user string, password string, http.SetCookie(*resp, &accessCookie) http.SetCookie(*resp, &refreshCookie) + // TODO: Prob bandaid + foundDevice.User = &foundUser return &model.AuthResponse{Result: model.AuthResultSuccess, Device: &foundDevice}, nil } diff --git a/internal/api/auth.go b/internal/api/auth.go index 197bd17..66f4b66 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -1,12 +1,10 @@ package api import ( - "context" "errors" "fmt" "net/http" - "github.com/99designs/gqlgen/graphql" "github.com/google/uuid" "github.com/lestrrat-go/jwx/jwt" @@ -94,31 +92,3 @@ func (api *API) validateTokens(w *http.ResponseWriter, r *http.Request) (jwt.Tok return jwt.ParseBytes([]byte(newAccessCookie)) } - -func (api *API) hasMinRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) { - authContext := ctx.Value("auth").(*model.AuthContext) - accessToken, err := api.validateTokens(authContext.AuthResponse, authContext.AuthRequest) - if err != nil { - return nil, errors.New("Access Denied") - } - authContext.AccessToken = &accessToken - - userRole, ok := accessToken.Get("role") - if !ok { - return nil, errors.New("Access Denied") - } - - if userRole == model.RoleAdmin.String() { - return next(ctx) - } - - if userRole == role.String() { - return next(ctx) - } - - return nil, errors.New("Role Not Authenticated") -} - -func (api *API) metaDirective(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error) { - return next(ctx) -} diff --git a/internal/api/directives.go b/internal/api/directives.go new file mode 100644 index 0000000..0c07f9a --- /dev/null +++ b/internal/api/directives.go @@ -0,0 +1,50 @@ +package api + +import ( + "context" + "errors" + + "github.com/99designs/gqlgen/graphql" + "reichard.io/imagini/graph/model" +) + +/** + * This is used to validate whether the users role is adequate for the requested resource. + **/ +func (api *API) hasMinRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) { + authContext := ctx.Value("auth").(*model.AuthContext) + accessToken, err := api.validateTokens(authContext.AuthResponse, authContext.AuthRequest) + if err != nil { + return nil, errors.New("Access Denied") + } + authContext.AccessToken = &accessToken + + userRole, ok := accessToken.Get("role") + if !ok { + return nil, errors.New("Access Denied") + } + + if userRole == model.RoleAdmin.String() { + return next(ctx) + } + + if userRole == role.String() { + return next(ctx) + } + + return nil, errors.New("Role Not Authenticated") +} + +/** + * This is needed but not used. Meta is used for Gorm. + **/ +func (api *API) metaDirective(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error) { + return next(ctx) +} + +/** + * This overrides the response so fields with an @isPrivate directive are always nil. + **/ +func (api *API) isPrivateDirective(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) { + return nil, errors.New("Private Field") +} diff --git a/internal/api/media.go b/internal/api/media.go index d648e25..873b924 100644 --- a/internal/api/media.go +++ b/internal/api/media.go @@ -8,7 +8,9 @@ import ( "reichard.io/imagini/graph/model" ) -// Responsible for serving up static images / videos +/** + * Responsible for serving up static images / videos + **/ func (api *API) mediaHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { w.WriteHeader(http.StatusMethodNotAllowed) diff --git a/internal/api/middlewares.go b/internal/api/middlewares.go index 3510133..a845e24 100644 --- a/internal/api/middlewares.go +++ b/internal/api/middlewares.go @@ -3,9 +3,6 @@ package api import ( "context" "net/http" - "os" - - log "github.com/sirupsen/logrus" "reichard.io/imagini/graph/model" ) @@ -27,9 +24,8 @@ func multipleMiddleware(h http.HandlerFunc, m ...Middleware) http.HandlerFunc { * This is used for the graphQL endpoints that may require access to the * Request and ResponseWriter variables. These are used to get / set cookies. **/ -func (api *API) injectContextMiddleware(next http.Handler) http.Handler { +func (api *API) contextMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authContext := &model.AuthContext{ AuthResponse: &w, AuthRequest: r, @@ -40,7 +36,6 @@ func (api *API) injectContextMiddleware(next http.Handler) http.Handler { r = r.WithContext(ctx) next.ServeHTTP(w, r) - }) } @@ -49,10 +44,10 @@ func (api *API) injectContextMiddleware(next http.Handler) http.Handler { **/ func (api *API) authMiddleware(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - + // Validate Tokens accessToken, err := api.validateTokens(&w, r) if err != nil { - errorJSON(w, "Invalid token.", http.StatusUnauthorized) + w.WriteHeader(http.StatusUnauthorized) return } @@ -64,16 +59,5 @@ func (api *API) authMiddleware(next http.Handler) http.HandlerFunc { r = r.WithContext(ctx) next.ServeHTTP(w, r) - - }) -} - -func (api *API) logMiddleware(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - - log.SetOutput(os.Stdout) - log.Println(r.Method, r.URL) - h.ServeHTTP(w, r) - }) } diff --git a/internal/api/models.go b/internal/api/models.go deleted file mode 100644 index 8a2345b..0000000 --- a/internal/api/models.go +++ /dev/null @@ -1,24 +0,0 @@ -package api - -type APICredentials struct { - User string `json:"user"` - Password string `json:"password"` -} - -type APIData interface{} - -type APIMeta struct { - Count int64 `json:"count"` - Page int64 `json:"page"` -} - -type APIError struct { - Message string `json:"message"` - Code int64 `json:"code"` -} - -type APIResponse struct { - Data APIData `json:"data,omitempty"` - Meta *APIMeta `json:"meta,omitempty"` - Error *APIError `json:"error,omitempty"` -} diff --git a/internal/api/routes.go b/internal/api/routes.go index 7343c48..2e342f2 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -1,9 +1,6 @@ package api import ( - "encoding/json" - "net/http" - "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/playground" @@ -12,7 +9,6 @@ import ( ) func (api *API) registerRoutes() { - // Set up Directives graphConfig := generated.Config{ Resolvers: &graph.Resolver{ @@ -21,6 +17,7 @@ func (api *API) registerRoutes() { }, Directives: generated.DirectiveRoot{ Meta: api.metaDirective, + IsPrivate: api.isPrivateDirective, HasMinRole: api.hasMinRoleDirective, }, } @@ -28,31 +25,11 @@ func (api *API) registerRoutes() { // Handle GraphQL api.Router.Handle("/playground", playground.Handler("GraphQL playground", "/query")) - api.Router.Handle("/query", api.injectContextMiddleware(srv)) + api.Router.Handle("/query", api.contextMiddleware(srv)) // Handle Resource Route api.Router.HandleFunc("/media/", multipleMiddleware( api.mediaHandler, api.authMiddleware, )) - -} - -func errorJSON(w http.ResponseWriter, err string, code int) { - errStruct := &APIResponse{Error: &APIError{Message: err, Code: int64(code)}} - responseJSON(w, errStruct, code) -} - -func responseJSON(w http.ResponseWriter, msg interface{}, code int) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(code) - json.NewEncoder(w).Encode(msg) -} - -func successJSON(w http.ResponseWriter, msg string, code int) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(code) - json.NewEncoder(w).Encode(map[string]interface{}{"success": msg}) } diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 062a02f..f444995 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,14 +1,8 @@ package auth import ( - "encoding/json" "errors" - "fmt" - "time" - "github.com/google/uuid" - "github.com/lestrrat-go/jwx/jwa" - "github.com/lestrrat-go/jwx/jwt" log "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -61,115 +55,3 @@ func (auth *AuthManager) AuthenticateUser(user, password string) (model.User, bo return *foundUser, false } } - -func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, error) { - byteRefreshJWT := []byte(refreshJWT) - - // Acquire Relevant Device - unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT) - did, ok := unverifiedToken.Get("did") - if !ok { - return nil, errors.New("did does not exist") - } - deviceID, err := uuid.Parse(fmt.Sprintf("%v", did)) - if err != nil { - return nil, errors.New("did does not parse") - } - device := &model.Device{ID: deviceID.String()} - _, err = auth.DB.Device(device) - if err != nil { - return nil, err - } - - // Verify & Validate Token - verifiedToken, err := jwt.ParseBytes(byteRefreshJWT, - jwt.WithValidate(true), - jwt.WithVerify(jwa.HS256, []byte(*device.RefreshKey)), - ) - if err != nil { - fmt.Println("failed to parse payload: ", err) - return nil, err - } - return verifiedToken, nil -} - -func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, error) { - byteAccessJWT := []byte(accessJWT) - verifiedToken, err := jwt.ParseBytes(byteAccessJWT, - jwt.WithValidate(true), - jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)), - ) - - if err != nil { - return nil, err - } - - return verifiedToken, nil -} - -func (auth *AuthManager) CreateJWTRefreshToken(user model.User, device model.Device) (string, error) { - // Acquire Refresh Key - byteKey := []byte(*device.RefreshKey) - - // Create New Token - tm := time.Now() - t := jwt.New() - t.Set(`did`, device.ID) // Device ID - t.Set(jwt.SubjectKey, user.ID) // User ID - t.Set(jwt.AudienceKey, `imagini`) // App ID - t.Set(jwt.IssuedAtKey, tm) // Issued At - - // iOS & Android = Never Expiring Refresh Token - if device.Type != "iOS" && device.Type != "Android" { - t.Set(jwt.ExpirationKey, tm.Add(time.Hour*24)) // 1 Day Access Key - } - - // Validate Token Creation - _, err := json.MarshalIndent(t, "", " ") - if err != nil { - fmt.Printf("failed to generate JSON: %s\n", err) - return "", err - } - - // Sign Token - signed, err := jwt.Sign(t, jwa.HS256, byteKey) - if err != nil { - log.Printf("failed to sign token: %s", err) - return "", err - } - - // Return Token - return string(signed), nil -} - -func (auth *AuthManager) CreateJWTAccessToken(user model.User, device model.Device) (string, error) { - // Create New Token - tm := time.Now() - t := jwt.New() - t.Set(`did`, device.ID) // Device ID - t.Set(`role`, user.Role.String()) // User Role (Admin / User) - t.Set(jwt.SubjectKey, user.ID) // User ID - t.Set(jwt.AudienceKey, `imagini`) // App ID - t.Set(jwt.IssuedAtKey, tm) // Issued At - t.Set(jwt.ExpirationKey, tm.Add(time.Hour*2)) // 2 Hour Access Key - - // Validate Token Creation - _, err := json.MarshalIndent(t, "", " ") - if err != nil { - fmt.Printf("failed to generate JSON: %s\n", err) - return "", err - } - - // Use Server Key - byteKey := []byte(auth.Config.JWTSecret) - - // Sign Token - signed, err := jwt.Sign(t, jwa.HS256, byteKey) - if err != nil { - log.Printf("failed to sign token: %s", err) - return "", err - } - - // Return Token - return string(signed), nil -} diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go new file mode 100644 index 0000000..cd04c13 --- /dev/null +++ b/internal/auth/jwt.go @@ -0,0 +1,126 @@ +package auth + +import ( + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" + log "github.com/sirupsen/logrus" + "reichard.io/imagini/graph/model" +) + +func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, error) { + byteRefreshJWT := []byte(refreshJWT) + + // Acquire Relevant Device + unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT) + did, ok := unverifiedToken.Get("did") + if !ok { + return nil, errors.New("did does not exist") + } + deviceID, err := uuid.Parse(fmt.Sprintf("%v", did)) + if err != nil { + return nil, errors.New("did does not parse") + } + device := &model.Device{ID: deviceID.String()} + _, err = auth.DB.Device(device) + if err != nil { + return nil, err + } + + // Verify & Validate Token + verifiedToken, err := jwt.ParseBytes(byteRefreshJWT, + jwt.WithValidate(true), + jwt.WithVerify(jwa.HS256, []byte(*device.RefreshKey)), + ) + if err != nil { + fmt.Println("failed to parse payload: ", err) + return nil, err + } + return verifiedToken, nil +} + +func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, error) { + byteAccessJWT := []byte(accessJWT) + verifiedToken, err := jwt.ParseBytes(byteAccessJWT, + jwt.WithValidate(true), + jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)), + ) + + if err != nil { + return nil, err + } + + return verifiedToken, nil +} + +func (auth *AuthManager) CreateJWTRefreshToken(user model.User, device model.Device) (string, error) { + // Acquire Refresh Key + byteKey := []byte(*device.RefreshKey) + + // Create New Token + tm := time.Now() + t := jwt.New() + t.Set(`did`, device.ID) // Device ID + t.Set(jwt.SubjectKey, user.ID) // User ID + t.Set(jwt.AudienceKey, `imagini`) // App ID + t.Set(jwt.IssuedAtKey, tm) // Issued At + + // iOS & Android = Never Expiring Refresh Token + if device.Type != "iOS" && device.Type != "Android" { + t.Set(jwt.ExpirationKey, tm.Add(time.Hour*24)) // 1 Day Access Key + } + + // Validate Token Creation + _, err := json.MarshalIndent(t, "", " ") + if err != nil { + fmt.Printf("failed to generate JSON: %s\n", err) + return "", err + } + + // Sign Token + signed, err := jwt.Sign(t, jwa.HS256, byteKey) + if err != nil { + log.Printf("failed to sign token: %s", err) + return "", err + } + + // Return Token + return string(signed), nil +} + +func (auth *AuthManager) CreateJWTAccessToken(user model.User, device model.Device) (string, error) { + // Create New Token + tm := time.Now() + t := jwt.New() + t.Set(`did`, device.ID) // Device ID + t.Set(`role`, user.Role.String()) // User Role (Admin / User) + t.Set(jwt.SubjectKey, user.ID) // User ID + t.Set(jwt.AudienceKey, `imagini`) // App ID + t.Set(jwt.IssuedAtKey, tm) // Issued At + t.Set(jwt.ExpirationKey, tm.Add(time.Hour*2)) // 2 Hour Access Key + + // Validate Token Creation + _, err := json.MarshalIndent(t, "", " ") + if err != nil { + fmt.Printf("failed to generate JSON: %s\n", err) + return "", err + } + + // Use Server Key + byteKey := []byte(auth.Config.JWTSecret) + + // Sign Token + signed, err := jwt.Sign(t, jwa.HS256, byteKey) + if err != nil { + log.Printf("failed to sign token: %s", err) + return "", err + } + + // Return Token + return string(signed), nil +} diff --git a/internal/auth/ldap.go b/internal/auth/ldap.go index 0de965b..c3927e6 100644 --- a/internal/auth/ldap.go +++ b/internal/auth/ldap.go @@ -1,9 +1,9 @@ package auth import ( - "reichard.io/imagini/graph/model" + "reichard.io/imagini/graph/model" ) func authenticateLDAPUser(user model.User, pw string) bool { - return false + return false } diff --git a/internal/auth/local.go b/internal/auth/local.go index b5f06c9..644c2dc 100644 --- a/internal/auth/local.go +++ b/internal/auth/local.go @@ -1,18 +1,18 @@ package auth import ( - "golang.org/x/crypto/bcrypt" - log "github.com/sirupsen/logrus" - "reichard.io/imagini/graph/model" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" + "reichard.io/imagini/graph/model" ) func authenticateLocalUser(user model.User, pw string) bool { - bPassword :=[]byte(pw) - err := bcrypt.CompareHashAndPassword([]byte(*user.Password), bPassword) - if err == nil { - log.Info("[auth] Authentication successfull: ", user.Username) - return true - } - log.Warn("[auth] Authentication failed: ", user.Username) - return false + bPassword := []byte(pw) + err := bcrypt.CompareHashAndPassword([]byte(*user.Password), bPassword) + if err == nil { + log.Info("[auth] Authentication successfull: ", user.Username) + return true + } + log.Warn("[auth] Authentication failed: ", user.Username) + return false } diff --git a/internal/config/config.go b/internal/config/config.go index 4243563..416af51 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,34 +1,34 @@ package config import ( - "os" + "os" ) type Config struct { - DBType string - DBName string - DBPassword string - DataPath string - ConfigPath string - JWTSecret string - ListenPort string + DBType string + DBName string + DBPassword string + DataPath string + ConfigPath string + JWTSecret string + ListenPort string } func Load() *Config { - return &Config{ - DBType: getEnv("DATABASE_TYPE", "SQLite"), - DBName: getEnv("DATABASE_NAME", "imagini"), - DBPassword: getEnv("DATABASE_PASSWORD", ""), - ConfigPath: getEnv("CONFIG_PATH", "/config"), - DataPath: getEnv("DATA_PATH", "/data"), - JWTSecret: getEnv("JWT_SECRET", "58b9340c0472cf045db226bc445966524e780cd38bc3dd707afce80c95d4de6f"), - ListenPort: getEnv("LISTEN_PORT", "8484"), - } + return &Config{ + DBType: getEnv("DATABASE_TYPE", "SQLite"), + DBName: getEnv("DATABASE_NAME", "imagini"), + DBPassword: getEnv("DATABASE_PASSWORD", ""), + ConfigPath: getEnv("CONFIG_PATH", "/config"), + DataPath: getEnv("DATA_PATH", "/data"), + JWTSecret: getEnv("JWT_SECRET", "58b9340c0472cf045db226bc445966524e780cd38bc3dd707afce80c95d4de6f"), + ListenPort: getEnv("LISTEN_PORT", "8484"), + } } func getEnv(key, fallback string) string { - if value, ok := os.LookupEnv(key); ok { - return value - } - return fallback + if value, ok := os.LookupEnv(key); ok { + return value + } + return fallback } diff --git a/internal/db/db.go b/internal/db/db.go index 3f3de1a..cda921c 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -69,6 +69,7 @@ func (dbm *DBManager) bootstrapDatabase() { } } +// TODO func (dbm *DBManager) QueryBuilder(dest interface{}, params []byte) (int64, error) { // TODO: // - Where Filters diff --git a/internal/db/errors.go b/internal/db/errors.go deleted file mode 100644 index 787ae41..0000000 --- a/internal/db/errors.go +++ /dev/null @@ -1,7 +0,0 @@ -package db - -import "errors" - -var ( - ErrUserAlreadyExists = errors.New("user already exists") -) diff --git a/internal/session/session.go b/internal/session/session.go index 477a22c..e9aa8af 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -1,38 +1,40 @@ package session import ( - "sync" + "sync" ) // Used to maintain a cache of user specific jwt secrets // This will prevent DB lookups on every request +// May not actually be needed. Refresh Token is the only +// token that will require proactive DB lookups. type SessionManager struct { - mutex sync.Mutex - values map[string]string + mutex sync.Mutex + values map[string]string } func NewMgr() *SessionManager { - return &SessionManager{} + return &SessionManager{} } func (sm *SessionManager) Set(key, value string) { - sm.mutex.Lock() - sm.values[key] = value - sm.mutex.Unlock() + sm.mutex.Lock() + sm.values[key] = value + sm.mutex.Unlock() } func (sm *SessionManager) Get(key string) string { - sm.mutex.Lock() - defer sm.mutex.Unlock() - return sm.values[key] + sm.mutex.Lock() + defer sm.mutex.Unlock() + return sm.values[key] } func (sm *SessionManager) Delete(key string) { - sm.mutex.Lock() - defer sm.mutex.Unlock() - _, exists := sm.values[key] - if !exists { - return - } - delete(sm.values, key) + sm.mutex.Lock() + defer sm.mutex.Unlock() + _, exists := sm.values[key] + if !exists { + return + } + delete(sm.values, key) } diff --git a/plugin/models.go b/plugin/models.go index 3682a2d..30cd6f8 100644 --- a/plugin/models.go +++ b/plugin/models.go @@ -42,7 +42,7 @@ type Field struct { Name string Type types.Type Tag string - Gorm string + Gorm string } type Enum struct { @@ -163,21 +163,21 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error { typ = types.NewPointer(typ) } - gormType := "" - directive := field.Directives.ForName("meta") - if directive != nil { - arg := directive.Arguments.ForName("gorm") - if arg != nil { - gormType = fmt.Sprintf("gorm:\"%s\"", arg.Value.Raw) - } - } + gormType := "" + directive := field.Directives.ForName("meta") + if directive != nil { + arg := directive.Arguments.ForName("gorm") + if arg != nil { + gormType = fmt.Sprintf("gorm:\"%s\"", arg.Value.Raw) + } + } it.Fields = append(it.Fields, &Field{ Name: name, Type: typ, Description: field.Description, Tag: `json:"` + field.Name + `"`, - Gorm: gormType, + Gorm: gormType, }) }