From 015ca30ac5f8d08a43cde1dbd38d38073808d692 Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Sat, 27 Jan 2024 21:02:08 -0500 Subject: [PATCH] feat(auth): add auth hash (allows purging sessions & more) --- api/api.go | 34 +++---- api/app-routes.go | 11 +-- api/auth.go | 91 ++++++++++++++++--- api/ko-routes.go | 14 ++- database/manager.go | 34 +++++-- database/manager_test.go | 13 ++- .../20240128012356_user_auth_hash.go | 91 +++++++++++++++++++ database/migrations/README.md | 4 + database/migrations/utils.go | 38 ++++++++ database/models.go | 1 + database/query.sql | 5 +- database/query.sql.go | 33 ++++--- database/schema.sql | 4 +- utils/utils.go | 11 +++ 14 files changed, 316 insertions(+), 68 deletions(-) create mode 100644 database/migrations/20240128012356_user_auth_hash.go create mode 100644 database/migrations/utils.go diff --git a/api/api.go b/api/api.go index 75b7e3a..abedddb 100644 --- a/api/api.go +++ b/api/api.go @@ -2,7 +2,6 @@ package api import ( "context" - "crypto/rand" "embed" "fmt" "html/template" @@ -20,23 +19,26 @@ import ( log "github.com/sirupsen/logrus" "reichard.io/antholume/config" "reichard.io/antholume/database" + "reichard.io/antholume/utils" ) type API struct { - db *database.DBManager - cfg *config.Config - assets *embed.FS - templates map[string]*template.Template - httpServer *http.Server + db *database.DBManager + cfg *config.Config + assets *embed.FS + httpServer *http.Server + templates map[string]*template.Template + userAuthCache map[string]string } var htmlPolicy = bluemonday.StrictPolicy() func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API { api := &API{ - db: db, - cfg: c, - assets: assets, + db: db, + cfg: c, + assets: assets, + userAuthCache: make(map[string]string), } // Create Router @@ -63,7 +65,7 @@ func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API { newToken = []byte(c.CookieAuthKey) } else { log.Info("Generating cookie auth key") - newToken, err = generateToken(64) + newToken, err = utils.GenerateToken(64) if err != nil { log.Panic("Unable to generate cookie auth key") } @@ -102,8 +104,8 @@ func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API { func (api *API) Start() error { return api.httpServer.ListenAndServe() - } + func (api *API) Stop() error { // Stop Server ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -115,7 +117,6 @@ func (api *API) Stop() error { // Close DB return api.db.DB.Close() - } func (api *API) registerWebAppRoutes(router *gin.Engine) { @@ -312,12 +313,3 @@ func apiLogger() gin.HandlerFunc { log.WithFields(logData).Info(fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path)) } } - -func generateToken(n int) ([]byte, error) { - b := make([]byte, n) - _, err := rand.Read(b) - if err != nil { - return nil, err - } - return b, nil -} diff --git a/api/app-routes.go b/api/app-routes.go index 1342e98..8cbf97b 100644 --- a/api/app-routes.go +++ b/api/app-routes.go @@ -1453,15 +1453,10 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte return } - // Close DB - err = api.db.DB.Close() - if err != nil { - appErrorPage(c, http.StatusInternalServerError, "Unable to close DB.") - log.Panic("Unable to close DB: ", err) - } - // Reinit DB - api.db.Reload() + if err := api.db.Reload(); err != nil { + log.Panicf("Unable to reload DB: %v", err) + } } func (api *API) restoreData(zipReader *zip.Reader) error { diff --git a/api/auth.go b/api/auth.go index 8fdc07e..93c5113 100644 --- a/api/auth.go +++ b/api/auth.go @@ -12,12 +12,14 @@ import ( "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" "reichard.io/antholume/database" + "reichard.io/antholume/utils" ) // Authorization Data type authData struct { UserName string IsAdmin bool + AuthHash string } // KOSync API Auth Headers @@ -41,9 +43,13 @@ func (api *API) authorizeCredentials(username string, password string) (auth *au return } + // Update Auth Cache + api.userAuthCache[user.ID] = user.AuthHash + return &authData{ UserName: user.ID, IsAdmin: user.Admin, + AuthHash: user.AuthHash, } } @@ -51,7 +57,7 @@ func (api *API) authKOMiddleware(c *gin.Context) { session := sessions.Default(c) // Check Session First - if auth, ok := getSession(session); ok == true { + if auth, ok := api.getSession(session); ok == true { c.Set("Authorization", auth) c.Header("Cache-Control", "private") c.Next() @@ -76,7 +82,7 @@ func (api *API) authKOMiddleware(c *gin.Context) { return } - if err := setSession(session, *authData); err != nil { + if err := api.setSession(session, *authData); err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) return } @@ -114,7 +120,7 @@ func (api *API) authWebAppMiddleware(c *gin.Context) { session := sessions.Default(c) // Check Session - if auth, ok := getSession(session); ok == true { + if auth, ok := api.getSession(session); ok == true { c.Set("Authorization", auth) c.Header("Cache-Control", "private") c.Next() @@ -163,7 +169,7 @@ func (api *API) appAuthFormLogin(c *gin.Context) { // Set Session session := sessions.Default(c) - if err := setSession(session, *authData); err != nil { + if err := api.setSession(session, *authData); err != nil { templateVars["Error"] = "Invalid Credentials" c.HTML(http.StatusUnauthorized, "page/login", templateVars) return @@ -199,9 +205,20 @@ func (api *API) appAuthFormRegister(c *gin.Context) { return } + // Generate Auth Hash + rawAuthHash, err := utils.GenerateToken(64) + if err != nil { + log.Error("Failed to generate user token: ", err) + templateVars["Error"] = "Failed to Create User" + c.HTML(http.StatusBadRequest, "page/login", templateVars) + return + } + + // Create User in DB rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{ - ID: username, - Pass: &hashedPassword, + ID: username, + Pass: &hashedPassword, + AuthHash: fmt.Sprintf("%x", rawAuthHash), }) // SQL Error @@ -233,9 +250,10 @@ func (api *API) appAuthFormRegister(c *gin.Context) { auth := authData{ UserName: user.ID, IsAdmin: user.Admin, + AuthHash: user.AuthHash, } session := sessions.Default(c) - if err := setSession(session, auth); err != nil { + if err := api.setSession(session, auth); err != nil { appErrorPage(c, http.StatusUnauthorized, "Unauthorized.") return } @@ -251,12 +269,13 @@ func (api *API) appAuthLogout(c *gin.Context) { c.Redirect(http.StatusFound, "/login") } -func getSession(session sessions.Session) (auth authData, ok bool) { - // Check Session +func (api *API) getSession(session sessions.Session) (auth authData, ok bool) { + // Get Session authorizedUser := session.Get("authorizedUser") isAdmin := session.Get("isAdmin") expiresAt := session.Get("expiresAt") - if authorizedUser == nil || isAdmin == nil || expiresAt == nil { + authHash := session.Get("authHash") + if authorizedUser == nil || isAdmin == nil || expiresAt == nil || authHash == nil { return } @@ -264,22 +283,70 @@ func getSession(session sessions.Session) (auth authData, ok bool) { auth = authData{ UserName: authorizedUser.(string), IsAdmin: isAdmin.(bool), + AuthHash: authHash.(string), + } + + // Validate Auth Hash + correctAuthHash, err := api.getUserAuthHash(auth.UserName) + if err != nil || correctAuthHash != auth.AuthHash { + return } // Refresh if expiresAt.(int64)-time.Now().Unix() < 60*60*24 { log.Info("Refreshing Session") - setSession(session, auth) + api.setSession(session, auth) } // Authorized return auth, true } -func setSession(session sessions.Session, auth authData) error { +func (api *API) setSession(session sessions.Session, auth authData) error { // Set Session Cookie session.Set("authorizedUser", auth.UserName) session.Set("isAdmin", auth.IsAdmin) session.Set("expiresAt", time.Now().Unix()+(60*60*24*7)) + session.Set("authHash", auth.AuthHash) + return session.Save() } + +func (api *API) getUserAuthHash(username string) (string, error) { + // Return Cache + if api.userAuthCache[username] != "" { + return api.userAuthCache[username], nil + } + + // Get DB + user, err := api.db.Queries.GetUser(api.db.Ctx, username) + if err != nil { + log.Error("GetUser DB Error:", err) + return "", err + } + + // Update Cache + api.userAuthCache[username] = user.AuthHash + + return api.userAuthCache[username], nil +} + +func (api *API) rotateUserAuthHash(username string) error { + // Generate Auth Hash + rawAuthHash, err := utils.GenerateToken(64) + if err != nil { + log.Error("Failed to generate user token: ", err) + return err + } + + // Update User + _, err = api.db.Queries.UpdateUser(api.db.Ctx, database.UpdateUserParams{ + UserID: username, + AuthHash: fmt.Sprintf("%x", rawAuthHash), + }) + + // Update Cache + api.userAuthCache[username] = fmt.Sprintf("%x", rawAuthHash) + + return nil +} diff --git a/api/ko-routes.go b/api/ko-routes.go index 14d88fc..e0d52eb 100644 --- a/api/ko-routes.go +++ b/api/ko-routes.go @@ -20,6 +20,7 @@ import ( "golang.org/x/exp/slices" "reichard.io/antholume/database" "reichard.io/antholume/metadata" + "reichard.io/antholume/utils" ) type activityItem struct { @@ -107,9 +108,18 @@ func (api *API) koCreateUser(c *gin.Context) { return } + // Generate Auth Hash + rawAuthHash, err := utils.GenerateToken(64) + if err != nil { + log.Error("Failed to generate user token: ", err) + apiErrorPage(c, http.StatusBadRequest, "Unknown Error") + return + } + rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{ - ID: rUser.Username, - Pass: &hashedPassword, + ID: rUser.Username, + Pass: &hashedPassword, + AuthHash: fmt.Sprintf("%x", rawAuthHash), }) if err != nil { log.Error("CreateUser DB Error:", err) diff --git a/database/manager.go b/database/manager.go index 210176a..ceecae5 100644 --- a/database/manager.go +++ b/database/manager.go @@ -13,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" _ "modernc.org/sqlite" "reichard.io/antholume/config" + _ "reichard.io/antholume/database/migrations" ) type DBManager struct { @@ -36,13 +37,15 @@ func NewMgr(c *config.Config) *DBManager { cfg: c, } - dbm.init() + if err := dbm.init(); err != nil { + log.Panic("Unable to init DB") + } return dbm } // Init manager -func (dbm *DBManager) init() { +func (dbm *DBManager) init() error { if dbm.cfg.DBType == "sqlite" || dbm.cfg.DBType == "memory" { var dbLocation string = ":memory:" if dbm.cfg.DBType == "sqlite" { @@ -52,7 +55,8 @@ func (dbm *DBManager) init() { var err error dbm.DB, err = sql.Open("sqlite", dbLocation) if err != nil { - log.Fatalf("Unable to open DB: %v", err) + log.Errorf("Unable to open DB: %v", err) + return err } // Single Open Connection @@ -60,22 +64,36 @@ func (dbm *DBManager) init() { // Execute DDL if _, err := dbm.DB.Exec(ddl, nil); err != nil { - log.Fatalf("Error executing schema: %v", err) + log.Errorf("Error executing schema: %v", err) + return err } // Perform Migrations err = dbm.performMigrations() if err != nil && err != goose.ErrNoMigrationFiles { - log.Fatalf("Error running DB migrations: %v", err) + log.Errorf("Error running DB migrations: %v", err) + return err + } + + // Set SQLite Settings (After Migrations) + pragmaQuery := ` + PRAGMA foreign_keys = ON; + PRAGMA journal_mode = WAL; + ` + if _, err := dbm.DB.Exec(pragmaQuery, nil); err != nil { + log.Errorf("Error executing pragma: %v", err) + return err } // Cache Tables dbm.CacheTempTables() } else { - log.Fatal("Unsupported Database") + return fmt.Errorf("unsupported database") } dbm.Queries = New(dbm.DB) + + return nil } // Reload manager (close DB & reinit) @@ -87,7 +105,9 @@ func (dbm *DBManager) Reload() error { } // Reinit DB - dbm.init() + if err := dbm.init(); err != nil { + return err + } return nil } diff --git a/database/manager_test.go b/database/manager_test.go index 4e82381..86f851f 100644 --- a/database/manager_test.go +++ b/database/manager_test.go @@ -1,10 +1,12 @@ package database import ( + "fmt" "testing" "time" "reichard.io/antholume/config" + "reichard.io/antholume/utils" ) type databaseTest struct { @@ -42,9 +44,16 @@ func TestNewMgr(t *testing.T) { func (dt *databaseTest) TestUser() { dt.Run("User", func(t *testing.T) { + // Generate Auth Hash + rawAuthHash, err := utils.GenerateToken(64) + if err != nil { + t.Fatalf(`Expected: %v, Got: %v, Error: %v`, nil, err, err) + } + changed, err := dt.dbm.Queries.CreateUser(dt.dbm.Ctx, CreateUserParams{ - ID: userID, - Pass: &userPass, + ID: userID, + Pass: &userPass, + AuthHash: fmt.Sprintf("%x", rawAuthHash), }) if err != nil || changed != 1 { diff --git a/database/migrations/20240128012356_user_auth_hash.go b/database/migrations/20240128012356_user_auth_hash.go new file mode 100644 index 0000000..d55263b --- /dev/null +++ b/database/migrations/20240128012356_user_auth_hash.go @@ -0,0 +1,91 @@ +package migrations + +import ( + "context" + "database/sql" + "fmt" + + "github.com/pressly/goose/v3" + "reichard.io/antholume/utils" +) + +func init() { + goose.AddMigrationContext(upUserAuthHash, downUserAuthHash) +} + +func upUserAuthHash(ctx context.Context, tx *sql.Tx) error { + // Validate column doesn't already exist + hasCol, err := hasColumn(tx, "users", "auth_hash") + if err != nil { + return err + } else if hasCol { + return nil + } + + // Copy table & create column + _, err = tx.Exec(` + -- Create Copy Table + CREATE TABLE temp_users AS SELECT * FROM users; + ALTER TABLE temp_users ADD COLUMN auth_hash TEXT; + + -- Update Schema + DELETE FROM users; + ALTER TABLE users ADD COLUMN auth_hash TEXT NOT NULL; + `) + if err != nil { + return err + } + + // Get current users + rows, err := tx.Query("SELECT id FROM temp_users") + if err != nil { + return err + } + + // Query existing users + var users []string + for rows.Next() { + var user string + if err := rows.Scan(&user); err != nil { + return err + } + users = append(users, user) + } + + // Create auth hash per user + for _, user := range users { + rawAuthHash, err := utils.GenerateToken(64) + if err != nil { + return err + } + + authHash := fmt.Sprintf("%x", rawAuthHash) + _, err = tx.Exec("UPDATE temp_users SET auth_hash = ? WHERE id = ?", authHash, user) + if err != nil { + return err + } + } + + // Copy from temp to true table + _, err = tx.Exec(` + -- Copy Into New + INSERT INTO users SELECT * FROM temp_users; + + -- Drop Temp Table + DROP TABLE temp_users; + `) + if err != nil { + return err + } + + return nil +} + +func downUserAuthHash(ctx context.Context, tx *sql.Tx) error { + // Drop column + _, err := tx.Exec("ALTER users DROP COLUMN auth_hash") + if err != nil { + return err + } + return nil +} diff --git a/database/migrations/README.md b/database/migrations/README.md index 7541c0c..25c43d0 100644 --- a/database/migrations/README.md +++ b/database/migrations/README.md @@ -1,5 +1,9 @@ # DB Migrations ```bash +# SQL migration goose create migration_name sql + +# Go migration +goose create migration_name ``` diff --git a/database/migrations/utils.go b/database/migrations/utils.go new file mode 100644 index 0000000..8bbc12a --- /dev/null +++ b/database/migrations/utils.go @@ -0,0 +1,38 @@ +package migrations + +import ( + "database/sql" + "fmt" +) + +type columnInfo struct { + CID int + Name string + Type string + NotNull int + DefaultVal sql.NullString + PK int +} + +func hasColumn(tx *sql.Tx, table string, column string) (bool, error) { + rows, err := tx.Query(fmt.Sprintf("PRAGMA table_info(%s)", table)) + if err != nil { + return false, err + } + defer rows.Close() + + colExists := false + for rows.Next() { + var col columnInfo + if err := rows.Scan(&col.CID, &col.Name, &col.Type, &col.NotNull, &col.DefaultVal, &col.PK); err != nil { + return false, err + } + + if col.Name == column { + colExists = true + break + } + } + + return colExists, nil +} diff --git a/database/models.go b/database/models.go index 63fe65c..6cb7d15 100644 --- a/database/models.go +++ b/database/models.go @@ -96,6 +96,7 @@ type Metadatum struct { type User struct { ID string `json:"id"` Pass *string `json:"-"` + AuthHash string `json:"auth_hash"` Admin bool `json:"-"` TimeOffset *string `json:"time_offset"` CreatedAt string `json:"created_at"` diff --git a/database/query.sql b/database/query.sql index c721c1e..7025728 100644 --- a/database/query.sql +++ b/database/query.sql @@ -26,8 +26,8 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?) RETURNING *; -- name: CreateUser :execrows -INSERT INTO users (id, pass) -VALUES (?, ?) +INSERT INTO users (id, pass, auth_hash) +VALUES (?, ?, ?) ON CONFLICT DO NOTHING; -- name: DeleteDocument :execrows @@ -368,6 +368,7 @@ RETURNING *; UPDATE users SET pass = COALESCE($password, pass), + auth_hash = COALESCE($auth_hash, auth_hash), time_offset = COALESCE($time_offset, time_offset) WHERE id = $user_id RETURNING *; diff --git a/database/query.sql.go b/database/query.sql.go index a2da4fb..edeb6fc 100644 --- a/database/query.sql.go +++ b/database/query.sql.go @@ -113,18 +113,19 @@ func (q *Queries) AddMetadata(ctx context.Context, arg AddMetadataParams) (Metad } const createUser = `-- name: CreateUser :execrows -INSERT INTO users (id, pass) -VALUES (?, ?) +INSERT INTO users (id, pass, auth_hash) +VALUES (?, ?, ?) ON CONFLICT DO NOTHING ` type CreateUserParams struct { - ID string `json:"id"` - Pass *string `json:"-"` + ID string `json:"id"` + Pass *string `json:"-"` + AuthHash string `json:"auth_hash"` } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (int64, error) { - result, err := q.db.ExecContext(ctx, createUser, arg.ID, arg.Pass) + result, err := q.db.ExecContext(ctx, createUser, arg.ID, arg.Pass, arg.AuthHash) if err != nil { return 0, err } @@ -954,7 +955,7 @@ func (q *Queries) GetProgress(ctx context.Context, arg GetProgressParams) ([]Get } const getUser = `-- name: GetUser :one -SELECT id, pass, admin, time_offset, created_at FROM users +SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users WHERE id = ?1 LIMIT 1 ` @@ -964,6 +965,7 @@ func (q *Queries) GetUser(ctx context.Context, userID string) (User, error) { err := row.Scan( &i.ID, &i.Pass, + &i.AuthHash, &i.Admin, &i.TimeOffset, &i.CreatedAt, @@ -1092,7 +1094,7 @@ func (q *Queries) GetUserStreaks(ctx context.Context, userID string) ([]UserStre } const getUsers = `-- name: GetUsers :many -SELECT id, pass, admin, time_offset, created_at FROM users +SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users ` func (q *Queries) GetUsers(ctx context.Context) ([]User, error) { @@ -1107,6 +1109,7 @@ func (q *Queries) GetUsers(ctx context.Context) ([]User, error) { if err := rows.Scan( &i.ID, &i.Pass, + &i.AuthHash, &i.Admin, &i.TimeOffset, &i.CreatedAt, @@ -1214,23 +1217,31 @@ const updateUser = `-- name: UpdateUser :one UPDATE users SET pass = COALESCE(?1, pass), - time_offset = COALESCE(?2, time_offset) -WHERE id = ?3 -RETURNING id, pass, admin, time_offset, created_at + auth_hash = COALESCE(?2, auth_hash), + time_offset = COALESCE(?3, time_offset) +WHERE id = ?4 +RETURNING id, pass, auth_hash, admin, time_offset, created_at ` type UpdateUserParams struct { Password *string `json:"-"` + AuthHash string `json:"auth_hash"` TimeOffset *string `json:"time_offset"` UserID string `json:"user_id"` } func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) { - row := q.db.QueryRowContext(ctx, updateUser, arg.Password, arg.TimeOffset, arg.UserID) + row := q.db.QueryRowContext(ctx, updateUser, + arg.Password, + arg.AuthHash, + arg.TimeOffset, + arg.UserID, + ) var i User err := row.Scan( &i.ID, &i.Pass, + &i.AuthHash, &i.Admin, &i.TimeOffset, &i.CreatedAt, diff --git a/database/schema.sql b/database/schema.sql index 8ece7b6..94b82bc 100644 --- a/database/schema.sql +++ b/database/schema.sql @@ -1,6 +1,3 @@ -PRAGMA foreign_keys = ON; -PRAGMA journal_mode = WAL; - --------------------------------------------------------------- ------------------------ Normal Tables ------------------------ --------------------------------------------------------------- @@ -10,6 +7,7 @@ CREATE TABLE IF NOT EXISTS users ( id TEXT NOT NULL PRIMARY KEY, pass TEXT NOT NULL, + auth_hash TEXT NOT NULL, admin BOOLEAN NOT NULL DEFAULT 0 CHECK (admin IN (0, 1)), time_offset TEXT NOT NULL DEFAULT '0 hours', diff --git a/utils/utils.go b/utils/utils.go index 5c9d4c2..07fa056 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,6 +3,7 @@ package utils import ( "bytes" "crypto/md5" + "crypto/rand" "fmt" "io" "os" @@ -42,3 +43,13 @@ func CalculatePartialMD5(filePath string) (string, error) { allBytes := buf.Bytes() return fmt.Sprintf("%x", md5.Sum(allBytes)), nil } + +// Creates a token of n size +func GenerateToken(n int) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + return nil, err + } + return b, nil +}