From 082f9234829fd3f74522784482ce5eba002c3f64 Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Thu, 4 Feb 2021 05:16:13 -0500 Subject: [PATCH] Wooo API! --- cmd/main.go | 113 +++++------ graph/generated/generated.go | 353 ++++++++++++++++++++++++++++++----- graph/model/models_auth.go | 11 +- graph/model/models_db.go | 34 ++-- graph/model/models_gen.go | 17 +- graph/resolver.go | 8 +- graph/schema.graphqls | 22 ++- graph/schema.resolvers.go | 90 +++++++-- internal/api/auth.go | 320 ++++++++++--------------------- internal/api/media.go | 71 +++---- internal/api/middlewares.go | 133 ++++++------- internal/api/routes.go | 53 +++--- internal/auth/auth.go | 274 +++++++++++++-------------- internal/db/db.go | 164 ++++++++-------- internal/db/devices.go | 35 ++-- internal/db/errors.go | 2 +- internal/db/media_items.go | 20 +- internal/db/users.go | 52 +++--- 18 files changed, 977 insertions(+), 795 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index d5ff9ab..0d9c8e4 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,85 +1,86 @@ package main import ( - "os" - "os/signal" - "github.com/urfave/cli/v2" - log "github.com/sirupsen/logrus" + "os" + "os/signal" - "reichard.io/imagini/cmd/server" + log "github.com/sirupsen/logrus" + "github.com/urfave/cli/v2" - "reichard.io/imagini/plugin" - "github.com/99designs/gqlgen/api" + "reichard.io/imagini/cmd/server" + + "github.com/99designs/gqlgen/api" "github.com/99designs/gqlgen/codegen/config" + "reichard.io/imagini/plugin" ) type UTCFormatter struct { - log.Formatter + log.Formatter } func (u UTCFormatter) Format(e *log.Entry) ([]byte, error) { - e.Time = e.Time.UTC() - return u.Formatter.Format(e) + e.Time = e.Time.UTC() + return u.Formatter.Format(e) } func main() { - log.SetFormatter(UTCFormatter{&log.TextFormatter{FullTimestamp: true}}) + log.SetFormatter(UTCFormatter{&log.TextFormatter{FullTimestamp: true}}) - app := &cli.App{ - Name: "Imagini", - Usage: "A self hosted photo library.", - Commands: []*cli.Command{ - { - Name: "serve", - Aliases: []string{"s"}, - Usage: "Start Imagini web server.", - Action: cmdServer, - }, - { - Name: "generate", - Usage: "generate graphql schema", + app := &cli.App{ + Name: "Imagini", + Usage: "A self hosted photo library.", + Commands: []*cli.Command{ + { + Name: "serve", + Aliases: []string{"s"}, + Usage: "Start Imagini web server.", + Action: cmdServer, + }, + { + Name: "generate", + Usage: "generate graphql schema", Action: cmdGenerate, }, - }, - } - err := app.Run(os.Args) - if err != nil { - log.Fatal(err) - } + }, + } + err := app.Run(os.Args) + if err != nil { + log.Fatal(err) + } } func cmdServer(ctx *cli.Context) error { - log.Info("Starting Imagini Server") - server := server.NewServer() - server.StartServer() + log.Info("Starting Imagini Server") + server := server.NewServer() + server.StartServer() - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - <-c + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + <-c - server.StopServer() - os.Exit(0) + server.StopServer() + os.Exit(0) - return nil + return nil } func cmdGenerate(ctx *cli.Context) error { - log.Info("Generating Imagini Models") - gqlgenConf, err := config.LoadConfigFromDefaultLocations() - if err != nil { - log.Panic("Failed to load config", err.Error()) - os.Exit(2) - } + log.Info("Generating Imagini Models") + gqlgenConf, err := config.LoadConfigFromDefaultLocations() + if err != nil { + log.Panic("Failed to load config", err.Error()) + os.Exit(2) + } - log.Info("Generating Schema...") - err = api.Generate(gqlgenConf, - api.AddPlugin(plugin.New()), - ) - log.Info("Schema Generation Done") - if err != nil { - log.Panic(err.Error()) - os.Exit(3) - } - os.Exit(0) - return nil + log.Info("Generating Schema...") + err = api.Generate(gqlgenConf, + api.AddPlugin(plugin.New()), + ) + log.Info("Schema Generation Done") + if err != nil { + log.Panic(err.Error()) + os.Exit(3) + } + os.Exit(0) + return nil } diff --git a/graph/generated/generated.go b/graph/generated/generated.go index bcf5285..dee8a1b 100644 --- a/graph/generated/generated.go +++ b/graph/generated/generated.go @@ -60,6 +60,7 @@ type ComplexityRoot struct { } AuthResponse struct { + Device func(childComplexity int) int Error func(childComplexity int) int Result func(childComplexity int) int } @@ -72,6 +73,7 @@ type ComplexityRoot struct { Type func(childComplexity int) int UpdatedAt func(childComplexity int) int User func(childComplexity int) int + UserID func(childComplexity int) int } DeviceResponse struct { @@ -92,6 +94,7 @@ type ComplexityRoot struct { Tags func(childComplexity int) int UpdatedAt func(childComplexity int) int User func(childComplexity int) int + UserID func(childComplexity int) int } MediaItemResponse struct { @@ -118,7 +121,7 @@ type ComplexityRoot struct { Albums func(childComplexity int, filter *model.AlbumFilter, count *int, page *int) int Device func(childComplexity int, id string) int Devices func(childComplexity int, filter *model.DeviceFilter, count *int, page *int) int - Login func(childComplexity int, user string, password string) int + Login func(childComplexity int, user string, password string, deviceID *string) int Logout func(childComplexity int) int Me func(childComplexity int) int MediaItem func(childComplexity int, id string) int @@ -168,8 +171,8 @@ type MutationResolver interface { CreateUser(ctx context.Context, input model.NewUser) (*model.User, error) } type QueryResolver interface { - Login(ctx context.Context, user string, password string) (model.AuthResult, error) - Logout(ctx context.Context) (model.AuthResult, error) + Login(ctx context.Context, user string, password string, deviceID *string) (*model.AuthResponse, error) + Logout(ctx context.Context) (*model.AuthResponse, error) MediaItem(ctx context.Context, id string) (*model.MediaItem, error) Device(ctx context.Context, id string) (*model.Device, error) Album(ctx context.Context, id string) (*model.Album, error) @@ -240,6 +243,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.AlbumResponse.PageInfo(childComplexity), true + case "AuthResponse.Device": + if e.complexity.AuthResponse.Device == nil { + break + } + + return e.complexity.AuthResponse.Device(childComplexity), true + case "AuthResponse.Error": if e.complexity.AuthResponse.Error == nil { break @@ -303,6 +313,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Device.User(childComplexity), true + case "Device.userID": + if e.complexity.Device.UserID == nil { + break + } + + return e.complexity.Device.UserID(childComplexity), true + case "DeviceResponse.data": if e.complexity.DeviceResponse.Data == nil { break @@ -401,6 +418,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.MediaItem.User(childComplexity), true + case "MediaItem.userID": + if e.complexity.MediaItem.UserID == nil { + break + } + + return e.complexity.MediaItem.UserID(childComplexity), true + case "MediaItemResponse.data": if e.complexity.MediaItemResponse.Data == nil { break @@ -554,7 +578,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Login(childComplexity, args["user"].(string), args["password"].(string)), true + return e.complexity.Query.Login(childComplexity, args["user"].(string), args["password"].(string), args["deviceID"].(*string)), true case "Query.logout": if e.complexity.Query.Logout == nil { @@ -876,6 +900,7 @@ enum AuthResult { type AuthResponse { Result: AuthResult! + Device: Device Error: String } @@ -972,17 +997,18 @@ input AuthTypeFilter { # ------------------------------------------------------------ type Device { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time name: String! @meta(gorm: "not null") type: DeviceType! @meta(gorm: "default:Unknown;not null") - user: User @meta(gorm: "ForeignKey:ID;not null") + userID: ID! @meta(gorm: "not null") + user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null") refreshKey: String } type User { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time email: String! @meta(gorm: "not null;unique") @@ -995,7 +1021,7 @@ type User { } type MediaItem { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time exifDate: Time @@ -1006,18 +1032,19 @@ type MediaItem { origName: String! @meta(gorm: "not null") tags: [Tag] @meta(gorm: "many2many:media_tags") albums: [Album] @meta(gorm: "many2many:media_albums") - user: User @meta(gorm: "ForeignKey:ID;not null") + userID: ID! @meta(gorm: "not null") + user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null") } type Tag { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time name: String! @meta(gorm: "unique;not null") } type Album { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time name: String! @meta(gorm: "unique;not null") @@ -1165,8 +1192,9 @@ type Query { login( user: String! password: String! - ): AuthResult! - logout: AuthResult! @hasMinRole(role: User) + deviceID: ID + ): AuthResponse! + logout: AuthResponse! @hasMinRole(role: User) # Single Item mediaItem(id: ID!): MediaItem! @hasMinRole(role: User) @@ -1456,6 +1484,15 @@ func (ec *executionContext) field_Query_login_args(ctx context.Context, rawArgs } } args["password"] = arg1 + var arg2 *string + if tmp, ok := rawArgs["deviceID"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("deviceID")) + arg2, err = ec.unmarshalOID2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["deviceID"] = arg2 return args, nil } @@ -1663,7 +1700,7 @@ func (ec *executionContext) _Album_id(ctx context.Context, field graphql.Collect return obj.ID, nil } directive1 := func(ctx context.Context) (interface{}, error) { - gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primarykey;not null") + gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primaryKey;not null") if err != nil { return nil, err } @@ -1680,21 +1717,24 @@ func (ec *executionContext) _Album_id(ctx context.Context, field graphql.Collect if tmp == nil { return nil, nil } - if data, ok := tmp.(*string); ok { + if data, ok := tmp.(string); ok { return data, nil } - return nil, fmt.Errorf(`unexpected type %T from directive, should be *string`, tmp) + return nil, fmt.Errorf(`unexpected type %T from directive, should be string`, tmp) }) if err != nil { ec.Error(ctx, err) return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*string) + res := resTmp.(string) fc.Result = res - return ec.marshalOID2ᚖstring(ctx, field.Selections, res) + return ec.marshalNID2string(ctx, field.Selections, res) } func (ec *executionContext) _Album_createdAt(ctx context.Context, field graphql.CollectedField, obj *model.Album) (ret graphql.Marshaler) { @@ -1922,6 +1962,38 @@ func (ec *executionContext) _AuthResponse_Result(ctx context.Context, field grap return ec.marshalNAuthResult2reichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResult(ctx, field.Selections, res) } +func (ec *executionContext) _AuthResponse_Device(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "AuthResponse", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Device, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*model.Device) + fc.Result = res + return ec.marshalODevice2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐDevice(ctx, field.Selections, res) +} + func (ec *executionContext) _AuthResponse_Error(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -1976,7 +2048,7 @@ func (ec *executionContext) _Device_id(ctx context.Context, field graphql.Collec return obj.ID, nil } directive1 := func(ctx context.Context) (interface{}, error) { - gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primarykey;not null") + gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primaryKey;not null") if err != nil { return nil, err } @@ -1993,21 +2065,24 @@ func (ec *executionContext) _Device_id(ctx context.Context, field graphql.Collec if tmp == nil { return nil, nil } - if data, ok := tmp.(*string); ok { + if data, ok := tmp.(string); ok { return data, nil } - return nil, fmt.Errorf(`unexpected type %T from directive, should be *string`, tmp) + return nil, fmt.Errorf(`unexpected type %T from directive, should be string`, tmp) }) if err != nil { ec.Error(ctx, err) return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*string) + res := resTmp.(string) fc.Result = res - return ec.marshalOID2ᚖstring(ctx, field.Selections, res) + return ec.marshalNID2string(ctx, field.Selections, res) } func (ec *executionContext) _Device_createdAt(ctx context.Context, field graphql.CollectedField, obj *model.Device) (ret graphql.Marshaler) { @@ -2192,6 +2267,65 @@ func (ec *executionContext) _Device_type(ctx context.Context, field graphql.Coll return ec.marshalNDeviceType2reichardᚗioᚋimaginiᚋgraphᚋmodelᚐDeviceType(ctx, field.Selections, res) } +func (ec *executionContext) _Device_userID(ctx context.Context, field graphql.CollectedField, obj *model.Device) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Device", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + directive0 := func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.UserID, nil + } + directive1 := func(ctx context.Context) (interface{}, error) { + gorm, err := ec.unmarshalOString2ᚖstring(ctx, "not null") + if err != nil { + return nil, err + } + if ec.directives.Meta == nil { + return nil, errors.New("directive meta is not implemented") + } + return ec.directives.Meta(ctx, obj, directive0, gorm) + } + + tmp, err := directive1(rctx) + if err != nil { + return nil, graphql.ErrorOnPath(ctx, err) + } + if tmp == nil { + return nil, nil + } + if data, ok := tmp.(string); ok { + return data, nil + } + return nil, fmt.Errorf(`unexpected type %T from directive, should be string`, tmp) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNID2string(ctx, field.Selections, res) +} + func (ec *executionContext) _Device_user(ctx context.Context, field graphql.CollectedField, obj *model.Device) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -2214,7 +2348,7 @@ func (ec *executionContext) _Device_user(ctx context.Context, field graphql.Coll return obj.User, nil } directive1 := func(ctx context.Context) (interface{}, error) { - gorm, err := ec.unmarshalOString2ᚖstring(ctx, "ForeignKey:ID;not null") + gorm, err := ec.unmarshalOString2ᚖstring(ctx, "foreignKey:ID;references:UserID;not null") if err != nil { return nil, err } @@ -2241,11 +2375,14 @@ func (ec *executionContext) _Device_user(ctx context.Context, field graphql.Coll return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } res := resTmp.(*model.User) fc.Result = res - return ec.marshalOUser2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) + return ec.marshalNUser2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) } func (ec *executionContext) _Device_refreshKey(ctx context.Context, field graphql.CollectedField, obj *model.Device) (ret graphql.Marshaler) { @@ -2369,7 +2506,7 @@ func (ec *executionContext) _MediaItem_id(ctx context.Context, field graphql.Col return obj.ID, nil } directive1 := func(ctx context.Context) (interface{}, error) { - gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primarykey;not null") + gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primaryKey;not null") if err != nil { return nil, err } @@ -2386,21 +2523,24 @@ func (ec *executionContext) _MediaItem_id(ctx context.Context, field graphql.Col if tmp == nil { return nil, nil } - if data, ok := tmp.(*string); ok { + if data, ok := tmp.(string); ok { return data, nil } - return nil, fmt.Errorf(`unexpected type %T from directive, should be *string`, tmp) + return nil, fmt.Errorf(`unexpected type %T from directive, should be string`, tmp) }) if err != nil { ec.Error(ctx, err) return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*string) + res := resTmp.(string) fc.Result = res - return ec.marshalOID2ᚖstring(ctx, field.Selections, res) + return ec.marshalNID2string(ctx, field.Selections, res) } func (ec *executionContext) _MediaItem_createdAt(ctx context.Context, field graphql.CollectedField, obj *model.MediaItem) (ret graphql.Marshaler) { @@ -2852,6 +2992,65 @@ func (ec *executionContext) _MediaItem_albums(ctx context.Context, field graphql return ec.marshalOAlbum2ᚕᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐAlbum(ctx, field.Selections, res) } +func (ec *executionContext) _MediaItem_userID(ctx context.Context, field graphql.CollectedField, obj *model.MediaItem) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "MediaItem", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + directive0 := func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.UserID, nil + } + directive1 := func(ctx context.Context) (interface{}, error) { + gorm, err := ec.unmarshalOString2ᚖstring(ctx, "not null") + if err != nil { + return nil, err + } + if ec.directives.Meta == nil { + return nil, errors.New("directive meta is not implemented") + } + return ec.directives.Meta(ctx, obj, directive0, gorm) + } + + tmp, err := directive1(rctx) + if err != nil { + return nil, graphql.ErrorOnPath(ctx, err) + } + if tmp == nil { + return nil, nil + } + if data, ok := tmp.(string); ok { + return data, nil + } + return nil, fmt.Errorf(`unexpected type %T from directive, should be string`, tmp) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNID2string(ctx, field.Selections, res) +} + func (ec *executionContext) _MediaItem_user(ctx context.Context, field graphql.CollectedField, obj *model.MediaItem) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -2874,7 +3073,7 @@ func (ec *executionContext) _MediaItem_user(ctx context.Context, field graphql.C return obj.User, nil } directive1 := func(ctx context.Context) (interface{}, error) { - gorm, err := ec.unmarshalOString2ᚖstring(ctx, "ForeignKey:ID;not null") + gorm, err := ec.unmarshalOString2ᚖstring(ctx, "foreignKey:ID;references:UserID;not null") if err != nil { return nil, err } @@ -2901,11 +3100,14 @@ func (ec *executionContext) _MediaItem_user(ctx context.Context, field graphql.C return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } res := resTmp.(*model.User) fc.Result = res - return ec.marshalOUser2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) + return ec.marshalNUser2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) } func (ec *executionContext) _MediaItemResponse_data(ctx context.Context, field graphql.CollectedField, obj *model.MediaItemResponse) (ret graphql.Marshaler) { @@ -3435,7 +3637,7 @@ func (ec *executionContext) _Query_login(ctx context.Context, field graphql.Coll fc.Args = args resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Login(rctx, args["user"].(string), args["password"].(string)) + return ec.resolvers.Query().Login(rctx, args["user"].(string), args["password"].(string), args["deviceID"].(*string)) }) if err != nil { ec.Error(ctx, err) @@ -3447,9 +3649,9 @@ func (ec *executionContext) _Query_login(ctx context.Context, field graphql.Coll } return graphql.Null } - res := resTmp.(model.AuthResult) + res := resTmp.(*model.AuthResponse) fc.Result = res - return ec.marshalNAuthResult2reichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResult(ctx, field.Selections, res) + return ec.marshalNAuthResponse2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResponse(ctx, field.Selections, res) } func (ec *executionContext) _Query_logout(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { @@ -3491,10 +3693,10 @@ func (ec *executionContext) _Query_logout(ctx context.Context, field graphql.Col if tmp == nil { return nil, nil } - if data, ok := tmp.(model.AuthResult); ok { + if data, ok := tmp.(*model.AuthResponse); ok { return data, nil } - return nil, fmt.Errorf(`unexpected type %T from directive, should be reichard.io/imagini/graph/model.AuthResult`, tmp) + return nil, fmt.Errorf(`unexpected type %T from directive, should be *reichard.io/imagini/graph/model.AuthResponse`, tmp) }) if err != nil { ec.Error(ctx, err) @@ -3506,9 +3708,9 @@ func (ec *executionContext) _Query_logout(ctx context.Context, field graphql.Col } return graphql.Null } - res := resTmp.(model.AuthResult) + res := resTmp.(*model.AuthResponse) fc.Result = res - return ec.marshalNAuthResult2reichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResult(ctx, field.Selections, res) + return ec.marshalNAuthResponse2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResponse(ctx, field.Selections, res) } func (ec *executionContext) _Query_mediaItem(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { @@ -4323,7 +4525,7 @@ func (ec *executionContext) _Tag_id(ctx context.Context, field graphql.Collected return obj.ID, nil } directive1 := func(ctx context.Context) (interface{}, error) { - gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primarykey;not null") + gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primaryKey;not null") if err != nil { return nil, err } @@ -4340,21 +4542,24 @@ func (ec *executionContext) _Tag_id(ctx context.Context, field graphql.Collected if tmp == nil { return nil, nil } - if data, ok := tmp.(*string); ok { + if data, ok := tmp.(string); ok { return data, nil } - return nil, fmt.Errorf(`unexpected type %T from directive, should be *string`, tmp) + return nil, fmt.Errorf(`unexpected type %T from directive, should be string`, tmp) }) if err != nil { ec.Error(ctx, err) return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*string) + res := resTmp.(string) fc.Result = res - return ec.marshalOID2ᚖstring(ctx, field.Selections, res) + return ec.marshalNID2string(ctx, field.Selections, res) } func (ec *executionContext) _Tag_createdAt(ctx context.Context, field graphql.CollectedField, obj *model.Tag) (ret graphql.Marshaler) { @@ -4569,7 +4774,7 @@ func (ec *executionContext) _User_id(ctx context.Context, field graphql.Collecte return obj.ID, nil } directive1 := func(ctx context.Context) (interface{}, error) { - gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primarykey;not null") + gorm, err := ec.unmarshalOString2ᚖstring(ctx, "primaryKey;not null") if err != nil { return nil, err } @@ -4586,21 +4791,24 @@ func (ec *executionContext) _User_id(ctx context.Context, field graphql.Collecte if tmp == nil { return nil, nil } - if data, ok := tmp.(*string); ok { + if data, ok := tmp.(string); ok { return data, nil } - return nil, fmt.Errorf(`unexpected type %T from directive, should be *string`, tmp) + return nil, fmt.Errorf(`unexpected type %T from directive, should be string`, tmp) }) if err != nil { ec.Error(ctx, err) return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*string) + res := resTmp.(string) fc.Result = res - return ec.marshalOID2ᚖstring(ctx, field.Selections, res) + return ec.marshalNID2string(ctx, field.Selections, res) } func (ec *executionContext) _User_createdAt(ctx context.Context, field graphql.CollectedField, obj *model.User) (ret graphql.Marshaler) { @@ -7282,6 +7490,9 @@ func (ec *executionContext) _Album(ctx context.Context, sel ast.SelectionSet, ob out.Values[i] = graphql.MarshalString("Album") case "id": out.Values[i] = ec._Album_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "createdAt": out.Values[i] = ec._Album_createdAt(ctx, field, obj) case "updatedAt": @@ -7347,6 +7558,8 @@ func (ec *executionContext) _AuthResponse(ctx context.Context, sel ast.Selection if out.Values[i] == graphql.Null { invalids++ } + case "Device": + out.Values[i] = ec._AuthResponse_Device(ctx, field, obj) case "Error": out.Values[i] = ec._AuthResponse_Error(ctx, field, obj) default: @@ -7373,6 +7586,9 @@ func (ec *executionContext) _Device(ctx context.Context, sel ast.SelectionSet, o out.Values[i] = graphql.MarshalString("Device") case "id": out.Values[i] = ec._Device_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "createdAt": out.Values[i] = ec._Device_createdAt(ctx, field, obj) case "updatedAt": @@ -7387,8 +7603,16 @@ func (ec *executionContext) _Device(ctx context.Context, sel ast.SelectionSet, o if out.Values[i] == graphql.Null { invalids++ } + case "userID": + out.Values[i] = ec._Device_userID(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "user": out.Values[i] = ec._Device_user(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "refreshKey": out.Values[i] = ec._Device_refreshKey(ctx, field, obj) default: @@ -7444,6 +7668,9 @@ func (ec *executionContext) _MediaItem(ctx context.Context, sel ast.SelectionSet out.Values[i] = graphql.MarshalString("MediaItem") case "id": out.Values[i] = ec._MediaItem_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "createdAt": out.Values[i] = ec._MediaItem_createdAt(ctx, field, obj) case "updatedAt": @@ -7473,8 +7700,16 @@ func (ec *executionContext) _MediaItem(ctx context.Context, sel ast.SelectionSet out.Values[i] = ec._MediaItem_tags(ctx, field, obj) case "albums": out.Values[i] = ec._MediaItem_albums(ctx, field, obj) + case "userID": + out.Values[i] = ec._MediaItem_userID(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "user": out.Values[i] = ec._MediaItem_user(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -7828,6 +8063,9 @@ func (ec *executionContext) _Tag(ctx context.Context, sel ast.SelectionSet, obj out.Values[i] = graphql.MarshalString("Tag") case "id": out.Values[i] = ec._Tag_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "createdAt": out.Values[i] = ec._Tag_createdAt(ctx, field, obj) case "updatedAt": @@ -7890,6 +8128,9 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj out.Values[i] = graphql.MarshalString("User") case "id": out.Values[i] = ec._User_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "createdAt": out.Values[i] = ec._User_createdAt(ctx, field, obj) case "updatedAt": @@ -8233,6 +8474,20 @@ func (ec *executionContext) marshalNAlbumResponse2ᚖreichardᚗioᚋimaginiᚋg return ec._AlbumResponse(ctx, sel, v) } +func (ec *executionContext) marshalNAuthResponse2reichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResponse(ctx context.Context, sel ast.SelectionSet, v model.AuthResponse) graphql.Marshaler { + return ec._AuthResponse(ctx, sel, &v) +} + +func (ec *executionContext) marshalNAuthResponse2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResponse(ctx context.Context, sel ast.SelectionSet, v *model.AuthResponse) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + return ec._AuthResponse(ctx, sel, v) +} + func (ec *executionContext) unmarshalNAuthResult2reichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResult(ctx context.Context, v interface{}) (model.AuthResult, error) { var res model.AuthResult err := res.UnmarshalGQL(v) diff --git a/graph/model/models_auth.go b/graph/model/models_auth.go index 5c77bd9..3a40b57 100644 --- a/graph/model/models_auth.go +++ b/graph/model/models_auth.go @@ -1,12 +1,13 @@ package model import ( - "net/http" + "net/http" + + "github.com/lestrrat-go/jwx/jwt" ) type AuthContext struct { - AccessToken string - RefreshToken string - AuthResponse *http.ResponseWriter - AuthRequest *http.Request + AccessToken *jwt.Token + AuthResponse *http.ResponseWriter + AuthRequest *http.Request } diff --git a/graph/model/models_db.go b/graph/model/models_db.go index cd46621..4835610 100644 --- a/graph/model/models_db.go +++ b/graph/model/models_db.go @@ -1,36 +1,36 @@ package model import ( - "gorm.io/gorm" - "github.com/google/uuid" + "github.com/google/uuid" + "gorm.io/gorm" ) func (u *User) BeforeCreate(tx *gorm.DB) (err error) { - newID := uuid.New().String() - u.ID = &newID - return + newID := uuid.New().String() + u.ID = newID + return } func (a *Album) BeforeCreate(tx *gorm.DB) (err error) { - newID := uuid.New().String() - a.ID = &newID - return + newID := uuid.New().String() + a.ID = newID + return } func (m *MediaItem) BeforeCreate(tx *gorm.DB) (err error) { - newID := uuid.New().String() - m.ID = &newID - return + newID := uuid.New().String() + m.ID = newID + return } func (t *Tag) BeforeCreate(tx *gorm.DB) (err error) { - newID := uuid.New().String() - t.ID = &newID - return + newID := uuid.New().String() + t.ID = newID + return } func (d *Device) BeforeCreate(tx *gorm.DB) (err error) { - newID := uuid.New().String() - d.ID = &newID - return + newID := uuid.New().String() + d.ID = newID + return } diff --git a/graph/model/models_gen.go b/graph/model/models_gen.go index ed064a7..50787f3 100644 --- a/graph/model/models_gen.go +++ b/graph/model/models_gen.go @@ -12,7 +12,7 @@ import ( ) type Album struct { - ID *string `json:"id" gorm:"primarykey;not null"` + ID string `json:"id" gorm:"primaryKey;not null"` CreatedAt *time.Time `json:"createdAt" ` UpdatedAt *time.Time `json:"updatedAt" ` Name string `json:"name" gorm:"unique;not null"` @@ -34,6 +34,7 @@ type AlbumResponse struct { type AuthResponse struct { Result AuthResult `json:"Result" ` + Device *Device `json:"Device" ` Error *string `json:"Error" ` } @@ -50,12 +51,13 @@ type BooleanFilter struct { } type Device struct { - ID *string `json:"id" gorm:"primarykey;not null"` + ID string `json:"id" gorm:"primaryKey;not null"` CreatedAt *time.Time `json:"createdAt" ` UpdatedAt *time.Time `json:"updatedAt" ` Name string `json:"name" gorm:"not null"` Type DeviceType `json:"type" gorm:"default:Unknown;not null"` - User *User `json:"user" gorm:"ForeignKey:ID;not null"` + UserID string `json:"userID" gorm:"not null"` + User *User `json:"user" gorm:"foreignKey:ID;references:UserID;not null"` RefreshKey *string `json:"refreshKey" ` } @@ -111,7 +113,7 @@ type IntFilter struct { } type MediaItem struct { - ID *string `json:"id" gorm:"primarykey;not null"` + ID string `json:"id" gorm:"primaryKey;not null"` CreatedAt *time.Time `json:"createdAt" ` UpdatedAt *time.Time `json:"updatedAt" ` ExifDate *time.Time `json:"exifDate" ` @@ -122,7 +124,8 @@ type MediaItem struct { OrigName string `json:"origName" gorm:"not null"` Tags []*Tag `json:"tags" gorm:"many2many:media_tags"` Albums []*Album `json:"albums" gorm:"many2many:media_albums"` - User *User `json:"user" gorm:"ForeignKey:ID;not null"` + UserID string `json:"userID" gorm:"not null"` + User *User `json:"user" gorm:"foreignKey:ID;references:UserID;not null"` } type MediaItemFilter struct { @@ -206,7 +209,7 @@ type StringFilter struct { } type Tag struct { - ID *string `json:"id" gorm:"primarykey;not null"` + ID string `json:"id" gorm:"primaryKey;not null"` CreatedAt *time.Time `json:"createdAt" ` UpdatedAt *time.Time `json:"updatedAt" ` Name string `json:"name" gorm:"unique;not null"` @@ -236,7 +239,7 @@ type TimeFilter struct { } type User struct { - ID *string `json:"id" gorm:"primarykey;not null"` + ID string `json:"id" gorm:"primaryKey;not null"` CreatedAt *time.Time `json:"createdAt" ` UpdatedAt *time.Time `json:"updatedAt" ` Email string `json:"email" gorm:"not null;unique"` diff --git a/graph/resolver.go b/graph/resolver.go index e7fdab5..4249b75 100644 --- a/graph/resolver.go +++ b/graph/resolver.go @@ -1,13 +1,15 @@ package graph import ( - "reichard.io/imagini/internal/db" + "reichard.io/imagini/internal/auth" + "reichard.io/imagini/internal/db" ) // This file will not be regenerated automatically. // // It serves as dependency injection for your app, add any dependencies you require here. -type Resolver struct{ - DB *db.DBManager +type Resolver struct { + Auth *auth.AuthManager + DB *db.DBManager } diff --git a/graph/schema.graphqls b/graph/schema.graphqls index 3d90eb0..92e8493 100644 --- a/graph/schema.graphqls +++ b/graph/schema.graphqls @@ -42,6 +42,7 @@ enum AuthResult { type AuthResponse { Result: AuthResult! + Device: Device Error: String } @@ -138,17 +139,18 @@ input AuthTypeFilter { # ------------------------------------------------------------ type Device { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time name: String! @meta(gorm: "not null") type: DeviceType! @meta(gorm: "default:Unknown;not null") - user: User @meta(gorm: "ForeignKey:ID;not null") + userID: ID! @meta(gorm: "not null") + user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null") refreshKey: String } type User { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time email: String! @meta(gorm: "not null;unique") @@ -161,7 +163,7 @@ type User { } type MediaItem { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time exifDate: Time @@ -172,18 +174,19 @@ type MediaItem { origName: String! @meta(gorm: "not null") tags: [Tag] @meta(gorm: "many2many:media_tags") albums: [Album] @meta(gorm: "many2many:media_albums") - user: User @meta(gorm: "ForeignKey:ID;not null") + userID: ID! @meta(gorm: "not null") + user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null") } type Tag { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time name: String! @meta(gorm: "unique;not null") } type Album { - id: ID @meta(gorm: "primarykey;not null") + id: ID! @meta(gorm: "primaryKey;not null") createdAt: Time updatedAt: Time name: String! @meta(gorm: "unique;not null") @@ -331,8 +334,9 @@ type Query { login( user: String! password: String! - ): AuthResult! - logout: AuthResult! @hasMinRole(role: User) + deviceID: ID + ): AuthResponse! + logout: AuthResponse! @hasMinRole(role: User) # Single Item mediaItem(id: ID!): MediaItem! @hasMinRole(role: User) diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index 3f72070..50f185e 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -4,10 +4,12 @@ package graph // will be copied through when generating and any unknown code will be moved to the end. import ( - "net/http" "context" "fmt" + "net/http" + "strings" + "github.com/google/uuid" "reichard.io/imagini/graph/generated" "reichard.io/imagini/graph/model" ) @@ -41,26 +43,66 @@ func (r *mutationResolver) CreateUser(ctx context.Context, input model.NewUser) err := r.DB.CreateUser(user) if err != nil { - panic(fmt.Errorf("DB Error")) + return nil, err } return user, nil } -func (r *queryResolver) Login(ctx context.Context, user string, password string) (model.AuthResult, error) { - - // Set Cookie From Context +func (r *queryResolver) Login(ctx context.Context, user string, password string, deviceID *string) (*model.AuthResponse, error) { + // Set Cookie From Context authContext := ctx.Value("auth").(*model.AuthContext) - resp := *authContext.AuthResponse - testCookie := http.Cookie{Name: "TestCookie", Value: "Test123", Path: "/", HttpOnly: true} - http.SetCookie(resp, &testCookie) + resp := authContext.AuthResponse + req := authContext.AuthRequest - return model.AuthResultSuccess, nil + // Do Login + foundUser, success := r.Auth.AuthenticateUser(user, password) + if !success { + return &model.AuthResponse{Result: model.AuthResultFailure}, nil + } + + // Upsert Device + foundDevice := model.Device{} + if deviceID != nil { + parsedDeviceID, err := uuid.Parse(*deviceID) + if err != nil { + return &model.AuthResponse{Result: model.AuthResultFailure}, nil + } + foundDevice.ID = parsedDeviceID.String() + count, err := r.DB.Device(&foundDevice) + if count != 1 || err != nil { + return &model.AuthResponse{Result: model.AuthResultFailure}, nil + } + } else { + foundDevice.Type = deriveDeviceType(req) + err := r.DB.CreateDevice(&foundDevice) + if err != nil { + return &model.AuthResponse{Result: model.AuthResultFailure}, nil + } + } + + // Create Tokens + accessToken, err := r.Auth.CreateJWTAccessToken(foundUser, foundDevice) + if err != nil { + return &model.AuthResponse{Result: model.AuthResultFailure}, nil + } + refreshToken, err := r.Auth.CreateJWTRefreshToken(foundUser, foundDevice) + if err != nil { + return &model.AuthResponse{Result: model.AuthResultFailure}, nil + } + + // Set appropriate cookies + accessCookie := http.Cookie{Name: "AccessToken", Value: accessToken, Path: "/", HttpOnly: true} + refreshCookie := http.Cookie{Name: "RefreshToken", Value: refreshToken, Path: "/", HttpOnly: true} + http.SetCookie(*resp, &accessCookie) + http.SetCookie(*resp, &refreshCookie) + + return &model.AuthResponse{Result: model.AuthResultSuccess, Device: &foundDevice}, nil } -func (r *queryResolver) Logout(ctx context.Context) (model.AuthResult, error) { +func (r *queryResolver) Logout(ctx context.Context) (*model.AuthResponse, error) { // panic(fmt.Errorf("not implemented")) - return model.AuthResultSuccess, nil + return &model.AuthResponse{Result: model.AuthResultSuccess}, nil } func (r *queryResolver) MediaItem(ctx context.Context, id string) (*model.MediaItem, error) { @@ -126,3 +168,29 @@ func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} } type mutationResolver struct{ *Resolver } type queryResolver struct{ *Resolver } + +// !!! WARNING !!! +// The code below was going to be deleted when updating resolvers. It has been copied here so you have +// one last chance to move it out of harms way if you want. There are two reasons this happens: +// - When renaming or deleting a resolver the old code will be put in here. You can safely delete +// it when you're done. +// - You have helper methods in this file. Move them out to keep these resolver files clean. +func deriveDeviceType(r *http.Request) model.DeviceType { + userAgent := strings.ToLower(r.Header.Get("User-Agent")) + if strings.Contains(userAgent, "ios-imagini") { + return model.DeviceTypeIOs + } else if strings.Contains(userAgent, "android-imagini") { + return model.DeviceTypeAndroid + } else if strings.Contains(userAgent, "chrome") { + return model.DeviceTypeChrome + } else if strings.Contains(userAgent, "firefox") { + return model.DeviceTypeFirefox + } else if strings.Contains(userAgent, "msie") { + return model.DeviceTypeInternetExplorer + } else if strings.Contains(userAgent, "edge") { + return model.DeviceTypeEdge + } else if strings.Contains(userAgent, "safari") { + return model.DeviceTypeSafari + } + return model.DeviceTypeUnknown +} diff --git a/internal/api/auth.go b/internal/api/auth.go index 81ef498..197bd17 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -1,248 +1,124 @@ package api import ( - "fmt" - "time" - "strings" - "context" - "net/http" - "encoding/json" + "context" + "errors" + "fmt" + "net/http" - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/lestrrat-go/jwx/jwt" "github.com/99designs/gqlgen/graphql" + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/jwt" "reichard.io/imagini/graph/model" ) -func (api *API) loginHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - if r.Method != http.MethodPost { - errorJSON(w, "Method is not supported.", http.StatusMethodNotAllowed) - return - } +func (api *API) refreshTokens(refreshToken jwt.Token) (string, string, error) { + // Acquire User & Device + did, ok := refreshToken.Get("did") + if !ok { + return "", "", errors.New("Missing DID") + } + uid, ok := refreshToken.Get(jwt.SubjectKey) + if !ok { + return "", "", errors.New("Missing UID") + } + deviceUUID, err := uuid.Parse(fmt.Sprintf("%v", did)) + if err != nil { + return "", "", errors.New("Invalid DID") + } + userUUID, err := uuid.Parse(fmt.Sprintf("%v", uid)) + if err != nil { + return "", "", errors.New("Invalid UID") + } - // Decode into Struct - var creds APICredentials - err := json.NewDecoder(r.Body).Decode(&creds) - if err != nil { - errorJSON(w, "Invalid parameters.", http.StatusBadRequest) - return - } + // Device & User Skeleton + user := model.User{ID: userUUID.String()} + device := model.Device{ID: deviceUUID.String()} - // Validate - if creds.User == "" || creds.Password == "" { - errorJSON(w, "Invalid parameters.", http.StatusBadRequest) - return - } + // Find User + _, err = api.DB.User(&user) + if err != nil { + return "", "", err + } - // Do login - resp, user := api.Auth.AuthenticateUser(creds.User, creds.Password) - if !resp { - errorJSON(w, "Invalid credentials.", http.StatusUnauthorized) - return - } + // Update Access Token + accessTokenCookie, err := api.Auth.CreateJWTAccessToken(user, device) + if err != nil { + return "", "", err + } - // Upsert device - device, err := api.upsertRequestedDevice(user, r) - if err != nil { - log.Error("[api] loginHandler - Failed to upsert device: ", err) - errorJSON(w, "DB error. Unable to proceed.", http.StatusUnauthorized) - return - } - - // Create Tokens - accessToken, err := api.Auth.CreateJWTAccessToken(user, device) - refreshToken, err := api.Auth.CreateJWTRefreshToken(user, device) - - // Set appropriate cookies - accessCookie := http.Cookie{Name: "AccessToken", Value: accessToken, Path: "/", HttpOnly: true} - refreshCookie := http.Cookie{Name: "RefreshToken", Value: refreshToken, Path: "/", HttpOnly: true} - http.SetCookie(w, &accessCookie) - http.SetCookie(w, &refreshCookie) - - // Response success - successJSON(w, "Login success.", http.StatusOK) + return accessTokenCookie, "", err } -func (api *API) logoutHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Method is not supported.", http.StatusMethodNotAllowed) - return - } +func (api *API) validateTokens(w *http.ResponseWriter, r *http.Request) (jwt.Token, error) { + // Validate Access Token + accessCookie, _ := r.Cookie("AccessToken") + if accessCookie != nil { + accessToken, err := api.Auth.ValidateJWTAccessToken(accessCookie.Value) + if err == nil { + return accessToken, nil + } + } - // TODO: Reset Refresh Key + // Validate Refresh Cookie Exists + refreshCookie, _ := r.Cookie("RefreshToken") + if refreshCookie == nil { + return nil, errors.New("Tokens Invalid") + } - // Clear Cookies - http.SetCookie(w, &http.Cookie{Name: "AccessToken", Expires: time.Unix(0, 0)}) - http.SetCookie(w, &http.Cookie{Name: "RefreshToken", Expires: time.Unix(0, 0)}) + // Validate Refresh Token + refreshToken, err := api.Auth.ValidateJWTRefreshToken(refreshCookie.Value) + if err != nil { + return nil, errors.New("Tokens Invalid") + } - successJSON(w, "Logout success.", http.StatusOK) + // Refresh Access Token & Generate New Refresh Token + newAccessCookie, newRefreshCookie, err := api.refreshTokens(refreshToken) + if err != nil { + return nil, err + } + + // TODO: Actually Refresh Refresh Token + newRefreshCookie = refreshCookie.Value + + // Update Access & Refresh Cookies + http.SetCookie(*w, &http.Cookie{ + Name: "AccessToken", + Value: newAccessCookie, + }) + http.SetCookie(*w, &http.Cookie{ + Name: "RefreshToken", + Value: newRefreshCookie, + }) + + return jwt.ParseBytes([]byte(newAccessCookie)) } -/** - * This will find or create the requested device based on ID and User. - **/ -func (api *API) upsertRequestedDevice(user model.User, r *http.Request) (model.Device, error) { - requestedDevice := deriveRequestedDevice(r) - requestedDevice.Type = deriveDeviceType(r) - requestedDevice.User.ID = user.ID +func (api *API) hasMinRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) { + authContext := ctx.Value("auth").(*model.AuthContext) + accessToken, err := api.validateTokens(authContext.AuthResponse, authContext.AuthRequest) + if err != nil { + return nil, errors.New("Access Denied") + } + authContext.AccessToken = &accessToken - if *requestedDevice.ID == "" { - err := api.DB.CreateDevice(&requestedDevice) - createdDevice, err := api.DB.Device(&requestedDevice) - return createdDevice, err - } + userRole, ok := accessToken.Get("role") + if !ok { + return nil, errors.New("Access Denied") + } - foundDevice, err := api.DB.Device(&model.Device{ - ID: requestedDevice.ID, - User: &user, - }) + if userRole == model.RoleAdmin.String() { + return next(ctx) + } - return foundDevice, err + if userRole == role.String() { + return next(ctx) + } + + return nil, errors.New("Role Not Authenticated") } -func deriveDeviceType(r *http.Request) model.DeviceType { - userAgent := strings.ToLower(r.Header.Get("User-Agent")) - if strings.HasPrefix(userAgent, "ios-imagini"){ - return model.DeviceTypeIOs - } else if strings.HasPrefix(userAgent, "android-imagini"){ - return model.DeviceTypeAndroid - } else if strings.HasPrefix(userAgent, "chrome"){ - return model.DeviceTypeChrome - } else if strings.HasPrefix(userAgent, "firefox"){ - return model.DeviceTypeFirefox - } else if strings.HasPrefix(userAgent, "msie"){ - return model.DeviceTypeInternetExplorer - } else if strings.HasPrefix(userAgent, "edge"){ - return model.DeviceTypeEdge - } else if strings.HasPrefix(userAgent, "safari"){ - return model.DeviceTypeSafari - } - return model.DeviceTypeUnknown -} - -func deriveRequestedDevice(r *http.Request) model.Device { - deviceSkeleton := model.Device{} - authHeader := r.Header.Get("X-Imagini-Authorization") - splitAuthInfo := strings.Split(authHeader, ",") - - // For each Key - Value pair - for i := range splitAuthInfo { - - // Split Key - Value - item := strings.TrimSpace(splitAuthInfo[i]) - splitItem := strings.SplitN(item, "=", 2) - if len(splitItem) != 2 { - continue - } - - // Derive Key - key := strings.ToLower(strings.TrimSpace(splitItem[0])) - if key != "deviceid" && key != "devicename" { - continue - } - - // Derive Value - val := trimQuotes(strings.TrimSpace(splitItem[1])) - if key == "deviceid" { - parsedDeviceUUID, err := uuid.Parse(val) - if err != nil { - log.Warn("[auth] deriveRequestedDevice - Unable to parse requested DeviceUUID: ", val) - continue - } - stringDeviceUUID := parsedDeviceUUID.String() - deviceSkeleton.ID = &stringDeviceUUID - } else if key == "devicename" { - deviceSkeleton.Name = val - } - } - - // If name not set, set to type - if deviceSkeleton.Name == "" { - deviceSkeleton.Name = deviceSkeleton.Type.String() - } - - return deviceSkeleton -} - -func (api *API) refreshAccessToken(w http.ResponseWriter, r *http.Request) (jwt.Token, error) { - refreshCookie, err := r.Cookie("RefreshToken") - if err != nil { - log.Warn("[middleware] RefreshToken not found") - return nil, err - } - - // Validate Refresh Token - refreshToken, err := api.Auth.ValidateJWTRefreshToken(refreshCookie.Value) - if err != nil { - http.SetCookie(w, &http.Cookie{Name: "AccessToken", Expires: time.Unix(0, 0)}) - http.SetCookie(w, &http.Cookie{Name: "RefreshToken", Expires: time.Unix(0, 0)}) - return nil, err - } - - // Acquire User & Device (Trusted) - did, ok := refreshToken.Get("did") - if !ok { - return nil, err - } - uid, ok := refreshToken.Get(jwt.SubjectKey) - if !ok { - return nil, err - } - deviceUUID, err := uuid.Parse(fmt.Sprintf("%v", did)) - if err != nil { - return nil, err - } - userUUID, err := uuid.Parse(fmt.Sprintf("%v", uid)) - if err != nil { - return nil, err - } - - stringUserUUID := userUUID.String() - stringDeviceUUID := deviceUUID.String() - - // Device & User Skeleton - user := model.User{ID: &stringUserUUID} - device := model.Device{ID: &stringDeviceUUID} - - // Update token - accessTokenString, err := api.Auth.CreateJWTAccessToken(user, device) - if err != nil { - return nil, err - } - accessCookie := http.Cookie{Name: "AccessToken", Value: accessTokenString} - http.SetCookie(w, &accessCookie) - - // TODO: Update Refresh Key & Token - - // Convert to jwt.Token - accessTokenBytes := []byte(accessTokenString) - accessToken, err := jwt.ParseBytes(accessTokenBytes) - - return accessToken, err -} - -func trimQuotes(s string) string { - if len(s) >= 2 { - if s[0] == '"' && s[len(s)-1] == '"' { - return s[1 : len(s)-1] - } - } - return s -} - -func hasMinRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) { - // if !getCurrentUser(ctx).HasRole(role) { - // // block calling the next resolver - // return nil, fmt.Errorf("Access denied") - // } - - // or let it pass through - return next(ctx) -} - -func metaDirective(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error){ - return next(ctx) +func (api *API) metaDirective(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error) { + return next(ctx) } diff --git a/internal/api/media.go b/internal/api/media.go index ed714e2..d648e25 100644 --- a/internal/api/media.go +++ b/internal/api/media.go @@ -1,50 +1,53 @@ package api import ( - "os" - "path" "net/http" + "os" + "path" + + "reichard.io/imagini/graph/model" ) // Responsible for serving up static images / videos func (api *API) mediaHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } - if path.Dir(r.URL.Path) != "/media" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } + if path.Dir(r.URL.Path) != "/media" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } - // Acquire Width & Height Parameters - query := r.URL.Query() - width := query["width"] - height := query["height"] - _ = width - _ = height + // Acquire Width & Height Parameters + query := r.URL.Query() + width := query["width"] + height := query["height"] + _ = width + _ = height - // TODO: Caching & Resizing - // - If both, force resize with new scale - // - If one, scale resize proportionally + // TODO: Caching & Resizing + // - If both, force resize with new scale + // - If one, scale resize proportionally - // Pull out UUIDs - reqInfo := r.Context().Value("uuids").(map[string]string) - uid := reqInfo["uid"] + // Pull out userID + authContext := r.Context().Value("auth").(*model.AuthContext) + rawUserID, _ := (*authContext.AccessToken).Get("sub") + userID := rawUserID.(string) - // Derive Path - fileName := path.Base(r.URL.Path) - folderPath := path.Join("/" + api.Config.DataPath + "/media/" + uid) - mediaPath := path.Join(folderPath + "/" + fileName) + // Derive Path + fileName := path.Base(r.URL.Path) + folderPath := path.Join("/" + api.Config.DataPath + "/media/" + userID) + mediaPath := path.Join(folderPath + "/" + fileName) - // Check if File Exists - _, err := os.Stat(mediaPath) - if os.IsNotExist(err) { - // TODO: Different HTTP Response Code? - w.WriteHeader(http.StatusMethodNotAllowed) - return - } + // Check if File Exists + _, err := os.Stat(mediaPath) + if os.IsNotExist(err) { + // TODO: Different HTTP Response Code? + w.WriteHeader(http.StatusMethodNotAllowed) + return + } - http.ServeFile(w, r, mediaPath) + http.ServeFile(w, r, mediaPath) } diff --git a/internal/api/middlewares.go b/internal/api/middlewares.go index 6edb91b..3510133 100644 --- a/internal/api/middlewares.go +++ b/internal/api/middlewares.go @@ -1,104 +1,79 @@ package api import ( - log "github.com/sirupsen/logrus" - "net/http" - "context" - "os" + "context" + "net/http" + "os" - "reichard.io/imagini/graph/model" + log "github.com/sirupsen/logrus" + + "reichard.io/imagini/graph/model" ) type Middleware func(http.Handler) http.HandlerFunc func multipleMiddleware(h http.HandlerFunc, m ...Middleware) http.HandlerFunc { - if len(m) < 1 { - return h - } - wrapped := h - for i := len(m) - 1; i >= 0; i-- { - wrapped = m[i](wrapped) - } - return wrapped + if len(m) < 1 { + return h + } + wrapped := h + for i := len(m) - 1; i >= 0; i-- { + wrapped = m[i](wrapped) + } + return wrapped } +/** + * This is used for the graphQL endpoints that may require access to the + * Request and ResponseWriter variables. These are used to get / set cookies. + **/ func (api *API) injectContextMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Info("[middleware] Entering testMiddleware...") - authContext := &model.AuthContext{ - AuthResponse: &w, - AuthRequest: r, - } - accessCookie, err := r.Cookie("AccessToken") - if err != nil { - log.Warn("[middleware] AccessToken not found") - } else { - authContext.AccessToken = accessCookie.Value - } - refreshCookie, err := r.Cookie("RefreshToken") - if err != nil { - log.Warn("[middleware] RefreshToken not found") - } else { - authContext.RefreshToken = refreshCookie.Value - } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Add context - ctx := context.WithValue(r.Context(), "auth", authContext) - r = r.WithContext(ctx) + authContext := &model.AuthContext{ + AuthResponse: &w, + AuthRequest: r, + } - log.Info("[middleware] Exiting testMiddleware...") - next.ServeHTTP(w, r) - }) + // Add context + ctx := context.WithValue(r.Context(), "auth", authContext) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + + }) } +/** + * This is used for non graphQL endpoints that require authentication. + **/ func (api *API) authMiddleware(next http.Handler) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Acquire Token - accessCookie, err := r.Cookie("AccessToken") - if err != nil { - log.Warn("[middleware] AccessToken not found") - errorJSON(w, "Invalid token.", http.StatusUnauthorized) - return - } + accessToken, err := api.validateTokens(&w, r) + if err != nil { + errorJSON(w, "Invalid token.", http.StatusUnauthorized) + return + } - // Validate JWT Tokens - accessToken, err := api.Auth.ValidateJWTAccessToken(accessCookie.Value) + // Create Context + authContext := &model.AuthContext{ + AccessToken: &accessToken, + } + ctx := context.WithValue(r.Context(), "auth", authContext) + r = r.WithContext(ctx) - if err != nil && err.Error() == "exp not satisfied" { - log.Info("[middleware] Refreshing AccessToken") - accessToken, err = api.refreshAccessToken(w, r) - if err != nil { - log.Warn("[middleware] Refreshing AccessToken failed: ", err) - errorJSON(w, "Invalid token.", http.StatusUnauthorized) - return - } - log.Info("[middleware] AccessToken Refreshed") - } else if err != nil { - log.Warn("[middleware] AccessToken failed to validate") - errorJSON(w, "Invalid token.", http.StatusUnauthorized) - return - } + next.ServeHTTP(w, r) - // Acquire UserID and DeviceID - reqInfo := make(map[string]string) - uid, _ := accessToken.Get("sub") - did, _ := accessToken.Get("did") - reqInfo["uid"] = uid.(string) - reqInfo["did"] = did.(string) - - // Add context - ctx := context.WithValue(r.Context(), "uuids", reqInfo) - sr := r.WithContext(ctx) - - next.ServeHTTP(w, sr) - }) + }) } func (api *API) logMiddleware(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.SetOutput(os.Stdout) - log.Println(r.Method, r.URL) - h.ServeHTTP(w, r) - }) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + log.SetOutput(os.Stdout) + log.Println(r.Method, r.URL) + h.ServeHTTP(w, r) + + }) } diff --git a/internal/api/routes.go b/internal/api/routes.go index 41bb097..7343c48 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -1,8 +1,8 @@ package api import ( - "net/http" - "encoding/json" + "encoding/json" + "net/http" "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/playground" @@ -13,39 +13,46 @@ import ( func (api *API) registerRoutes() { - // Set up Directives - c := generated.Config{ Resolvers: &graph.Resolver{ DB: api.DB } } - c.Directives.HasMinRole = hasMinRoleDirective - c.Directives.Meta = metaDirective - srv := handler.NewDefaultServer(generated.NewExecutableSchema(c)) + // Set up Directives + graphConfig := generated.Config{ + Resolvers: &graph.Resolver{ + DB: api.DB, + Auth: api.Auth, + }, + Directives: generated.DirectiveRoot{ + Meta: api.metaDirective, + HasMinRole: api.hasMinRoleDirective, + }, + } + srv := handler.NewDefaultServer(generated.NewExecutableSchema(graphConfig)) - // Handle GraphQL + // Handle GraphQL api.Router.Handle("/playground", playground.Handler("GraphQL playground", "/query")) api.Router.Handle("/query", api.injectContextMiddleware(srv)) - // Handle Resource Route - api.Router.HandleFunc("/media/", multipleMiddleware( - api.mediaHandler, - api.authMiddleware, - )) + // Handle Resource Route + api.Router.HandleFunc("/media/", multipleMiddleware( + api.mediaHandler, + api.authMiddleware, + )) } func errorJSON(w http.ResponseWriter, err string, code int) { - errStruct := &APIResponse{Error: &APIError{Message: err, Code: int64(code)}} - responseJSON(w, errStruct, code) + errStruct := &APIResponse{Error: &APIError{Message: err, Code: int64(code)}} + responseJSON(w, errStruct, code) } func responseJSON(w http.ResponseWriter, msg interface{}, code int) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(code) - json.NewEncoder(w).Encode(msg) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(code) + json.NewEncoder(w).Encode(msg) } func successJSON(w http.ResponseWriter, msg string, code int) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(code) - json.NewEncoder(w).Encode(map[string]interface{}{"success": msg}) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(code) + json.NewEncoder(w).Encode(map[string]interface{}{"success": msg}) } diff --git a/internal/auth/auth.go b/internal/auth/auth.go index a4dfd8c..062a02f 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,185 +1,175 @@ package auth import ( - "fmt" - "time" - "errors" - "encoding/json" + "encoding/json" + "errors" + "fmt" + "time" - "gorm.io/gorm" - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/lestrrat-go/jwx/jwa" - "github.com/lestrrat-go/jwx/jwt" + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" + log "github.com/sirupsen/logrus" + "gorm.io/gorm" - "reichard.io/imagini/graph/model" - "reichard.io/imagini/internal/db" - "reichard.io/imagini/internal/config" - "reichard.io/imagini/internal/session" + "reichard.io/imagini/graph/model" + "reichard.io/imagini/internal/config" + "reichard.io/imagini/internal/db" ) type AuthManager struct { - DB *db.DBManager - Config *config.Config - Session *session.SessionManager + DB *db.DBManager + Config *config.Config } func NewMgr(db *db.DBManager, c *config.Config) *AuthManager { - session := session.NewMgr() - return &AuthManager{ - DB: db, - Config: c, - Session: session, - } + return &AuthManager{ + DB: db, + Config: c, + } } -func (auth *AuthManager) AuthenticateUser(user, password string) (bool, model.User) { - // Search Objects - userByName := &model.User{} - userByName.Username = user +func (auth *AuthManager) AuthenticateUser(user, password string) (model.User, bool) { + // Find User by Username / Email + foundUser := &model.User{Username: user} + _, err := auth.DB.User(foundUser) - foundUser, err := auth.DB.User(userByName) - if errors.Is(err, gorm.ErrRecordNotFound) { - userByEmail := &model.User{} - userByEmail.Email = user - foundUser, err = auth.DB.User(userByEmail) - } + // By Username + if errors.Is(err, gorm.ErrRecordNotFound) { + foundUser = &model.User{Email: user} + _, err = auth.DB.User(foundUser) + } - // Error Checking - if errors.Is(err, gorm.ErrRecordNotFound) { - log.Warn("[auth] User not found: ", user) - return false, foundUser - } else if err != nil { - log.Error(err) - return false, foundUser - } + // By Email + if errors.Is(err, gorm.ErrRecordNotFound) { + log.Warn("[auth] User not found: ", user) + return *foundUser, false + } else if err != nil { + log.Error(err) + return *foundUser, false + } - log.Info("[auth] Authenticating user: ", foundUser.Username) + log.Info("[auth] Authenticating user: ", foundUser.Username) - // Determine Type - switch foundUser.AuthType { - case "Local": - return authenticateLocalUser(foundUser, password), foundUser - case "LDAP": - return authenticateLDAPUser(foundUser, password), foundUser - default: - return false, foundUser - } -} - -func (auth *AuthManager) getRole(user model.User) string { - // TODO: Lookup role of user - return "User" + // Determine Type + switch foundUser.AuthType { + case "Local": + return *foundUser, authenticateLocalUser(*foundUser, password) + case "LDAP": + return *foundUser, authenticateLDAPUser(*foundUser, password) + default: + return *foundUser, false + } } func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, error) { - byteRefreshJWT := []byte(refreshJWT) + byteRefreshJWT := []byte(refreshJWT) - // Acquire Relevant Device - unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT) - did, ok := unverifiedToken.Get("did") - if !ok { - return nil, errors.New("did does not exist") - } - deviceID, err := uuid.Parse(fmt.Sprintf("%v", did)) - if err != nil { - return nil, errors.New("did does not parse") - } - stringDeviceID := deviceID.String() - device, err := auth.DB.Device(&model.Device{ID: &stringDeviceID}) - if err != nil { - return nil, err - } + // Acquire Relevant Device + unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT) + did, ok := unverifiedToken.Get("did") + if !ok { + return nil, errors.New("did does not exist") + } + deviceID, err := uuid.Parse(fmt.Sprintf("%v", did)) + if err != nil { + return nil, errors.New("did does not parse") + } + device := &model.Device{ID: deviceID.String()} + _, err = auth.DB.Device(device) + if err != nil { + return nil, err + } - // Verify & Validate Token - verifiedToken, err := jwt.ParseBytes(byteRefreshJWT, - jwt.WithValidate(true), - jwt.WithVerify(jwa.HS256, []byte(*device.RefreshKey)), - ) - if err != nil { - fmt.Println("failed to parse payload: ", err) - return nil, err - } - return verifiedToken, nil + // Verify & Validate Token + verifiedToken, err := jwt.ParseBytes(byteRefreshJWT, + jwt.WithValidate(true), + jwt.WithVerify(jwa.HS256, []byte(*device.RefreshKey)), + ) + if err != nil { + fmt.Println("failed to parse payload: ", err) + return nil, err + } + return verifiedToken, nil } func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, error) { - byteAccessJWT := []byte(accessJWT) - verifiedToken, err := jwt.ParseBytes(byteAccessJWT, - jwt.WithValidate(true), - jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)), - ) + byteAccessJWT := []byte(accessJWT) + verifiedToken, err := jwt.ParseBytes(byteAccessJWT, + jwt.WithValidate(true), + jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)), + ) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - return verifiedToken, nil + return verifiedToken, nil } func (auth *AuthManager) CreateJWTRefreshToken(user model.User, device model.Device) (string, error) { - // Acquire Refresh Key - byteKey := []byte(*device.RefreshKey) + // Acquire Refresh Key + byteKey := []byte(*device.RefreshKey) - // Create New Token - tm := time.Now() - t := jwt.New() - t.Set(`did`, device.ID) // Device ID - t.Set(jwt.SubjectKey, user.ID) // User ID - t.Set(jwt.AudienceKey, `imagini`) // App ID - t.Set(jwt.IssuedAtKey, tm) // Issued At + // Create New Token + tm := time.Now() + t := jwt.New() + t.Set(`did`, device.ID) // Device ID + t.Set(jwt.SubjectKey, user.ID) // User ID + t.Set(jwt.AudienceKey, `imagini`) // App ID + t.Set(jwt.IssuedAtKey, tm) // Issued At - // iOS & Android = Never Expiring Refresh Token - if device.Type != "iOS" && device.Type != "Android" { - t.Set(jwt.ExpirationKey, tm.Add(time.Hour * 24)) // 1 Day Access Key - } + // iOS & Android = Never Expiring Refresh Token + if device.Type != "iOS" && device.Type != "Android" { + t.Set(jwt.ExpirationKey, tm.Add(time.Hour*24)) // 1 Day Access Key + } - // Validate Token Creation - _, err := json.MarshalIndent(t, "", " ") - if err != nil { - fmt.Printf("failed to generate JSON: %s\n", err) - return "", err - } + // Validate Token Creation + _, err := json.MarshalIndent(t, "", " ") + if err != nil { + fmt.Printf("failed to generate JSON: %s\n", err) + return "", err + } - // Sign Token - signed, err := jwt.Sign(t, jwa.HS256, byteKey) - if err != nil { - log.Printf("failed to sign token: %s", err) - return "", err - } + // Sign Token + signed, err := jwt.Sign(t, jwa.HS256, byteKey) + if err != nil { + log.Printf("failed to sign token: %s", err) + return "", err + } - // Return Token - return string(signed), nil + // Return Token + return string(signed), nil } func (auth *AuthManager) CreateJWTAccessToken(user model.User, device model.Device) (string, error) { - // Create New Token - tm := time.Now() - t := jwt.New() - t.Set(`did`, device.ID) // Device ID - t.Set(`role`, auth.getRole(user)) // User Role (Admin / User) - t.Set(jwt.SubjectKey, user.ID) // User ID - t.Set(jwt.AudienceKey, `imagini`) // App ID - t.Set(jwt.IssuedAtKey, tm) // Issued At - t.Set(jwt.ExpirationKey, tm.Add(time.Hour * 2)) // 2 Hour Access Key + // Create New Token + tm := time.Now() + t := jwt.New() + t.Set(`did`, device.ID) // Device ID + t.Set(`role`, user.Role.String()) // User Role (Admin / User) + t.Set(jwt.SubjectKey, user.ID) // User ID + t.Set(jwt.AudienceKey, `imagini`) // App ID + t.Set(jwt.IssuedAtKey, tm) // Issued At + t.Set(jwt.ExpirationKey, tm.Add(time.Hour*2)) // 2 Hour Access Key - // Validate Token Creation - _, err := json.MarshalIndent(t, "", " ") - if err != nil { - fmt.Printf("failed to generate JSON: %s\n", err) - return "", err - } + // Validate Token Creation + _, err := json.MarshalIndent(t, "", " ") + if err != nil { + fmt.Printf("failed to generate JSON: %s\n", err) + return "", err + } - // Use Server Key - byteKey := []byte(auth.Config.JWTSecret) + // Use Server Key + byteKey := []byte(auth.Config.JWTSecret) - // Sign Token - signed, err := jwt.Sign(t, jwa.HS256, byteKey) - if err != nil { - log.Printf("failed to sign token: %s", err) - return "", err - } + // Sign Token + signed, err := jwt.Sign(t, jwa.HS256, byteKey) + if err != nil { + log.Printf("failed to sign token: %s", err) + return "", err + } - // Return Token - return string(signed), nil + // Return Token + return string(signed), nil } diff --git a/internal/db/db.go b/internal/db/db.go index 1873598..3f3de1a 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,113 +1,113 @@ package db import ( - "errors" - "path" - "fmt" + "errors" + "fmt" + "path" - log "github.com/sirupsen/logrus" - // "gorm.io/gorm/logger" - "gorm.io/driver/sqlite" - "gorm.io/gorm" + log "github.com/sirupsen/logrus" + // "gorm.io/gorm/logger" + "gorm.io/driver/sqlite" + "gorm.io/gorm" - "reichard.io/imagini/internal/config" - "reichard.io/imagini/graph/model" + "reichard.io/imagini/graph/model" + "reichard.io/imagini/internal/config" ) type DBManager struct { - db *gorm.DB + db *gorm.DB } func NewMgr(c *config.Config) *DBManager { - gormConfig := &gorm.Config{ - PrepareStmt: true, - // Logger: logger.Default.LogMode(logger.Silent), - } + gormConfig := &gorm.Config{ + PrepareStmt: true, + // Logger: logger.Default.LogMode(logger.Silent), + } - // Create manager - dbm := &DBManager{} + // Create manager + dbm := &DBManager{} - if c.DBType == "SQLite" { - dbLocation := path.Join(c.ConfigPath, "imagini.db") - dbm.db, _ = gorm.Open(sqlite.Open(dbLocation), gormConfig) - } else { - log.Fatal("Unsupported Database") - } + if c.DBType == "SQLite" { + dbLocation := path.Join(c.ConfigPath, "imagini.db") + dbm.db, _ = gorm.Open(sqlite.Open(dbLocation), gormConfig) + } else { + log.Fatal("Unsupported Database") + } - // Initialize database - dbm.db.AutoMigrate(&model.Device{}) - dbm.db.AutoMigrate(&model.User{}) - dbm.db.AutoMigrate(&model.MediaItem{}) - dbm.db.AutoMigrate(&model.Tag{}) - dbm.db.AutoMigrate(&model.Album{}) + // Initialize database + dbm.db.AutoMigrate(&model.Device{}) + dbm.db.AutoMigrate(&model.User{}) + dbm.db.AutoMigrate(&model.MediaItem{}) + dbm.db.AutoMigrate(&model.Tag{}) + dbm.db.AutoMigrate(&model.Album{}) - // Determine whether to bootstrap - var count int64 - dbm.db.Model(&model.User{}).Count(&count) - if count == 0 { - dbm.bootstrapDatabase() - } + // Determine whether to bootstrap + var count int64 + dbm.db.Model(&model.User{}).Count(&count) + if count == 0 { + dbm.bootstrapDatabase() + } - return dbm + return dbm } func (dbm *DBManager) bootstrapDatabase() { - log.Info("[query] Bootstrapping database.") + log.Info("[query] Bootstrapping database.") - password := "admin" - user := &model.User{ - Username: "admin", - AuthType: "Local", - Password: &password, - Role: model.RoleAdmin, - } + password := "admin" + user := &model.User{ + Username: "admin", + AuthType: "Local", + Password: &password, + Role: model.RoleAdmin, + } - err := dbm.CreateUser(user) + err := dbm.CreateUser(user) - if err != nil { - log.Fatal("[query] Unable to bootstrap database.") - } + if err != nil { + log.Fatal("[query] Unable to bootstrap database.") + } } func (dbm *DBManager) QueryBuilder(dest interface{}, params []byte) (int64, error) { - // TODO: - // - Where Filters - // - Sort Filters - // - Paging Filters + // TODO: + // - Where Filters + // - Sort Filters + // - Paging Filters - objType := fmt.Sprintf("%T", dest) - if objType == "*[]model.MediaItem" { - // TODO: Validate MediaItem Type - } else { - // Return Error - return 0, errors.New("Invalid type") - } + objType := fmt.Sprintf("%T", dest) + if objType == "*[]model.MediaItem" { + // TODO: Validate MediaItem Type + } else { + // Return Error + return 0, errors.New("Invalid type") + } - var count int64 - err := dbm.db.Find(dest).Count(&count).Error; - return count, err + var count int64 + err := dbm.db.Find(dest).Count(&count).Error + return count, err - // Paging: - // - Regular Pagination: - // - /api/v1/MediaItems?page[limit]=50&page=2 - // - Meta Count Only - // - /api/v1/MediaItems?page[limit]=0 + // Paging: + // - Regular Pagination: + // - /api/v1/MediaItems?page[limit]=50&page=2 + // - Meta Count Only + // - /api/v1/MediaItems?page[limit]=0 - // Sorting: - // - Ascending Sort: - // - /api/v1/MediaItems?sort=created_at - // - Descending Sort: - // - /api/v1/MediaItems?sort=-created_at + // Sorting: + // - Ascending Sort: + // - /api/v1/MediaItems?sort=created_at + // - Descending Sort: + // - /api/v1/MediaItems?sort=-created_at - // Filters: - // - Greater Than / Less Than (created_at, updated_at, exif_date) - // - /api/v1/MediaItems?filter[created_at]>=2020-01-01&filter[created_at]<=2021-01-01 - // - Long / Lat Range (latitude, longitude) - // - /api/v1/MediaItems?filter[latitude]>=71.1827&filter[latitude]<=72.0000&filter[longitude]>=100.000&filter[longitude]<=101.0000 - // - Image / Video (media_type) - // - /api/v1/MediaItems?filter[media_type]=Image - // - Tags (tags) - // - /api/v1/MediaItems?filter[tags]=id1,id2,id3 - // - Albums (albums) - // - /api/v1/MediaItems?filter[albums]=id1 + // Filters: + // - Greater Than / Less Than (created_at, updated_at, exif_date) + // - /api/v1/MediaItems?filter[created_at]>=2020-01-01&filter[created_at]<=2021-01-01 + // - Long / Lat Range (latitude, longitude) + // - /api/v1/MediaItems?filter[latitude]>=71.1827&filter[latitude]<=72.0000&filter[longitude]>=100.000&filter[longitude]<=101.0000 + // - Image / Video (media_type) + // - /api/v1/MediaItems?filter[media_type]=Image + // - Tags (tags) + // - /api/v1/MediaItems?filter[tags]=id1,id2,id3 + // - Albums (albums) + // - /api/v1/MediaItems?filter[albums]=id1 } diff --git a/internal/db/devices.go b/internal/db/devices.go index 9bb6d35..990cb18 100644 --- a/internal/db/devices.go +++ b/internal/db/devices.go @@ -1,31 +1,30 @@ package db import ( - "github.com/google/uuid" - log "github.com/sirupsen/logrus" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" - "reichard.io/imagini/graph/model" + "reichard.io/imagini/graph/model" ) -func (dbm *DBManager) CreateDevice (device *model.Device) error { - log.Info("[db] Creating device: ", device.Name) - refreshKey := uuid.New().String() - device.RefreshKey = &refreshKey - err := dbm.db.Create(&device).Error - return err +func (dbm *DBManager) CreateDevice(device *model.Device) error { + log.Debug("[db] Creating device: ", device.Name) + refreshKey := uuid.New().String() + device.RefreshKey = &refreshKey + err := dbm.db.Create(device).Error + return err } -func (dbm *DBManager) Device (device *model.Device) (model.Device, error) { - var foundDevice model.Device - var count int64 - err := dbm.db.Where(&device).First(&foundDevice).Count(&count).Error - return foundDevice, err +func (dbm *DBManager) Device(device *model.Device) (int64, error) { + var count int64 + err := dbm.db.Where(device).First(device).Count(&count).Error + return count, err } -func (dbm *DBManager) DeleteDevice (user *model.Device) error { - return nil +func (dbm *DBManager) DeleteDevice(user *model.Device) error { + return nil } -func (dbm *DBManager) UpdateRefreshToken (device *model.Device, refreshToken string) error { - return nil +func (dbm *DBManager) UpdateRefreshToken(device *model.Device, refreshToken string) error { + return nil } diff --git a/internal/db/errors.go b/internal/db/errors.go index 035b3c1..787ae41 100644 --- a/internal/db/errors.go +++ b/internal/db/errors.go @@ -3,5 +3,5 @@ package db import "errors" var ( - ErrUserAlreadyExists = errors.New("user already exists") + ErrUserAlreadyExists = errors.New("user already exists") ) diff --git a/internal/db/media_items.go b/internal/db/media_items.go index 6e65c27..f7f3b80 100644 --- a/internal/db/media_items.go +++ b/internal/db/media_items.go @@ -1,21 +1,21 @@ package db import ( - log "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" - "reichard.io/imagini/graph/model" + "reichard.io/imagini/graph/model" ) -func (dbm *DBManager) CreateMediaItem (mediaItem *model.MediaItem) error { - log.Info("[db] Creating media item: ", mediaItem.FileName) - err := dbm.db.Create(&mediaItem).Error - return err +func (dbm *DBManager) CreateMediaItem(mediaItem *model.MediaItem) error { + log.Debug("[db] Creating media item: ", mediaItem.FileName) + err := dbm.db.Create(mediaItem).Error + return err } func (dbm *DBManager) MediaItems(mediaItemFilter *model.MediaItem) ([]model.MediaItem, int64, error) { - var mediaItems []model.MediaItem - var count int64 + var mediaItems []model.MediaItem + var count int64 - err := dbm.db.Where(&mediaItemFilter).Find(&mediaItems).Count(&count).Error; - return mediaItems, count, err + err := dbm.db.Where(mediaItemFilter).Find(&mediaItems).Count(&count).Error + return mediaItems, count, err } diff --git a/internal/db/users.go b/internal/db/users.go index 8853789..e69fa6b 100644 --- a/internal/db/users.go +++ b/internal/db/users.go @@ -1,43 +1,41 @@ package db import ( - "golang.org/x/crypto/bcrypt" - log "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" - "reichard.io/imagini/graph/model" + "reichard.io/imagini/graph/model" ) -func (dbm *DBManager) CreateUser (user *model.User) error { - log.Info("[db] Creating user: ", user.Username) - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(*user.Password), bcrypt.DefaultCost) - if err != nil { - log.Error(err) - return err - } - stringHashedPassword := string(hashedPassword) - user.Password = &stringHashedPassword - err = dbm.db.Create(&user).Error - return err +func (dbm *DBManager) CreateUser(user *model.User) error { + log.Info("[db] Creating user: ", user.Username) + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(*user.Password), bcrypt.DefaultCost) + if err != nil { + log.Error(err) + return err + } + stringHashedPassword := string(hashedPassword) + user.Password = &stringHashedPassword + return dbm.db.Create(user).Error } -func (dbm *DBManager) User (user *model.User) (model.User, error) { - var foundUser model.User - var count int64 - err := dbm.db.Where(&user).First(&foundUser).Count(&count).Error - return foundUser, err +func (dbm *DBManager) User(user *model.User) (int64, error) { + var count int64 + err := dbm.db.Where(user).First(user).Count(&count).Error + return count, err } -func (dbm *DBManager) Users () ([]*model.User, int64, error) { - var foundUsers []*model.User - var count int64 - err := dbm.db.Find(&foundUsers).Count(&count).Error - return foundUsers, count, err +func (dbm *DBManager) Users() ([]*model.User, int64, error) { + var foundUsers []*model.User + var count int64 + err := dbm.db.Find(&foundUsers).Count(&count).Error + return foundUsers, count, err } -func (dbm *DBManager) DeleteUser (user model.User) error { - return nil +func (dbm *DBManager) DeleteUser(user model.User) error { + return nil } -func (dbm *DBManager) UpdatePassword (user model.User, pw string) { +func (dbm *DBManager) UpdatePassword(user model.User, pw string) { }