diff --git a/cmd/imagini.db b/cmd/imagini.db new file mode 100644 index 0000000..ecb2d8f Binary files /dev/null and b/cmd/imagini.db differ diff --git a/internal/api/auth.go b/internal/api/auth.go index b9c276e..cee9d33 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -1,13 +1,16 @@ package api import ( + "fmt" "time" - "encoding/json" + "strings" "net/http" + "encoding/json" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/lestrrat-go/jwx/jwt" + "reichard.io/imagini/internal/models" - // "github.com/lestrrat-go/jwx/jwt" - // "github.com/lestrrat-go/jwx/jwa" - // log "github.com/sirupsen/logrus" ) // https://www.calhoun.io/pitfalls-of-context-values-and-how-to-avoid-or-mitigate-them/ @@ -35,20 +38,51 @@ func (api *API) loginHandler(w http.ResponseWriter, r *http.Request) { return } + // Verify Device Name Exists + deviceHeader := r.Header.Get("X-Imagini-DeviceName") + if deviceHeader == "" { + errorJSON(w, "Missing 'X-Imagini-DeviceName' header.", http.StatusBadRequest) + return + } + + // Derive Device Type + var deviceType string + userAgent := strings.ToLower(r.Header.Get("User-Agent")) + if strings.HasPrefix(userAgent, "ios-imagini"){ + deviceType = "iOS" + } else if strings.HasPrefix(userAgent, "android-imagini"){ + deviceType = "Android" + } else if strings.HasPrefix(userAgent, "chrome"){ + deviceType = "Chrome" + } else if strings.HasPrefix(userAgent, "firefox"){ + deviceType = "Firefox" + } else if strings.HasPrefix(userAgent, "msie"){ + deviceType = "Internet Explorer" + } else if strings.HasPrefix(userAgent, "edge"){ + deviceType = "Edge" + } else if strings.HasPrefix(userAgent, "safari"){ + deviceType = "Safari" + }else { + deviceType = "Unknown" + } + // Do login - resp := api.Auth.AuthenticateUser(creds) + resp, user := api.Auth.AuthenticateUser(creds) if !resp { errorJSON(w, "Invalid credentials.", http.StatusUnauthorized) return } - // Create tokens - accessToken := api.Auth.CreateJWTAccessToken() - refreshToken := api.Auth.CreateRefreshToken() + // Create New Device + device, err := api.DB.CreateDevice(models.Device{Name: deviceHeader, Type: deviceType}) + + // Create Tokens + accessToken, err := api.Auth.CreateJWTAccessToken(user, device) + refreshToken, err := api.Auth.CreateJWTRefreshToken(user, device) // Set appropriate cookies - accessCookie := http.Cookie{Name: "AccessToken", Value: accessToken} - refreshCookie := http.Cookie{Name: "RefreshToken", Value: refreshToken} + accessCookie := http.Cookie{Name: "AccessToken", Value: accessToken, HttpOnly: true} + refreshCookie := http.Cookie{Name: "RefreshToken", Value: refreshToken, HttpOnly: true} http.SetCookie(w, &accessCookie) http.SetCookie(w, &refreshCookie) @@ -78,15 +112,50 @@ func (api *API) logoutHandler(w http.ResponseWriter, r *http.Request) { } func (api *API) refreshLoginHandler(w http.ResponseWriter, r *http.Request) { - ok := api.Auth.ValidateRefreshToken() + refreshCookie, err := r.Cookie("RefreshToken") + if err != nil { + log.Warn("[middleware] Cookie not found") + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Validate Refresh Token + refreshToken, ok := api.Auth.ValidateJWTRefreshToken(refreshCookie.Value) if !ok { - // TODO: Clear Access & Refresh Cookies + http.SetCookie(w, &http.Cookie{Name: "AccessToken", Expires: time.Unix(0, 0)}) + http.SetCookie(w, &http.Cookie{Name: "RefreshToken", Expires: time.Unix(0, 0)}) errorJSON(w, "Invalid credentials.", http.StatusUnauthorized) return } + // Acquire User & Device (Trusted) + did, ok := refreshToken.Get("did") + if !ok { + errorJSON(w, "Invalid credentials.", http.StatusUnauthorized) + return + } + uid, ok := refreshToken.Get(jwt.SubjectKey) + if !ok { + errorJSON(w, "Invalid credentials.", http.StatusUnauthorized) + return + } + deviceID, err := uuid.Parse(fmt.Sprintf("%v", did)) + if err != nil { + errorJSON(w, "Invalid credentials.", http.StatusUnauthorized) + return + } + userID, err := uuid.Parse(fmt.Sprintf("%v", uid)) + if err != nil { + errorJSON(w, "Invalid credentials.", http.StatusUnauthorized) + return + } + + // Device Skeleton + user := models.User{Base: models.Base{UUID: userID}} + device := models.Device{Base: models.Base{UUID: deviceID}} + // Update token - accessToken := api.Auth.CreateJWTAccessToken() + accessToken, err := api.Auth.CreateJWTAccessToken(user, device) accessCookie := http.Cookie{Name: "AccessToken", Value: accessToken} http.SetCookie(w, &accessCookie) diff --git a/internal/api/middlewares.go b/internal/api/middlewares.go index 5ffbcb2..41f3847 100644 --- a/internal/api/middlewares.go +++ b/internal/api/middlewares.go @@ -22,27 +22,22 @@ func multipleMiddleware(h http.HandlerFunc, m ...Middleware) http.HandlerFunc { func (api *API) authMiddleware(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - cookie, err := r.Cookie("Token") + // Acquire Token + accessCookie, err := r.Cookie("AccessToken") if err != nil { - log.Warn("[middleware] Cookie not found") + log.Warn("[middleware] AccessToken not found") w.WriteHeader(http.StatusUnauthorized) return } - // Validate cookie.Value JWT with - api.Auth.ValidateJWTToken(cookie.Value) + // Validate JWT Tokens + _, accessOK := api.Auth.ValidateJWTAccessToken(accessCookie.Value) - - log.Info("[middleware] Cookie Name: ", cookie.Name) - log.Info("[middleware] Cookie Value: ", cookie.Value) - - next.ServeHTTP(w, r) - - // if true { - // next.ServeHTTP(w, r) - // } else { - // w.WriteHeader(http.StatusUnauthorized) - // } + if accessOK { + next.ServeHTTP(w, r) + } else { + w.WriteHeader(http.StatusUnauthorized) + } }) } diff --git a/internal/api/users.go b/internal/api/users.go index 0220829..0aa6758 100644 --- a/internal/api/users.go +++ b/internal/api/users.go @@ -27,13 +27,4 @@ func (api *API) meHandler(w http.ResponseWriter, r *http.Request) { errorJSON(w, "Method is not supported.", http.StatusMethodNotAllowed) return } - - // Get Authenticated User & Return Object - authCookie, err := r.Cookie("Token") - if err != nil { - log.Error("[api] ", err) - return - } - - log.Info("[api] Auth Cookie: ", authCookie) } diff --git a/internal/auth/auth.go b/internal/auth/auth.go index a6fce77..1845972 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/google/uuid" - "golang.org/x/crypto/bcrypt" log "github.com/sirupsen/logrus" "gorm.io/gorm" "reichard.io/imagini/internal/db" @@ -35,7 +34,7 @@ func NewMgr(db *db.DBManager, c *config.Config) *AuthManager { } } -func (auth *AuthManager) AuthenticateUser(creds models.APICredentials) bool { +func (auth *AuthManager) AuthenticateUser(creds models.APICredentials) (bool, models.User) { // By Username foundUser, err := auth.DB.User(models.User{Username: creds.User}) if errors.Is(err, gorm.ErrRecordNotFound) { @@ -45,10 +44,10 @@ func (auth *AuthManager) AuthenticateUser(creds models.APICredentials) bool { // Error Checking if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn("[auth] User not found: ", creds.User) - return false + return false, foundUser } else if err != nil { log.Error(err) - return false + return false, foundUser } log.Info("[auth] Authenticating user: ", foundUser.Username) @@ -56,76 +55,59 @@ func (auth *AuthManager) AuthenticateUser(creds models.APICredentials) bool { // Determine Type switch foundUser.AuthType { case "Local": - return authenticateLocalUser(foundUser, creds.Password) + return authenticateLocalUser(foundUser, creds.Password), foundUser case "LDAP": - return authenticateLDAPUser(foundUser, creds.Password) + return authenticateLDAPUser(foundUser, creds.Password), foundUser default: - return false + return false, foundUser } } -func (auth *AuthManager) ValidateJWTToken(userJWT string) bool { - byteUserJWT := []byte(userJWT) - serverToken, err := jwt.ParseBytes(byteUserJWT, jwt.WithVerify(jwa.HS256, auth.Config.JWTSecret)) +func (auth *AuthManager) getRole(user models.User) string { + // TODO: Lookup role of user + return "User" +} + +func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, bool) { + byteAccessJWT := []byte(accessJWT) + verifiedToken, err := jwt.ParseBytes(byteAccessJWT, jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret))) if err != nil { fmt.Println("failed to parse payload: ", err) + return nil, false } - - uid, ok := serverToken.Get("uid"); - if !ok { - fmt.Println("failed to acquire uid") - } - - userID := fmt.Sprintf("%v", uid) - userKey := auth.Session.Get(userID) - - userToken, err := jwt.ParseBytes(byteUserJWT, jwt.WithVerify(jwa.HS256, userKey)) - if err != nil { - fmt.Println("failed to parse payload: ", err) - } - _ = userToken - - // TODO: - // - Get User ID from UNVALIDATED token - // - Lookup user key, concat with server key - // - Validate with concatted user & server key - // validatedToken, err := jwt.ParseBytes(byteUserJWT, jwt.WithVerify(jwa.HS256, concatKey)) - // if err != nil { - // fmt.Printf("failed to parse payload: %s\n", err) - // } - - // userToken := auth.Session.Get(userID) - // log.Info("[auth] DEBUG: ", userToken) - - return false + return verifiedToken, true } func (auth *AuthManager) RevokeRefreshToken() { } -func (auth *AuthManager) ValidateRefreshToken(refreshToken, deviceID string) bool { - // Acquire Device - deviceUUID, err := uuid.Parse(deviceID) - device := models.Device{Base: models.Base{UUID: deviceUUID}} - foundDevice, err := auth.DB.Device(device) +func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, bool) { + byteRefreshJWT := []byte(refreshJWT) - // Validate Expiration - expTime, err := time.Parse(time.RFC3339, foundDevice.RefreshExp) - if expTime.Before(time.Now()) { - return false + // Acquire Relevant Device + unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT) + did, ok := unverifiedToken.Get("did") + if !ok { + return nil, false + } + deviceID, err := uuid.Parse(fmt.Sprintf("%v", did)) + if err != nil { + return nil, false + } + device, err := auth.DB.Device(models.Device{Base: models.Base{UUID: deviceID}}) + if err != nil { + return nil, false } - // Validate Token - bRefreshToken :=[]byte(refreshToken) - err = bcrypt.CompareHashAndPassword([]byte(foundDevice.RefreshToken), bRefreshToken) - if err == nil { - log.Info("[auth] Refresh Token validation succeeded: ", foundDevice.UUID) - return true + // Verify Token + verifiedToken, err := jwt.ParseBytes(byteRefreshJWT, jwt.WithVerify(jwa.HS256, []byte(device.RefreshKey))) + if err != nil { + fmt.Println("failed to parse payload: ", err) + return nil, false } - log.Warn("[auth] Refresh Token validation failed: ", foundDevice.UUID) - return false + return verifiedToken, true } func (auth *AuthManager) UpdateRefreshToken(deviceID string) error { @@ -135,30 +117,49 @@ func (auth *AuthManager) UpdateRefreshToken(deviceID string) error { return nil } -func (auth *AuthManager) CreateRefreshToken(deviceID string) (string, error) { - // TODO: - // - Create regular bcrypt password - // - Create Expiration (Depends on Device Type) - // - Store in DB: DeviceID, ValidUntil +func (auth *AuthManager) CreateJWTRefreshToken(user models.User, device models.Device) (string, error) { + // Acquire Refresh Key + byteKey := []byte(device.RefreshKey) - generatedToken := uuid.New().String() - hashedRefreshToken, err := bcrypt.GenerateFromPassword([]byte(generatedToken), bcrypt.DefaultCost) - if err != nil { - log.Error(err) - return "", err - } - _ = string(hashedRefreshToken) - - return "", nil -} - -func (auth *AuthManager) CreateJWTAccessToken(user, role, deviceID string) (string, error) { // Create New Token tm := time.Now() t := jwt.New() - t.Set(`did`, deviceID) // Device ID + t.Set(`did`, device.UUID) // Device ID + t.Set(jwt.SubjectKey, user.UUID) // User ID + t.Set(jwt.AudienceKey, `imagini`) // App ID + t.Set(jwt.IssuedAtKey, tm) // Issued At + + // TODO: Depends on Device + t.Set(jwt.ExpirationKey, tm.Add(time.Hour * 24)) // 1 Day Access Key + + // Validate Token Creation + _, err := json.MarshalIndent(t, "", " ") + if err != nil { + fmt.Printf("failed to generate JSON: %s\n", err) + return "", err + } + + // Sign Token + signed, err := jwt.Sign(t, jwa.HS256, byteKey) + if err != nil { + log.Printf("failed to sign token: %s", err) + return "", err + } + + // Return Token + return string(signed), nil +} + +func (auth *AuthManager) CreateJWTAccessToken(user models.User, device models.Device) (string, error) { + // Acquire Role + role := auth.getRole(user) + + // Create New Token + tm := time.Now() + t := jwt.New() + t.Set(`did`, device.UUID) // Device ID t.Set(`role`, role) // User Role (Admin / User) - t.Set(jwt.SubjectKey, user) // User ID + t.Set(jwt.SubjectKey, user.UUID) // User ID t.Set(jwt.AudienceKey, `imagini`) // App ID t.Set(jwt.IssuedAtKey, tm) // Issued At t.Set(jwt.ExpirationKey, tm.Add(time.Minute * 30)) // 30 Minute Access Key diff --git a/internal/db/db.go b/internal/db/db.go index 8963b9a..62b6a6c 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -4,7 +4,7 @@ import ( "path" "gorm.io/gorm" - "gorm.io/gorm/logger" + // "gorm.io/gorm/logger" "gorm.io/driver/sqlite" log "github.com/sirupsen/logrus" @@ -19,7 +19,7 @@ type DBManager struct { func NewMgr(c *config.Config) *DBManager { gormConfig := &gorm.Config{ PrepareStmt: true, - Logger: logger.Default.LogMode(logger.Silent), + // Logger: logger.Default.LogMode(logger.Silent), } // Create manager @@ -52,7 +52,7 @@ func NewMgr(c *config.Config) *DBManager { func (dbm *DBManager) bootstrapDatabase() { log.Info("[query] Bootstrapping database.") - err := dbm.CreateUser(models.User{ + _, err := dbm.CreateUser(models.User{ Username: "admin", Password: "admin", AuthType: "Local", diff --git a/internal/db/devices.go b/internal/db/devices.go index 3d8728d..2fada0b 100644 --- a/internal/db/devices.go +++ b/internal/db/devices.go @@ -1,31 +1,17 @@ package db import ( - "errors" - "gorm.io/gorm" - "golang.org/x/crypto/bcrypt" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "reichard.io/imagini/internal/models" ) -func (dbm *DBManager) CreateDevice(device models.Device) error { +func (dbm *DBManager) CreateDevice(device models.Device) (models.Device, error) { log.Info("[query] Creating device: ", device.Name) - _, err := dbm.Device(device) - if !errors.Is(err, gorm.ErrRecordNotFound) { - log.Warn("[query] Device already exists: ", device.Name) - return errors.New("Device already exists") - } - - // Generate random password - refreshToken := "asd123" - hashedToken, err := bcrypt.GenerateFromPassword([]byte(refreshToken), bcrypt.DefaultCost) - if err != nil { - log.Error(err) - return err - } - device.RefreshToken = string(hashedToken) - return dbm.db.Create(&device).Error + device.RefreshKey = uuid.New().String() + err := dbm.db.Create(&device).Error + return device, err } func (dbm *DBManager) Device (device models.Device) (models.Device, error) { diff --git a/internal/db/users.go b/internal/db/users.go index 3ef2c64..77d35be 100644 --- a/internal/db/users.go +++ b/internal/db/users.go @@ -1,29 +1,22 @@ package db import ( - "errors" - "gorm.io/gorm" "golang.org/x/crypto/bcrypt" log "github.com/sirupsen/logrus" "reichard.io/imagini/internal/models" ) -func (dbm *DBManager) CreateUser(user models.User) error { +func (dbm *DBManager) CreateUser(user models.User) (models.User, error) { log.Info("[query] Creating user: ", user.Username) - _, err := dbm.User(user) - if !errors.Is(err, gorm.ErrRecordNotFound) { - log.Warn("[query] User already exists: ", user.Username) - return errors.New("User already exists") - } - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) if err != nil { log.Error(err) - return err + return user, err } user.Password = string(hashedPassword) - return dbm.db.Create(&user).Error + err = dbm.db.Create(&user).Error + return user, err } func (dbm *DBManager) User (user models.User) (models.User, error) { diff --git a/internal/models/db.go b/internal/models/db.go index b6c66e1..bd02bc2 100644 --- a/internal/models/db.go +++ b/internal/models/db.go @@ -8,7 +8,7 @@ import ( // Base contains common columns for all tables. type Base struct { - UUID uuid.UUID `gorm:"type:uuid;default:default:uuid_generate_v4();primarykey"` + UUID uuid.UUID `gorm:"type:uuid;primarykey"` CreatedAt time.Time UpdatedAt time.Time DeletedAt gorm.DeletedAt `gorm:"index"` @@ -31,8 +31,7 @@ type Device struct { User User `json:"user" gorm:"ForeignKey:UUID"` Name string `json:"name"` Type string `json:"type"` // Android, iOS, Chrome, FireFox, Edge, etc - RefreshExp string `json:"refresh_exp"` - RefreshToken string `json:"-"` + RefreshKey string `json:"-"` } type User struct {