This commit is contained in:
Evan Reichard 2021-02-04 15:31:07 -05:00
parent 082f923482
commit a5692babb8
20 changed files with 319 additions and 348 deletions

View File

@ -1,58 +1,59 @@
package server package server
import ( import (
"time" "context"
"context" "net/http"
"net/http" "time"
log "github.com/sirupsen/logrus"
"reichard.io/imagini/internal/db" log "github.com/sirupsen/logrus"
"reichard.io/imagini/internal/api"
"reichard.io/imagini/internal/auth" "reichard.io/imagini/internal/api"
"reichard.io/imagini/internal/config" "reichard.io/imagini/internal/auth"
"reichard.io/imagini/internal/config"
"reichard.io/imagini/internal/db"
) )
type Server struct { type Server struct {
API *api.API API *api.API
Auth *auth.AuthManager Auth *auth.AuthManager
Config *config.Config Config *config.Config
Database *db.DBManager Database *db.DBManager
httpServer *http.Server httpServer *http.Server
} }
func NewServer() *Server { func NewServer() *Server {
c := config.Load() c := config.Load()
db := db.NewMgr(c) db := db.NewMgr(c)
auth := auth.NewMgr(db, c) auth := auth.NewMgr(db, c)
api := api.NewApi(db, c, auth) api := api.NewApi(db, c, auth)
return &Server{ return &Server{
API: api, API: api,
Auth: auth, Auth: auth,
Config: c, Config: c,
Database: db, Database: db,
} }
} }
func (s *Server) StartServer() { func (s *Server) StartServer() {
listenAddr := (":" + s.Config.ListenPort) listenAddr := (":" + s.Config.ListenPort)
s.httpServer = &http.Server{ s.httpServer = &http.Server{
Handler: s.API.Router, Handler: s.API.Router,
Addr: listenAddr, Addr: listenAddr,
} }
go func() { go func() {
err := s.httpServer.ListenAndServe() err := s.httpServer.ListenAndServe()
if err != nil { if err != nil {
log.Error("Error starting server ", err) log.Error("Error starting server ", err)
return return
} }
}() }()
} }
func (s *Server) StopServer() { func (s *Server) StopServer() {
ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.httpServer.Shutdown(ctx) s.httpServer.Shutdown(ctx)
} }

View File

@ -43,6 +43,7 @@ type ResolverRoot interface {
type DirectiveRoot struct { type DirectiveRoot struct {
HasMinRole func(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) HasMinRole func(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error)
IsPrivate func(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error)
Meta func(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error) Meta func(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error)
} }
@ -243,21 +244,21 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.AlbumResponse.PageInfo(childComplexity), true return e.complexity.AlbumResponse.PageInfo(childComplexity), true
case "AuthResponse.Device": case "AuthResponse.device":
if e.complexity.AuthResponse.Device == nil { if e.complexity.AuthResponse.Device == nil {
break break
} }
return e.complexity.AuthResponse.Device(childComplexity), true return e.complexity.AuthResponse.Device(childComplexity), true
case "AuthResponse.Error": case "AuthResponse.error":
if e.complexity.AuthResponse.Error == nil { if e.complexity.AuthResponse.Error == nil {
break break
} }
return e.complexity.AuthResponse.Error(childComplexity), true return e.complexity.AuthResponse.Error(childComplexity), true
case "AuthResponse.Result": case "AuthResponse.result":
if e.complexity.AuthResponse.Result == nil { if e.complexity.AuthResponse.Result == nil {
break break
} }
@ -863,6 +864,7 @@ scalar Upload
# https://gqlgen.com/reference/directives/ # https://gqlgen.com/reference/directives/
directive @hasMinRole(role: Role!) on FIELD_DEFINITION directive @hasMinRole(role: Role!) on FIELD_DEFINITION
directive @isPrivate on FIELD_DEFINITION | INPUT_FIELD_DEFINITION
directive @meta( directive @meta(
gorm: String, gorm: String,
@ -899,9 +901,9 @@ enum AuthResult {
} }
type AuthResponse { type AuthResponse {
Result: AuthResult! result: AuthResult!
Device: Device device: Device
Error: String error: String
} }
# ------------------------------------------------------------ # ------------------------------------------------------------
@ -1004,7 +1006,7 @@ type Device {
type: DeviceType! @meta(gorm: "default:Unknown;not null") type: DeviceType! @meta(gorm: "default:Unknown;not null")
userID: ID! @meta(gorm: "not null") userID: ID! @meta(gorm: "not null")
user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null") user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null")
refreshKey: String refreshKey: String @deprecated(reason: "Private Field") # @isPrivate
} }
type User { type User {
@ -1017,7 +1019,7 @@ type User {
lastName: String lastName: String
role: Role! @meta(gorm: "default:User;not null") role: Role! @meta(gorm: "default:User;not null")
authType: AuthType! @meta(gorm: "default:Local;not null") authType: AuthType! @meta(gorm: "default:Local;not null")
password: String password: String @deprecated(reason: "Private Field") #@isPrivate
} }
type MediaItem { type MediaItem {
@ -1927,7 +1929,7 @@ func (ec *executionContext) _AlbumResponse_pageInfo(ctx context.Context, field g
return ec.marshalNPageInfo2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐPageInfo(ctx, field.Selections, res) return ec.marshalNPageInfo2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐPageInfo(ctx, field.Selections, res)
} }
func (ec *executionContext) _AuthResponse_Result(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { func (ec *executionContext) _AuthResponse_result(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r)) ec.Error(ctx, ec.Recover(ctx, r))
@ -1962,7 +1964,7 @@ func (ec *executionContext) _AuthResponse_Result(ctx context.Context, field grap
return ec.marshalNAuthResult2reichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResult(ctx, field.Selections, res) return ec.marshalNAuthResult2reichardᚗioᚋimaginiᚋgraphᚋmodelᚐAuthResult(ctx, field.Selections, res)
} }
func (ec *executionContext) _AuthResponse_Device(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { func (ec *executionContext) _AuthResponse_device(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r)) ec.Error(ctx, ec.Recover(ctx, r))
@ -1994,7 +1996,7 @@ func (ec *executionContext) _AuthResponse_Device(ctx context.Context, field grap
return ec.marshalODevice2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐDevice(ctx, field.Selections, res) return ec.marshalODevice2ᚖreichardᚗioᚋimaginiᚋgraphᚋmodelᚐDevice(ctx, field.Selections, res)
} }
func (ec *executionContext) _AuthResponse_Error(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { func (ec *executionContext) _AuthResponse_error(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r)) ec.Error(ctx, ec.Recover(ctx, r))
@ -7553,15 +7555,15 @@ func (ec *executionContext) _AuthResponse(ctx context.Context, sel ast.Selection
switch field.Name { switch field.Name {
case "__typename": case "__typename":
out.Values[i] = graphql.MarshalString("AuthResponse") out.Values[i] = graphql.MarshalString("AuthResponse")
case "Result": case "result":
out.Values[i] = ec._AuthResponse_Result(ctx, field, obj) out.Values[i] = ec._AuthResponse_result(ctx, field, obj)
if out.Values[i] == graphql.Null { if out.Values[i] == graphql.Null {
invalids++ invalids++
} }
case "Device": case "device":
out.Values[i] = ec._AuthResponse_Device(ctx, field, obj) out.Values[i] = ec._AuthResponse_device(ctx, field, obj)
case "Error": case "error":
out.Values[i] = ec._AuthResponse_Error(ctx, field, obj) out.Values[i] = ec._AuthResponse_error(ctx, field, obj)
default: default:
panic("unknown field " + strconv.Quote(field.Name)) panic("unknown field " + strconv.Quote(field.Name))
} }

View File

@ -33,9 +33,9 @@ type AlbumResponse struct {
} }
type AuthResponse struct { type AuthResponse struct {
Result AuthResult `json:"Result" ` Result AuthResult `json:"result" `
Device *Device `json:"Device" ` Device *Device `json:"device" `
Error *string `json:"Error" ` Error *string `json:"error" `
} }
type AuthTypeFilter struct { type AuthTypeFilter struct {

View File

@ -5,6 +5,7 @@ scalar Upload
# https://gqlgen.com/reference/directives/ # https://gqlgen.com/reference/directives/
directive @hasMinRole(role: Role!) on FIELD_DEFINITION directive @hasMinRole(role: Role!) on FIELD_DEFINITION
directive @isPrivate on FIELD_DEFINITION | INPUT_FIELD_DEFINITION
directive @meta( directive @meta(
gorm: String, gorm: String,
@ -41,9 +42,9 @@ enum AuthResult {
} }
type AuthResponse { type AuthResponse {
Result: AuthResult! result: AuthResult!
Device: Device device: Device
Error: String error: String
} }
# ------------------------------------------------------------ # ------------------------------------------------------------
@ -146,7 +147,7 @@ type Device {
type: DeviceType! @meta(gorm: "default:Unknown;not null") type: DeviceType! @meta(gorm: "default:Unknown;not null")
userID: ID! @meta(gorm: "not null") userID: ID! @meta(gorm: "not null")
user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null") user: User! @meta(gorm: "foreignKey:ID;references:UserID;not null")
refreshKey: String refreshKey: String @isPrivate
} }
type User { type User {
@ -159,7 +160,7 @@ type User {
lastName: String lastName: String
role: Role! @meta(gorm: "default:User;not null") role: Role! @meta(gorm: "default:User;not null")
authType: AuthType! @meta(gorm: "default:Local;not null") authType: AuthType! @meta(gorm: "default:Local;not null")
password: String password: String @isPrivate
} }
type MediaItem { type MediaItem {

View File

@ -75,6 +75,8 @@ func (r *queryResolver) Login(ctx context.Context, user string, password string,
} }
} else { } else {
foundDevice.Type = deriveDeviceType(req) foundDevice.Type = deriveDeviceType(req)
foundDevice.UserID = foundUser.ID
// TODO: foundDevice.User = &foundUser
err := r.DB.CreateDevice(&foundDevice) err := r.DB.CreateDevice(&foundDevice)
if err != nil { if err != nil {
return &model.AuthResponse{Result: model.AuthResultFailure}, nil return &model.AuthResponse{Result: model.AuthResultFailure}, nil
@ -97,6 +99,8 @@ func (r *queryResolver) Login(ctx context.Context, user string, password string,
http.SetCookie(*resp, &accessCookie) http.SetCookie(*resp, &accessCookie)
http.SetCookie(*resp, &refreshCookie) http.SetCookie(*resp, &refreshCookie)
// TODO: Prob bandaid
foundDevice.User = &foundUser
return &model.AuthResponse{Result: model.AuthResultSuccess, Device: &foundDevice}, nil return &model.AuthResponse{Result: model.AuthResultSuccess, Device: &foundDevice}, nil
} }

View File

@ -1,12 +1,10 @@
package api package api
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"github.com/99designs/gqlgen/graphql"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwt" "github.com/lestrrat-go/jwx/jwt"
@ -94,31 +92,3 @@ func (api *API) validateTokens(w *http.ResponseWriter, r *http.Request) (jwt.Tok
return jwt.ParseBytes([]byte(newAccessCookie)) return jwt.ParseBytes([]byte(newAccessCookie))
} }
func (api *API) hasMinRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) {
authContext := ctx.Value("auth").(*model.AuthContext)
accessToken, err := api.validateTokens(authContext.AuthResponse, authContext.AuthRequest)
if err != nil {
return nil, errors.New("Access Denied")
}
authContext.AccessToken = &accessToken
userRole, ok := accessToken.Get("role")
if !ok {
return nil, errors.New("Access Denied")
}
if userRole == model.RoleAdmin.String() {
return next(ctx)
}
if userRole == role.String() {
return next(ctx)
}
return nil, errors.New("Role Not Authenticated")
}
func (api *API) metaDirective(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error) {
return next(ctx)
}

View File

@ -0,0 +1,50 @@
package api
import (
"context"
"errors"
"github.com/99designs/gqlgen/graphql"
"reichard.io/imagini/graph/model"
)
/**
* This is used to validate whether the users role is adequate for the requested resource.
**/
func (api *API) hasMinRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) {
authContext := ctx.Value("auth").(*model.AuthContext)
accessToken, err := api.validateTokens(authContext.AuthResponse, authContext.AuthRequest)
if err != nil {
return nil, errors.New("Access Denied")
}
authContext.AccessToken = &accessToken
userRole, ok := accessToken.Get("role")
if !ok {
return nil, errors.New("Access Denied")
}
if userRole == model.RoleAdmin.String() {
return next(ctx)
}
if userRole == role.String() {
return next(ctx)
}
return nil, errors.New("Role Not Authenticated")
}
/**
* This is needed but not used. Meta is used for Gorm.
**/
func (api *API) metaDirective(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error) {
return next(ctx)
}
/**
* This overrides the response so fields with an @isPrivate directive are always nil.
**/
func (api *API) isPrivateDirective(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) {
return nil, errors.New("Private Field")
}

View File

@ -8,7 +8,9 @@ import (
"reichard.io/imagini/graph/model" "reichard.io/imagini/graph/model"
) )
// Responsible for serving up static images / videos /**
* Responsible for serving up static images / videos
**/
func (api *API) mediaHandler(w http.ResponseWriter, r *http.Request) { func (api *API) mediaHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)

View File

@ -3,9 +3,6 @@ package api
import ( import (
"context" "context"
"net/http" "net/http"
"os"
log "github.com/sirupsen/logrus"
"reichard.io/imagini/graph/model" "reichard.io/imagini/graph/model"
) )
@ -27,9 +24,8 @@ func multipleMiddleware(h http.HandlerFunc, m ...Middleware) http.HandlerFunc {
* This is used for the graphQL endpoints that may require access to the * This is used for the graphQL endpoints that may require access to the
* Request and ResponseWriter variables. These are used to get / set cookies. * Request and ResponseWriter variables. These are used to get / set cookies.
**/ **/
func (api *API) injectContextMiddleware(next http.Handler) http.Handler { func (api *API) contextMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authContext := &model.AuthContext{ authContext := &model.AuthContext{
AuthResponse: &w, AuthResponse: &w,
AuthRequest: r, AuthRequest: r,
@ -40,7 +36,6 @@ func (api *API) injectContextMiddleware(next http.Handler) http.Handler {
r = r.WithContext(ctx) r = r.WithContext(ctx)
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
@ -49,10 +44,10 @@ func (api *API) injectContextMiddleware(next http.Handler) http.Handler {
**/ **/
func (api *API) authMiddleware(next http.Handler) http.HandlerFunc { func (api *API) authMiddleware(next http.Handler) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Validate Tokens
accessToken, err := api.validateTokens(&w, r) accessToken, err := api.validateTokens(&w, r)
if err != nil { if err != nil {
errorJSON(w, "Invalid token.", http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
@ -64,16 +59,5 @@ func (api *API) authMiddleware(next http.Handler) http.HandlerFunc {
r = r.WithContext(ctx) r = r.WithContext(ctx)
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
})
}
func (api *API) logMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.SetOutput(os.Stdout)
log.Println(r.Method, r.URL)
h.ServeHTTP(w, r)
}) })
} }

View File

@ -1,24 +0,0 @@
package api
type APICredentials struct {
User string `json:"user"`
Password string `json:"password"`
}
type APIData interface{}
type APIMeta struct {
Count int64 `json:"count"`
Page int64 `json:"page"`
}
type APIError struct {
Message string `json:"message"`
Code int64 `json:"code"`
}
type APIResponse struct {
Data APIData `json:"data,omitempty"`
Meta *APIMeta `json:"meta,omitempty"`
Error *APIError `json:"error,omitempty"`
}

View File

@ -1,9 +1,6 @@
package api package api
import ( import (
"encoding/json"
"net/http"
"github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/playground" "github.com/99designs/gqlgen/graphql/playground"
@ -12,7 +9,6 @@ import (
) )
func (api *API) registerRoutes() { func (api *API) registerRoutes() {
// Set up Directives // Set up Directives
graphConfig := generated.Config{ graphConfig := generated.Config{
Resolvers: &graph.Resolver{ Resolvers: &graph.Resolver{
@ -21,6 +17,7 @@ func (api *API) registerRoutes() {
}, },
Directives: generated.DirectiveRoot{ Directives: generated.DirectiveRoot{
Meta: api.metaDirective, Meta: api.metaDirective,
IsPrivate: api.isPrivateDirective,
HasMinRole: api.hasMinRoleDirective, HasMinRole: api.hasMinRoleDirective,
}, },
} }
@ -28,31 +25,11 @@ func (api *API) registerRoutes() {
// Handle GraphQL // Handle GraphQL
api.Router.Handle("/playground", playground.Handler("GraphQL playground", "/query")) api.Router.Handle("/playground", playground.Handler("GraphQL playground", "/query"))
api.Router.Handle("/query", api.injectContextMiddleware(srv)) api.Router.Handle("/query", api.contextMiddleware(srv))
// Handle Resource Route // Handle Resource Route
api.Router.HandleFunc("/media/", multipleMiddleware( api.Router.HandleFunc("/media/", multipleMiddleware(
api.mediaHandler, api.mediaHandler,
api.authMiddleware, api.authMiddleware,
)) ))
}
func errorJSON(w http.ResponseWriter, err string, code int) {
errStruct := &APIResponse{Error: &APIError{Message: err, Code: int64(code)}}
responseJSON(w, errStruct, code)
}
func responseJSON(w http.ResponseWriter, msg interface{}, code int) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(code)
json.NewEncoder(w).Encode(msg)
}
func successJSON(w http.ResponseWriter, msg string, code int) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(code)
json.NewEncoder(w).Encode(map[string]interface{}{"success": msg})
} }

View File

@ -1,14 +1,8 @@
package auth package auth
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
@ -61,115 +55,3 @@ func (auth *AuthManager) AuthenticateUser(user, password string) (model.User, bo
return *foundUser, false return *foundUser, false
} }
} }
func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, error) {
byteRefreshJWT := []byte(refreshJWT)
// Acquire Relevant Device
unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT)
did, ok := unverifiedToken.Get("did")
if !ok {
return nil, errors.New("did does not exist")
}
deviceID, err := uuid.Parse(fmt.Sprintf("%v", did))
if err != nil {
return nil, errors.New("did does not parse")
}
device := &model.Device{ID: deviceID.String()}
_, err = auth.DB.Device(device)
if err != nil {
return nil, err
}
// Verify & Validate Token
verifiedToken, err := jwt.ParseBytes(byteRefreshJWT,
jwt.WithValidate(true),
jwt.WithVerify(jwa.HS256, []byte(*device.RefreshKey)),
)
if err != nil {
fmt.Println("failed to parse payload: ", err)
return nil, err
}
return verifiedToken, nil
}
func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, error) {
byteAccessJWT := []byte(accessJWT)
verifiedToken, err := jwt.ParseBytes(byteAccessJWT,
jwt.WithValidate(true),
jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)),
)
if err != nil {
return nil, err
}
return verifiedToken, nil
}
func (auth *AuthManager) CreateJWTRefreshToken(user model.User, device model.Device) (string, error) {
// Acquire Refresh Key
byteKey := []byte(*device.RefreshKey)
// Create New Token
tm := time.Now()
t := jwt.New()
t.Set(`did`, device.ID) // Device ID
t.Set(jwt.SubjectKey, user.ID) // User ID
t.Set(jwt.AudienceKey, `imagini`) // App ID
t.Set(jwt.IssuedAtKey, tm) // Issued At
// iOS & Android = Never Expiring Refresh Token
if device.Type != "iOS" && device.Type != "Android" {
t.Set(jwt.ExpirationKey, tm.Add(time.Hour*24)) // 1 Day Access Key
}
// Validate Token Creation
_, err := json.MarshalIndent(t, "", " ")
if err != nil {
fmt.Printf("failed to generate JSON: %s\n", err)
return "", err
}
// Sign Token
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
if err != nil {
log.Printf("failed to sign token: %s", err)
return "", err
}
// Return Token
return string(signed), nil
}
func (auth *AuthManager) CreateJWTAccessToken(user model.User, device model.Device) (string, error) {
// Create New Token
tm := time.Now()
t := jwt.New()
t.Set(`did`, device.ID) // Device ID
t.Set(`role`, user.Role.String()) // User Role (Admin / User)
t.Set(jwt.SubjectKey, user.ID) // User ID
t.Set(jwt.AudienceKey, `imagini`) // App ID
t.Set(jwt.IssuedAtKey, tm) // Issued At
t.Set(jwt.ExpirationKey, tm.Add(time.Hour*2)) // 2 Hour Access Key
// Validate Token Creation
_, err := json.MarshalIndent(t, "", " ")
if err != nil {
fmt.Printf("failed to generate JSON: %s\n", err)
return "", err
}
// Use Server Key
byteKey := []byte(auth.Config.JWTSecret)
// Sign Token
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
if err != nil {
log.Printf("failed to sign token: %s", err)
return "", err
}
// Return Token
return string(signed), nil
}

126
internal/auth/jwt.go Normal file
View File

@ -0,0 +1,126 @@
package auth
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
log "github.com/sirupsen/logrus"
"reichard.io/imagini/graph/model"
)
func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, error) {
byteRefreshJWT := []byte(refreshJWT)
// Acquire Relevant Device
unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT)
did, ok := unverifiedToken.Get("did")
if !ok {
return nil, errors.New("did does not exist")
}
deviceID, err := uuid.Parse(fmt.Sprintf("%v", did))
if err != nil {
return nil, errors.New("did does not parse")
}
device := &model.Device{ID: deviceID.String()}
_, err = auth.DB.Device(device)
if err != nil {
return nil, err
}
// Verify & Validate Token
verifiedToken, err := jwt.ParseBytes(byteRefreshJWT,
jwt.WithValidate(true),
jwt.WithVerify(jwa.HS256, []byte(*device.RefreshKey)),
)
if err != nil {
fmt.Println("failed to parse payload: ", err)
return nil, err
}
return verifiedToken, nil
}
func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, error) {
byteAccessJWT := []byte(accessJWT)
verifiedToken, err := jwt.ParseBytes(byteAccessJWT,
jwt.WithValidate(true),
jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)),
)
if err != nil {
return nil, err
}
return verifiedToken, nil
}
func (auth *AuthManager) CreateJWTRefreshToken(user model.User, device model.Device) (string, error) {
// Acquire Refresh Key
byteKey := []byte(*device.RefreshKey)
// Create New Token
tm := time.Now()
t := jwt.New()
t.Set(`did`, device.ID) // Device ID
t.Set(jwt.SubjectKey, user.ID) // User ID
t.Set(jwt.AudienceKey, `imagini`) // App ID
t.Set(jwt.IssuedAtKey, tm) // Issued At
// iOS & Android = Never Expiring Refresh Token
if device.Type != "iOS" && device.Type != "Android" {
t.Set(jwt.ExpirationKey, tm.Add(time.Hour*24)) // 1 Day Access Key
}
// Validate Token Creation
_, err := json.MarshalIndent(t, "", " ")
if err != nil {
fmt.Printf("failed to generate JSON: %s\n", err)
return "", err
}
// Sign Token
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
if err != nil {
log.Printf("failed to sign token: %s", err)
return "", err
}
// Return Token
return string(signed), nil
}
func (auth *AuthManager) CreateJWTAccessToken(user model.User, device model.Device) (string, error) {
// Create New Token
tm := time.Now()
t := jwt.New()
t.Set(`did`, device.ID) // Device ID
t.Set(`role`, user.Role.String()) // User Role (Admin / User)
t.Set(jwt.SubjectKey, user.ID) // User ID
t.Set(jwt.AudienceKey, `imagini`) // App ID
t.Set(jwt.IssuedAtKey, tm) // Issued At
t.Set(jwt.ExpirationKey, tm.Add(time.Hour*2)) // 2 Hour Access Key
// Validate Token Creation
_, err := json.MarshalIndent(t, "", " ")
if err != nil {
fmt.Printf("failed to generate JSON: %s\n", err)
return "", err
}
// Use Server Key
byteKey := []byte(auth.Config.JWTSecret)
// Sign Token
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
if err != nil {
log.Printf("failed to sign token: %s", err)
return "", err
}
// Return Token
return string(signed), nil
}

View File

@ -1,9 +1,9 @@
package auth package auth
import ( import (
"reichard.io/imagini/graph/model" "reichard.io/imagini/graph/model"
) )
func authenticateLDAPUser(user model.User, pw string) bool { func authenticateLDAPUser(user model.User, pw string) bool {
return false return false
} }

View File

@ -1,18 +1,18 @@
package auth package auth
import ( import (
"golang.org/x/crypto/bcrypt" log "github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt"
"reichard.io/imagini/graph/model" "reichard.io/imagini/graph/model"
) )
func authenticateLocalUser(user model.User, pw string) bool { func authenticateLocalUser(user model.User, pw string) bool {
bPassword :=[]byte(pw) bPassword := []byte(pw)
err := bcrypt.CompareHashAndPassword([]byte(*user.Password), bPassword) err := bcrypt.CompareHashAndPassword([]byte(*user.Password), bPassword)
if err == nil { if err == nil {
log.Info("[auth] Authentication successfull: ", user.Username) log.Info("[auth] Authentication successfull: ", user.Username)
return true return true
} }
log.Warn("[auth] Authentication failed: ", user.Username) log.Warn("[auth] Authentication failed: ", user.Username)
return false return false
} }

View File

@ -1,34 +1,34 @@
package config package config
import ( import (
"os" "os"
) )
type Config struct { type Config struct {
DBType string DBType string
DBName string DBName string
DBPassword string DBPassword string
DataPath string DataPath string
ConfigPath string ConfigPath string
JWTSecret string JWTSecret string
ListenPort string ListenPort string
} }
func Load() *Config { func Load() *Config {
return &Config{ return &Config{
DBType: getEnv("DATABASE_TYPE", "SQLite"), DBType: getEnv("DATABASE_TYPE", "SQLite"),
DBName: getEnv("DATABASE_NAME", "imagini"), DBName: getEnv("DATABASE_NAME", "imagini"),
DBPassword: getEnv("DATABASE_PASSWORD", ""), DBPassword: getEnv("DATABASE_PASSWORD", ""),
ConfigPath: getEnv("CONFIG_PATH", "/config"), ConfigPath: getEnv("CONFIG_PATH", "/config"),
DataPath: getEnv("DATA_PATH", "/data"), DataPath: getEnv("DATA_PATH", "/data"),
JWTSecret: getEnv("JWT_SECRET", "58b9340c0472cf045db226bc445966524e780cd38bc3dd707afce80c95d4de6f"), JWTSecret: getEnv("JWT_SECRET", "58b9340c0472cf045db226bc445966524e780cd38bc3dd707afce80c95d4de6f"),
ListenPort: getEnv("LISTEN_PORT", "8484"), ListenPort: getEnv("LISTEN_PORT", "8484"),
} }
} }
func getEnv(key, fallback string) string { func getEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok { if value, ok := os.LookupEnv(key); ok {
return value return value
} }
return fallback return fallback
} }

View File

@ -69,6 +69,7 @@ func (dbm *DBManager) bootstrapDatabase() {
} }
} }
// TODO
func (dbm *DBManager) QueryBuilder(dest interface{}, params []byte) (int64, error) { func (dbm *DBManager) QueryBuilder(dest interface{}, params []byte) (int64, error) {
// TODO: // TODO:
// - Where Filters // - Where Filters

View File

@ -1,7 +0,0 @@
package db
import "errors"
var (
ErrUserAlreadyExists = errors.New("user already exists")
)

View File

@ -1,38 +1,40 @@
package session package session
import ( import (
"sync" "sync"
) )
// Used to maintain a cache of user specific jwt secrets // Used to maintain a cache of user specific jwt secrets
// This will prevent DB lookups on every request // This will prevent DB lookups on every request
// May not actually be needed. Refresh Token is the only
// token that will require proactive DB lookups.
type SessionManager struct { type SessionManager struct {
mutex sync.Mutex mutex sync.Mutex
values map[string]string values map[string]string
} }
func NewMgr() *SessionManager { func NewMgr() *SessionManager {
return &SessionManager{} return &SessionManager{}
} }
func (sm *SessionManager) Set(key, value string) { func (sm *SessionManager) Set(key, value string) {
sm.mutex.Lock() sm.mutex.Lock()
sm.values[key] = value sm.values[key] = value
sm.mutex.Unlock() sm.mutex.Unlock()
} }
func (sm *SessionManager) Get(key string) string { func (sm *SessionManager) Get(key string) string {
sm.mutex.Lock() sm.mutex.Lock()
defer sm.mutex.Unlock() defer sm.mutex.Unlock()
return sm.values[key] return sm.values[key]
} }
func (sm *SessionManager) Delete(key string) { func (sm *SessionManager) Delete(key string) {
sm.mutex.Lock() sm.mutex.Lock()
defer sm.mutex.Unlock() defer sm.mutex.Unlock()
_, exists := sm.values[key] _, exists := sm.values[key]
if !exists { if !exists {
return return
} }
delete(sm.values, key) delete(sm.values, key)
} }

View File

@ -42,7 +42,7 @@ type Field struct {
Name string Name string
Type types.Type Type types.Type
Tag string Tag string
Gorm string Gorm string
} }
type Enum struct { type Enum struct {
@ -163,21 +163,21 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
typ = types.NewPointer(typ) typ = types.NewPointer(typ)
} }
gormType := "" gormType := ""
directive := field.Directives.ForName("meta") directive := field.Directives.ForName("meta")
if directive != nil { if directive != nil {
arg := directive.Arguments.ForName("gorm") arg := directive.Arguments.ForName("gorm")
if arg != nil { if arg != nil {
gormType = fmt.Sprintf("gorm:\"%s\"", arg.Value.Raw) gormType = fmt.Sprintf("gorm:\"%s\"", arg.Value.Raw)
} }
} }
it.Fields = append(it.Fields, &Field{ it.Fields = append(it.Fields, &Field{
Name: name, Name: name,
Type: typ, Type: typ,
Description: field.Description, Description: field.Description,
Tag: `json:"` + field.Name + `"`, Tag: `json:"` + field.Name + `"`,
Gorm: gormType, Gorm: gormType,
}) })
} }