Basic Access & Refresh Token

This commit is contained in:
Evan Reichard 2021-01-18 16:16:52 -05:00
parent 377903f7a1
commit b05b1eb9b6
9 changed files with 181 additions and 147 deletions

BIN
cmd/imagini.db Normal file

Binary file not shown.

View File

@ -1,13 +1,16 @@
package api package api
import ( import (
"fmt"
"time" "time"
"encoding/json" "strings"
"net/http" "net/http"
"encoding/json"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/lestrrat-go/jwx/jwt"
"reichard.io/imagini/internal/models" "reichard.io/imagini/internal/models"
// "github.com/lestrrat-go/jwx/jwt"
// "github.com/lestrrat-go/jwx/jwa"
// log "github.com/sirupsen/logrus"
) )
// https://www.calhoun.io/pitfalls-of-context-values-and-how-to-avoid-or-mitigate-them/ // https://www.calhoun.io/pitfalls-of-context-values-and-how-to-avoid-or-mitigate-them/
@ -35,20 +38,51 @@ func (api *API) loginHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
// Verify Device Name Exists
deviceHeader := r.Header.Get("X-Imagini-DeviceName")
if deviceHeader == "" {
errorJSON(w, "Missing 'X-Imagini-DeviceName' header.", http.StatusBadRequest)
return
}
// Derive Device Type
var deviceType string
userAgent := strings.ToLower(r.Header.Get("User-Agent"))
if strings.HasPrefix(userAgent, "ios-imagini"){
deviceType = "iOS"
} else if strings.HasPrefix(userAgent, "android-imagini"){
deviceType = "Android"
} else if strings.HasPrefix(userAgent, "chrome"){
deviceType = "Chrome"
} else if strings.HasPrefix(userAgent, "firefox"){
deviceType = "Firefox"
} else if strings.HasPrefix(userAgent, "msie"){
deviceType = "Internet Explorer"
} else if strings.HasPrefix(userAgent, "edge"){
deviceType = "Edge"
} else if strings.HasPrefix(userAgent, "safari"){
deviceType = "Safari"
}else {
deviceType = "Unknown"
}
// Do login // Do login
resp := api.Auth.AuthenticateUser(creds) resp, user := api.Auth.AuthenticateUser(creds)
if !resp { if !resp {
errorJSON(w, "Invalid credentials.", http.StatusUnauthorized) errorJSON(w, "Invalid credentials.", http.StatusUnauthorized)
return return
} }
// Create tokens // Create New Device
accessToken := api.Auth.CreateJWTAccessToken() device, err := api.DB.CreateDevice(models.Device{Name: deviceHeader, Type: deviceType})
refreshToken := api.Auth.CreateRefreshToken()
// Create Tokens
accessToken, err := api.Auth.CreateJWTAccessToken(user, device)
refreshToken, err := api.Auth.CreateJWTRefreshToken(user, device)
// Set appropriate cookies // Set appropriate cookies
accessCookie := http.Cookie{Name: "AccessToken", Value: accessToken} accessCookie := http.Cookie{Name: "AccessToken", Value: accessToken, HttpOnly: true}
refreshCookie := http.Cookie{Name: "RefreshToken", Value: refreshToken} refreshCookie := http.Cookie{Name: "RefreshToken", Value: refreshToken, HttpOnly: true}
http.SetCookie(w, &accessCookie) http.SetCookie(w, &accessCookie)
http.SetCookie(w, &refreshCookie) http.SetCookie(w, &refreshCookie)
@ -78,15 +112,50 @@ func (api *API) logoutHandler(w http.ResponseWriter, r *http.Request) {
} }
func (api *API) refreshLoginHandler(w http.ResponseWriter, r *http.Request) { func (api *API) refreshLoginHandler(w http.ResponseWriter, r *http.Request) {
ok := api.Auth.ValidateRefreshToken() refreshCookie, err := r.Cookie("RefreshToken")
if err != nil {
log.Warn("[middleware] Cookie not found")
w.WriteHeader(http.StatusUnauthorized)
return
}
// Validate Refresh Token
refreshToken, ok := api.Auth.ValidateJWTRefreshToken(refreshCookie.Value)
if !ok { if !ok {
// TODO: Clear Access & Refresh Cookies http.SetCookie(w, &http.Cookie{Name: "AccessToken", Expires: time.Unix(0, 0)})
http.SetCookie(w, &http.Cookie{Name: "RefreshToken", Expires: time.Unix(0, 0)})
errorJSON(w, "Invalid credentials.", http.StatusUnauthorized) errorJSON(w, "Invalid credentials.", http.StatusUnauthorized)
return return
} }
// Acquire User & Device (Trusted)
did, ok := refreshToken.Get("did")
if !ok {
errorJSON(w, "Invalid credentials.", http.StatusUnauthorized)
return
}
uid, ok := refreshToken.Get(jwt.SubjectKey)
if !ok {
errorJSON(w, "Invalid credentials.", http.StatusUnauthorized)
return
}
deviceID, err := uuid.Parse(fmt.Sprintf("%v", did))
if err != nil {
errorJSON(w, "Invalid credentials.", http.StatusUnauthorized)
return
}
userID, err := uuid.Parse(fmt.Sprintf("%v", uid))
if err != nil {
errorJSON(w, "Invalid credentials.", http.StatusUnauthorized)
return
}
// Device Skeleton
user := models.User{Base: models.Base{UUID: userID}}
device := models.Device{Base: models.Base{UUID: deviceID}}
// Update token // Update token
accessToken := api.Auth.CreateJWTAccessToken() accessToken, err := api.Auth.CreateJWTAccessToken(user, device)
accessCookie := http.Cookie{Name: "AccessToken", Value: accessToken} accessCookie := http.Cookie{Name: "AccessToken", Value: accessToken}
http.SetCookie(w, &accessCookie) http.SetCookie(w, &accessCookie)

View File

@ -22,27 +22,22 @@ func multipleMiddleware(h http.HandlerFunc, m ...Middleware) http.HandlerFunc {
func (api *API) authMiddleware(next http.Handler) http.HandlerFunc { func (api *API) authMiddleware(next http.Handler) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("Token") // Acquire Token
accessCookie, err := r.Cookie("AccessToken")
if err != nil { if err != nil {
log.Warn("[middleware] Cookie not found") log.Warn("[middleware] AccessToken not found")
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
// Validate cookie.Value JWT with // Validate JWT Tokens
api.Auth.ValidateJWTToken(cookie.Value) _, accessOK := api.Auth.ValidateJWTAccessToken(accessCookie.Value)
log.Info("[middleware] Cookie Name: ", cookie.Name)
log.Info("[middleware] Cookie Value: ", cookie.Value)
if accessOK {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} else {
// if true { w.WriteHeader(http.StatusUnauthorized)
// next.ServeHTTP(w, r) }
// } else {
// w.WriteHeader(http.StatusUnauthorized)
// }
}) })
} }

View File

@ -27,13 +27,4 @@ func (api *API) meHandler(w http.ResponseWriter, r *http.Request) {
errorJSON(w, "Method is not supported.", http.StatusMethodNotAllowed) errorJSON(w, "Method is not supported.", http.StatusMethodNotAllowed)
return return
} }
// Get Authenticated User & Return Object
authCookie, err := r.Cookie("Token")
if err != nil {
log.Error("[api] ", err)
return
}
log.Info("[api] Auth Cookie: ", authCookie)
} }

View File

@ -4,7 +4,6 @@ import (
"errors" "errors"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
"reichard.io/imagini/internal/db" "reichard.io/imagini/internal/db"
@ -35,7 +34,7 @@ func NewMgr(db *db.DBManager, c *config.Config) *AuthManager {
} }
} }
func (auth *AuthManager) AuthenticateUser(creds models.APICredentials) bool { func (auth *AuthManager) AuthenticateUser(creds models.APICredentials) (bool, models.User) {
// By Username // By Username
foundUser, err := auth.DB.User(models.User{Username: creds.User}) foundUser, err := auth.DB.User(models.User{Username: creds.User})
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
@ -45,10 +44,10 @@ func (auth *AuthManager) AuthenticateUser(creds models.APICredentials) bool {
// Error Checking // Error Checking
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn("[auth] User not found: ", creds.User) log.Warn("[auth] User not found: ", creds.User)
return false return false, foundUser
} else if err != nil { } else if err != nil {
log.Error(err) log.Error(err)
return false return false, foundUser
} }
log.Info("[auth] Authenticating user: ", foundUser.Username) log.Info("[auth] Authenticating user: ", foundUser.Username)
@ -56,76 +55,59 @@ func (auth *AuthManager) AuthenticateUser(creds models.APICredentials) bool {
// Determine Type // Determine Type
switch foundUser.AuthType { switch foundUser.AuthType {
case "Local": case "Local":
return authenticateLocalUser(foundUser, creds.Password) return authenticateLocalUser(foundUser, creds.Password), foundUser
case "LDAP": case "LDAP":
return authenticateLDAPUser(foundUser, creds.Password) return authenticateLDAPUser(foundUser, creds.Password), foundUser
default: default:
return false return false, foundUser
} }
} }
func (auth *AuthManager) ValidateJWTToken(userJWT string) bool {
byteUserJWT := []byte(userJWT)
serverToken, err := jwt.ParseBytes(byteUserJWT, jwt.WithVerify(jwa.HS256, auth.Config.JWTSecret)) func (auth *AuthManager) getRole(user models.User) string {
// TODO: Lookup role of user
return "User"
}
func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, bool) {
byteAccessJWT := []byte(accessJWT)
verifiedToken, err := jwt.ParseBytes(byteAccessJWT, jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)))
if err != nil { if err != nil {
fmt.Println("failed to parse payload: ", err) fmt.Println("failed to parse payload: ", err)
return nil, false
} }
return verifiedToken, true
uid, ok := serverToken.Get("uid");
if !ok {
fmt.Println("failed to acquire uid")
}
userID := fmt.Sprintf("%v", uid)
userKey := auth.Session.Get(userID)
userToken, err := jwt.ParseBytes(byteUserJWT, jwt.WithVerify(jwa.HS256, userKey))
if err != nil {
fmt.Println("failed to parse payload: ", err)
}
_ = userToken
// TODO:
// - Get User ID from UNVALIDATED token
// - Lookup user key, concat with server key
// - Validate with concatted user & server key
// validatedToken, err := jwt.ParseBytes(byteUserJWT, jwt.WithVerify(jwa.HS256, concatKey))
// if err != nil {
// fmt.Printf("failed to parse payload: %s\n", err)
// }
// userToken := auth.Session.Get(userID)
// log.Info("[auth] DEBUG: ", userToken)
return false
} }
func (auth *AuthManager) RevokeRefreshToken() { func (auth *AuthManager) RevokeRefreshToken() {
} }
func (auth *AuthManager) ValidateRefreshToken(refreshToken, deviceID string) bool { func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, bool) {
// Acquire Device byteRefreshJWT := []byte(refreshJWT)
deviceUUID, err := uuid.Parse(deviceID)
device := models.Device{Base: models.Base{UUID: deviceUUID}}
foundDevice, err := auth.DB.Device(device)
// Validate Expiration // Acquire Relevant Device
expTime, err := time.Parse(time.RFC3339, foundDevice.RefreshExp) unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT)
if expTime.Before(time.Now()) { did, ok := unverifiedToken.Get("did")
return false if !ok {
return nil, false
}
deviceID, err := uuid.Parse(fmt.Sprintf("%v", did))
if err != nil {
return nil, false
}
device, err := auth.DB.Device(models.Device{Base: models.Base{UUID: deviceID}})
if err != nil {
return nil, false
} }
// Validate Token // Verify Token
bRefreshToken :=[]byte(refreshToken) verifiedToken, err := jwt.ParseBytes(byteRefreshJWT, jwt.WithVerify(jwa.HS256, []byte(device.RefreshKey)))
err = bcrypt.CompareHashAndPassword([]byte(foundDevice.RefreshToken), bRefreshToken) if err != nil {
if err == nil { fmt.Println("failed to parse payload: ", err)
log.Info("[auth] Refresh Token validation succeeded: ", foundDevice.UUID) return nil, false
return true
} }
log.Warn("[auth] Refresh Token validation failed: ", foundDevice.UUID) return verifiedToken, true
return false
} }
func (auth *AuthManager) UpdateRefreshToken(deviceID string) error { func (auth *AuthManager) UpdateRefreshToken(deviceID string) error {
@ -135,30 +117,49 @@ func (auth *AuthManager) UpdateRefreshToken(deviceID string) error {
return nil return nil
} }
func (auth *AuthManager) CreateRefreshToken(deviceID string) (string, error) { func (auth *AuthManager) CreateJWTRefreshToken(user models.User, device models.Device) (string, error) {
// TODO: // Acquire Refresh Key
// - Create regular bcrypt password byteKey := []byte(device.RefreshKey)
// - Create Expiration (Depends on Device Type)
// - Store in DB: DeviceID, ValidUntil
generatedToken := uuid.New().String()
hashedRefreshToken, err := bcrypt.GenerateFromPassword([]byte(generatedToken), bcrypt.DefaultCost)
if err != nil {
log.Error(err)
return "", err
}
_ = string(hashedRefreshToken)
return "", nil
}
func (auth *AuthManager) CreateJWTAccessToken(user, role, deviceID string) (string, error) {
// Create New Token // Create New Token
tm := time.Now() tm := time.Now()
t := jwt.New() t := jwt.New()
t.Set(`did`, deviceID) // Device ID t.Set(`did`, device.UUID) // Device ID
t.Set(jwt.SubjectKey, user.UUID) // User ID
t.Set(jwt.AudienceKey, `imagini`) // App ID
t.Set(jwt.IssuedAtKey, tm) // Issued At
// TODO: Depends on Device
t.Set(jwt.ExpirationKey, tm.Add(time.Hour * 24)) // 1 Day Access Key
// Validate Token Creation
_, err := json.MarshalIndent(t, "", " ")
if err != nil {
fmt.Printf("failed to generate JSON: %s\n", err)
return "", err
}
// Sign Token
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
if err != nil {
log.Printf("failed to sign token: %s", err)
return "", err
}
// Return Token
return string(signed), nil
}
func (auth *AuthManager) CreateJWTAccessToken(user models.User, device models.Device) (string, error) {
// Acquire Role
role := auth.getRole(user)
// Create New Token
tm := time.Now()
t := jwt.New()
t.Set(`did`, device.UUID) // Device ID
t.Set(`role`, role) // User Role (Admin / User) t.Set(`role`, role) // User Role (Admin / User)
t.Set(jwt.SubjectKey, user) // User ID t.Set(jwt.SubjectKey, user.UUID) // User ID
t.Set(jwt.AudienceKey, `imagini`) // App ID t.Set(jwt.AudienceKey, `imagini`) // App ID
t.Set(jwt.IssuedAtKey, tm) // Issued At t.Set(jwt.IssuedAtKey, tm) // Issued At
t.Set(jwt.ExpirationKey, tm.Add(time.Minute * 30)) // 30 Minute Access Key t.Set(jwt.ExpirationKey, tm.Add(time.Minute * 30)) // 30 Minute Access Key

View File

@ -4,7 +4,7 @@ import (
"path" "path"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" // "gorm.io/gorm/logger"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -19,7 +19,7 @@ type DBManager struct {
func NewMgr(c *config.Config) *DBManager { func NewMgr(c *config.Config) *DBManager {
gormConfig := &gorm.Config{ gormConfig := &gorm.Config{
PrepareStmt: true, PrepareStmt: true,
Logger: logger.Default.LogMode(logger.Silent), // Logger: logger.Default.LogMode(logger.Silent),
} }
// Create manager // Create manager
@ -52,7 +52,7 @@ func NewMgr(c *config.Config) *DBManager {
func (dbm *DBManager) bootstrapDatabase() { func (dbm *DBManager) bootstrapDatabase() {
log.Info("[query] Bootstrapping database.") log.Info("[query] Bootstrapping database.")
err := dbm.CreateUser(models.User{ _, err := dbm.CreateUser(models.User{
Username: "admin", Username: "admin",
Password: "admin", Password: "admin",
AuthType: "Local", AuthType: "Local",

View File

@ -1,31 +1,17 @@
package db package db
import ( import (
"errors" "github.com/google/uuid"
"gorm.io/gorm"
"golang.org/x/crypto/bcrypt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/imagini/internal/models" "reichard.io/imagini/internal/models"
) )
func (dbm *DBManager) CreateDevice(device models.Device) error { func (dbm *DBManager) CreateDevice(device models.Device) (models.Device, error) {
log.Info("[query] Creating device: ", device.Name) log.Info("[query] Creating device: ", device.Name)
_, err := dbm.Device(device) device.RefreshKey = uuid.New().String()
if !errors.Is(err, gorm.ErrRecordNotFound) { err := dbm.db.Create(&device).Error
log.Warn("[query] Device already exists: ", device.Name) return device, err
return errors.New("Device already exists")
}
// Generate random password
refreshToken := "asd123"
hashedToken, err := bcrypt.GenerateFromPassword([]byte(refreshToken), bcrypt.DefaultCost)
if err != nil {
log.Error(err)
return err
}
device.RefreshToken = string(hashedToken)
return dbm.db.Create(&device).Error
} }
func (dbm *DBManager) Device (device models.Device) (models.Device, error) { func (dbm *DBManager) Device (device models.Device) (models.Device, error) {

View File

@ -1,29 +1,22 @@
package db package db
import ( import (
"errors"
"gorm.io/gorm"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/imagini/internal/models" "reichard.io/imagini/internal/models"
) )
func (dbm *DBManager) CreateUser(user models.User) error { func (dbm *DBManager) CreateUser(user models.User) (models.User, error) {
log.Info("[query] Creating user: ", user.Username) log.Info("[query] Creating user: ", user.Username)
_, err := dbm.User(user)
if !errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn("[query] User already exists: ", user.Username)
return errors.New("User already exists")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return err return user, err
} }
user.Password = string(hashedPassword) user.Password = string(hashedPassword)
return dbm.db.Create(&user).Error err = dbm.db.Create(&user).Error
return user, err
} }
func (dbm *DBManager) User (user models.User) (models.User, error) { func (dbm *DBManager) User (user models.User) (models.User, error) {

View File

@ -8,7 +8,7 @@ import (
// Base contains common columns for all tables. // Base contains common columns for all tables.
type Base struct { type Base struct {
UUID uuid.UUID `gorm:"type:uuid;default:default:uuid_generate_v4();primarykey"` UUID uuid.UUID `gorm:"type:uuid;primarykey"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
@ -31,8 +31,7 @@ type Device struct {
User User `json:"user" gorm:"ForeignKey:UUID"` User User `json:"user" gorm:"ForeignKey:UUID"`
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` // Android, iOS, Chrome, FireFox, Edge, etc Type string `json:"type"` // Android, iOS, Chrome, FireFox, Edge, etc
RefreshExp string `json:"refresh_exp"` RefreshKey string `json:"-"`
RefreshToken string `json:"-"`
} }
type User struct { type User struct {