188 lines
5.2 KiB
Go
188 lines
5.2 KiB
Go
package auth
|
|
|
|
import (
|
|
"fmt"
|
|
"time"
|
|
"errors"
|
|
"encoding/json"
|
|
|
|
"gorm.io/gorm"
|
|
"github.com/google/uuid"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/lestrrat-go/jwx/jwa"
|
|
"github.com/lestrrat-go/jwx/jwt"
|
|
|
|
"reichard.io/imagini/internal/db"
|
|
"reichard.io/imagini/internal/config"
|
|
|
|
graphql "reichard.io/imagini/graph/model"
|
|
"reichard.io/imagini/internal/models"
|
|
"reichard.io/imagini/internal/session"
|
|
)
|
|
|
|
type AuthManager struct {
|
|
DB *db.DBManager
|
|
Config *config.Config
|
|
Session *session.SessionManager
|
|
}
|
|
|
|
func NewMgr(db *db.DBManager, c *config.Config) *AuthManager {
|
|
session := session.NewMgr()
|
|
return &AuthManager{
|
|
DB: db,
|
|
Config: c,
|
|
Session: session,
|
|
}
|
|
}
|
|
|
|
func (auth *AuthManager) AuthenticateUser(creds models.APICredentials) (bool, graphql.User) {
|
|
// Search Objects
|
|
userByName := &graphql.User{}
|
|
userByName.Username = creds.User
|
|
|
|
foundUser, err := auth.DB.User(userByName)
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
userByEmail := &graphql.User{}
|
|
userByEmail.Email = creds.User
|
|
foundUser, err = auth.DB.User(userByEmail)
|
|
}
|
|
|
|
// Error Checking
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
log.Warn("[auth] User not found: ", creds.User)
|
|
return false, foundUser
|
|
} else if err != nil {
|
|
log.Error(err)
|
|
return false, foundUser
|
|
}
|
|
|
|
log.Info("[auth] Authenticating user: ", foundUser.Username)
|
|
|
|
// Determine Type
|
|
switch foundUser.AuthType {
|
|
case "Local":
|
|
return authenticateLocalUser(foundUser, creds.Password), foundUser
|
|
case "LDAP":
|
|
return authenticateLDAPUser(foundUser, creds.Password), foundUser
|
|
default:
|
|
return false, foundUser
|
|
}
|
|
}
|
|
|
|
func (auth *AuthManager) getRole(user graphql.User) string {
|
|
// TODO: Lookup role of user
|
|
return "User"
|
|
}
|
|
|
|
func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, error) {
|
|
byteRefreshJWT := []byte(refreshJWT)
|
|
|
|
// Acquire Relevant Device
|
|
unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT)
|
|
did, ok := unverifiedToken.Get("did")
|
|
if !ok {
|
|
return nil, errors.New("did does not exist")
|
|
}
|
|
deviceID, err := uuid.Parse(fmt.Sprintf("%v", did))
|
|
if err != nil {
|
|
return nil, errors.New("did does not parse")
|
|
}
|
|
stringDeviceID := deviceID.String()
|
|
device, err := auth.DB.Device(&graphql.Device{ID: &stringDeviceID})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Verify & Validate Token
|
|
verifiedToken, err := jwt.ParseBytes(byteRefreshJWT,
|
|
jwt.WithValidate(true),
|
|
jwt.WithVerify(jwa.HS256, []byte(*device.RefreshKey)),
|
|
)
|
|
if err != nil {
|
|
fmt.Println("failed to parse payload: ", err)
|
|
return nil, err
|
|
}
|
|
return verifiedToken, nil
|
|
}
|
|
|
|
func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, error) {
|
|
byteAccessJWT := []byte(accessJWT)
|
|
verifiedToken, err := jwt.ParseBytes(byteAccessJWT,
|
|
jwt.WithValidate(true),
|
|
jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)),
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return verifiedToken, nil
|
|
}
|
|
|
|
func (auth *AuthManager) CreateJWTRefreshToken(user graphql.User, device graphql.Device) (string, error) {
|
|
// Acquire Refresh Key
|
|
byteKey := []byte(*device.RefreshKey)
|
|
|
|
// Create New Token
|
|
tm := time.Now()
|
|
t := jwt.New()
|
|
t.Set(`did`, device.ID) // Device ID
|
|
t.Set(jwt.SubjectKey, user.ID) // User ID
|
|
t.Set(jwt.AudienceKey, `imagini`) // App ID
|
|
t.Set(jwt.IssuedAtKey, tm) // Issued At
|
|
|
|
// iOS & Android = Never Expiring Refresh Token
|
|
if device.Type != "iOS" && device.Type != "Android" {
|
|
t.Set(jwt.ExpirationKey, tm.Add(time.Hour * 24)) // 1 Day Access Key
|
|
}
|
|
|
|
// Validate Token Creation
|
|
_, err := json.MarshalIndent(t, "", " ")
|
|
if err != nil {
|
|
fmt.Printf("failed to generate JSON: %s\n", err)
|
|
return "", err
|
|
}
|
|
|
|
// Sign Token
|
|
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
|
|
if err != nil {
|
|
log.Printf("failed to sign token: %s", err)
|
|
return "", err
|
|
}
|
|
|
|
// Return Token
|
|
return string(signed), nil
|
|
}
|
|
|
|
func (auth *AuthManager) CreateJWTAccessToken(user graphql.User, device graphql.Device) (string, error) {
|
|
// Create New Token
|
|
tm := time.Now()
|
|
t := jwt.New()
|
|
t.Set(`did`, device.ID) // Device ID
|
|
t.Set(`role`, auth.getRole(user)) // User Role (Admin / User)
|
|
t.Set(jwt.SubjectKey, user.ID) // User ID
|
|
t.Set(jwt.AudienceKey, `imagini`) // App ID
|
|
t.Set(jwt.IssuedAtKey, tm) // Issued At
|
|
t.Set(jwt.ExpirationKey, tm.Add(time.Hour * 2)) // 2 Hour Access Key
|
|
|
|
// Validate Token Creation
|
|
_, err := json.MarshalIndent(t, "", " ")
|
|
if err != nil {
|
|
fmt.Printf("failed to generate JSON: %s\n", err)
|
|
return "", err
|
|
}
|
|
|
|
// Use Server Key
|
|
byteKey := []byte(auth.Config.JWTSecret)
|
|
|
|
// Sign Token
|
|
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
|
|
if err != nil {
|
|
log.Printf("failed to sign token: %s", err)
|
|
return "", err
|
|
}
|
|
|
|
// Return Token
|
|
return string(signed), nil
|
|
}
|