feat(auth): add auth hash (allows purging sessions & more)
Some checks reported errors
continuous-integration/drone/push Build encountered an error

This commit is contained in:
Evan Reichard 2024-01-27 21:02:08 -05:00
parent 9792a6ff19
commit 386b1c46f8
11 changed files with 217 additions and 61 deletions

View File

@ -2,7 +2,6 @@ package api
import ( import (
"context" "context"
"crypto/rand"
"embed" "embed"
"fmt" "fmt"
"html/template" "html/template"
@ -20,23 +19,26 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/antholume/config" "reichard.io/antholume/config"
"reichard.io/antholume/database" "reichard.io/antholume/database"
"reichard.io/antholume/utils"
) )
type API struct { type API struct {
db *database.DBManager db *database.DBManager
cfg *config.Config cfg *config.Config
assets *embed.FS assets *embed.FS
templates map[string]*template.Template httpServer *http.Server
httpServer *http.Server templates map[string]*template.Template
userAuthCache map[string]string
} }
var htmlPolicy = bluemonday.StrictPolicy() var htmlPolicy = bluemonday.StrictPolicy()
func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API { func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API {
api := &API{ api := &API{
db: db, db: db,
cfg: c, cfg: c,
assets: assets, assets: assets,
userAuthCache: make(map[string]string),
} }
// Create Router // Create Router
@ -63,7 +65,7 @@ func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API {
newToken = []byte(c.CookieAuthKey) newToken = []byte(c.CookieAuthKey)
} else { } else {
log.Info("Generating cookie auth key") log.Info("Generating cookie auth key")
newToken, err = generateToken(64) newToken, err = utils.GenerateToken(64)
if err != nil { if err != nil {
log.Panic("Unable to generate cookie auth key") log.Panic("Unable to generate cookie auth key")
} }
@ -102,8 +104,8 @@ func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API {
func (api *API) Start() error { func (api *API) Start() error {
return api.httpServer.ListenAndServe() return api.httpServer.ListenAndServe()
} }
func (api *API) Stop() error { func (api *API) Stop() error {
// Stop Server // Stop Server
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@ -115,7 +117,6 @@ func (api *API) Stop() error {
// Close DB // Close DB
return api.db.DB.Close() return api.db.DB.Close()
} }
func (api *API) registerWebAppRoutes(router *gin.Engine) { func (api *API) registerWebAppRoutes(router *gin.Engine) {
@ -312,12 +313,3 @@ func apiLogger() gin.HandlerFunc {
log.WithFields(logData).Info(fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path)) log.WithFields(logData).Info(fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
} }
} }
func generateToken(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}

View File

@ -1453,15 +1453,10 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
return return
} }
// Close DB
err = api.db.DB.Close()
if err != nil {
appErrorPage(c, http.StatusInternalServerError, "Unable to close DB.")
log.Panic("Unable to close DB: ", err)
}
// Reinit DB // Reinit DB
api.db.Reload() if err := api.db.Reload(); err != nil {
log.Panicf("Unable to reload DB: %v", err)
}
} }
func (api *API) restoreData(zipReader *zip.Reader) error { func (api *API) restoreData(zipReader *zip.Reader) error {

View File

@ -12,12 +12,14 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/antholume/database" "reichard.io/antholume/database"
"reichard.io/antholume/utils"
) )
// Authorization Data // Authorization Data
type authData struct { type authData struct {
UserName string UserName string
IsAdmin bool IsAdmin bool
AuthHash string
} }
// KOSync API Auth Headers // KOSync API Auth Headers
@ -41,9 +43,13 @@ func (api *API) authorizeCredentials(username string, password string) (auth *au
return return
} }
// Update Auth Cache
api.userAuthCache[user.ID] = user.AuthHash
return &authData{ return &authData{
UserName: user.ID, UserName: user.ID,
IsAdmin: user.Admin, IsAdmin: user.Admin,
AuthHash: user.AuthHash,
} }
} }
@ -51,7 +57,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
// Check Session First // Check Session First
if auth, ok := getSession(session); ok == true { if auth, ok := api.getSession(session); ok == true {
c.Set("Authorization", auth) c.Set("Authorization", auth)
c.Header("Cache-Control", "private") c.Header("Cache-Control", "private")
c.Next() c.Next()
@ -76,7 +82,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
return return
} }
if err := setSession(session, *authData); err != nil { if err := api.setSession(session, *authData); err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return return
} }
@ -114,7 +120,7 @@ func (api *API) authWebAppMiddleware(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
// Check Session // Check Session
if auth, ok := getSession(session); ok == true { if auth, ok := api.getSession(session); ok == true {
c.Set("Authorization", auth) c.Set("Authorization", auth)
c.Header("Cache-Control", "private") c.Header("Cache-Control", "private")
c.Next() c.Next()
@ -163,7 +169,7 @@ func (api *API) appAuthFormLogin(c *gin.Context) {
// Set Session // Set Session
session := sessions.Default(c) session := sessions.Default(c)
if err := setSession(session, *authData); err != nil { if err := api.setSession(session, *authData); err != nil {
templateVars["Error"] = "Invalid Credentials" templateVars["Error"] = "Invalid Credentials"
c.HTML(http.StatusUnauthorized, "page/login", templateVars) c.HTML(http.StatusUnauthorized, "page/login", templateVars)
return return
@ -199,9 +205,20 @@ func (api *API) appAuthFormRegister(c *gin.Context) {
return return
} }
// Generate Auth Hash
rawAuthHash, err := utils.GenerateToken(64)
if err != nil {
log.Error("Failed to generate user token: ", err)
templateVars["Error"] = "Failed to Create User"
c.HTML(http.StatusBadRequest, "page/login", templateVars)
return
}
// Create User in DB
rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{ rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
ID: username, ID: username,
Pass: &hashedPassword, Pass: &hashedPassword,
AuthHash: fmt.Sprintf("%x", rawAuthHash),
}) })
// SQL Error // SQL Error
@ -233,9 +250,10 @@ func (api *API) appAuthFormRegister(c *gin.Context) {
auth := authData{ auth := authData{
UserName: user.ID, UserName: user.ID,
IsAdmin: user.Admin, IsAdmin: user.Admin,
AuthHash: user.AuthHash,
} }
session := sessions.Default(c) session := sessions.Default(c)
if err := setSession(session, auth); err != nil { if err := api.setSession(session, auth); err != nil {
appErrorPage(c, http.StatusUnauthorized, "Unauthorized.") appErrorPage(c, http.StatusUnauthorized, "Unauthorized.")
return return
} }
@ -251,12 +269,13 @@ func (api *API) appAuthLogout(c *gin.Context) {
c.Redirect(http.StatusFound, "/login") c.Redirect(http.StatusFound, "/login")
} }
func getSession(session sessions.Session) (auth authData, ok bool) { func (api *API) getSession(session sessions.Session) (auth authData, ok bool) {
// Check Session // Get Session
authorizedUser := session.Get("authorizedUser") authorizedUser := session.Get("authorizedUser")
isAdmin := session.Get("isAdmin") isAdmin := session.Get("isAdmin")
expiresAt := session.Get("expiresAt") expiresAt := session.Get("expiresAt")
if authorizedUser == nil || isAdmin == nil || expiresAt == nil { authHash := session.Get("authHash")
if authorizedUser == nil || isAdmin == nil || expiresAt == nil || authHash == nil {
return return
} }
@ -264,22 +283,70 @@ func getSession(session sessions.Session) (auth authData, ok bool) {
auth = authData{ auth = authData{
UserName: authorizedUser.(string), UserName: authorizedUser.(string),
IsAdmin: isAdmin.(bool), IsAdmin: isAdmin.(bool),
AuthHash: authHash.(string),
}
// Validate Auth Hash
correctAuthHash, err := api.getUserAuthHash(auth.UserName)
if err != nil || correctAuthHash != auth.AuthHash {
return
} }
// Refresh // Refresh
if expiresAt.(int64)-time.Now().Unix() < 60*60*24 { if expiresAt.(int64)-time.Now().Unix() < 60*60*24 {
log.Info("Refreshing Session") log.Info("Refreshing Session")
setSession(session, auth) api.setSession(session, auth)
} }
// Authorized // Authorized
return auth, true return auth, true
} }
func setSession(session sessions.Session, auth authData) error { func (api *API) setSession(session sessions.Session, auth authData) error {
// Set Session Cookie // Set Session Cookie
session.Set("authorizedUser", auth.UserName) session.Set("authorizedUser", auth.UserName)
session.Set("isAdmin", auth.IsAdmin) session.Set("isAdmin", auth.IsAdmin)
session.Set("expiresAt", time.Now().Unix()+(60*60*24*7)) session.Set("expiresAt", time.Now().Unix()+(60*60*24*7))
session.Set("authHash", auth.AuthHash)
return session.Save() return session.Save()
} }
func (api *API) getUserAuthHash(username string) (string, error) {
// Return Cache
if api.userAuthCache[username] != "" {
return api.userAuthCache[username], nil
}
// Get DB
user, err := api.db.Queries.GetUser(api.db.Ctx, username)
if err != nil {
log.Error("GetUser DB Error:", err)
return "", err
}
// Update Cache
api.userAuthCache[username] = user.AuthHash
return api.userAuthCache[username], nil
}
func (api *API) rotateUserAuthHash(username string) error {
// Generate Auth Hash
rawAuthHash, err := utils.GenerateToken(64)
if err != nil {
log.Error("Failed to generate user token: ", err)
return err
}
// Update User
_, err = api.db.Queries.UpdateUser(api.db.Ctx, database.UpdateUserParams{
UserID: username,
AuthHash: fmt.Sprintf("%x", rawAuthHash),
})
// Update Cache
api.userAuthCache[username] = fmt.Sprintf("%x", rawAuthHash)
return nil
}

View File

@ -13,6 +13,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
"reichard.io/antholume/config" "reichard.io/antholume/config"
_ "reichard.io/antholume/database/migrations"
) )
type DBManager struct { type DBManager struct {
@ -36,13 +37,15 @@ func NewMgr(c *config.Config) *DBManager {
cfg: c, cfg: c,
} }
dbm.init() if err := dbm.init(); err != nil {
log.Panic("Unable to init DB")
}
return dbm return dbm
} }
// Init manager // Init manager
func (dbm *DBManager) init() { func (dbm *DBManager) init() error {
if dbm.cfg.DBType == "sqlite" || dbm.cfg.DBType == "memory" { if dbm.cfg.DBType == "sqlite" || dbm.cfg.DBType == "memory" {
var dbLocation string = ":memory:" var dbLocation string = ":memory:"
if dbm.cfg.DBType == "sqlite" { if dbm.cfg.DBType == "sqlite" {
@ -52,7 +55,8 @@ func (dbm *DBManager) init() {
var err error var err error
dbm.DB, err = sql.Open("sqlite", dbLocation) dbm.DB, err = sql.Open("sqlite", dbLocation)
if err != nil { if err != nil {
log.Fatalf("Unable to open DB: %v", err) log.Errorf("Unable to open DB: %v", err)
return err
} }
// Single Open Connection // Single Open Connection
@ -60,22 +64,26 @@ func (dbm *DBManager) init() {
// Execute DDL // Execute DDL
if _, err := dbm.DB.Exec(ddl, nil); err != nil { if _, err := dbm.DB.Exec(ddl, nil); err != nil {
log.Fatalf("Error executing schema: %v", err) log.Errorf("Error executing schema: %v", err)
return err
} }
// Perform Migrations // Perform Migrations
err = dbm.performMigrations() err = dbm.performMigrations()
if err != nil && err != goose.ErrNoMigrationFiles { if err != nil && err != goose.ErrNoMigrationFiles {
log.Fatalf("Error running DB migrations: %v", err) log.Errorf("Error running DB migrations: %v", err)
return err
} }
// Cache Tables // Cache Tables
dbm.CacheTempTables() dbm.CacheTempTables()
} else { } else {
log.Fatal("Unsupported Database") return fmt.Errorf("unsupported database")
} }
dbm.Queries = New(dbm.DB) dbm.Queries = New(dbm.DB)
return nil
} }
// Reload manager (close DB & reinit) // Reload manager (close DB & reinit)
@ -87,7 +95,9 @@ func (dbm *DBManager) Reload() error {
} }
// Reinit DB // Reinit DB
dbm.init() if err := dbm.init(); err != nil {
return err
}
return nil return nil
} }

View File

@ -0,0 +1,63 @@
package migrations
import (
"context"
"database/sql"
"fmt"
"github.com/pressly/goose/v3"
"reichard.io/antholume/utils"
)
func init() {
goose.AddMigrationContext(upUserAuthHash, downUserAuthHash)
}
func upUserAuthHash(ctx context.Context, tx *sql.Tx) error {
// Create now column
_, err := tx.Exec("ALTER TABLE users ADD COLUMN auth_hash TEXT")
if err != nil {
return err
}
// Get current users
rows, err := tx.Query("SELECT id FROM users")
if err != nil {
return err
}
// Query existing users
var users []string
for rows.Next() {
var user string
if err := rows.Scan(&user); err != nil {
return err
}
users = append(users, user)
}
// Create auth hash per user
for _, user := range users {
rawAuthHash, err := utils.GenerateToken(64)
if err != nil {
return err
}
authHash := fmt.Sprintf("%x", rawAuthHash)
_, err = tx.Exec("UPDATE users SET auth_hash = ? WHERE id = ?", authHash, user)
if err != nil {
return err
}
}
return nil
}
func downUserAuthHash(ctx context.Context, tx *sql.Tx) error {
// Drop column
_, err := tx.Exec("ALTER users DROP COLUMN auth_hash")
if err != nil {
return err
}
return nil
}

View File

@ -1,5 +1,9 @@
# DB Migrations # DB Migrations
```bash ```bash
# SQL migration
goose create migration_name sql goose create migration_name sql
# Go migration
goose create migration_name
``` ```

View File

@ -96,6 +96,7 @@ type Metadatum struct {
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Pass *string `json:"-"` Pass *string `json:"-"`
AuthHash string `json:"auth_hash"`
Admin bool `json:"-"` Admin bool `json:"-"`
TimeOffset *string `json:"time_offset"` TimeOffset *string `json:"time_offset"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`

View File

@ -26,8 +26,8 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
RETURNING *; RETURNING *;
-- name: CreateUser :execrows -- name: CreateUser :execrows
INSERT INTO users (id, pass) INSERT INTO users (id, pass, auth_hash)
VALUES (?, ?) VALUES (?, ?, ?)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
-- name: DeleteDocument :execrows -- name: DeleteDocument :execrows
@ -368,6 +368,7 @@ RETURNING *;
UPDATE users UPDATE users
SET SET
pass = COALESCE($password, pass), pass = COALESCE($password, pass),
auth_hash = COALESCE($auth_hash, auth_hash),
time_offset = COALESCE($time_offset, time_offset) time_offset = COALESCE($time_offset, time_offset)
WHERE id = $user_id WHERE id = $user_id
RETURNING *; RETURNING *;

View File

@ -113,18 +113,19 @@ func (q *Queries) AddMetadata(ctx context.Context, arg AddMetadataParams) (Metad
} }
const createUser = `-- name: CreateUser :execrows const createUser = `-- name: CreateUser :execrows
INSERT INTO users (id, pass) INSERT INTO users (id, pass, auth_hash)
VALUES (?, ?) VALUES (?, ?, ?)
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
` `
type CreateUserParams struct { type CreateUserParams struct {
ID string `json:"id"` ID string `json:"id"`
Pass *string `json:"-"` Pass *string `json:"-"`
AuthHash string `json:"auth_hash"`
} }
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (int64, error) { func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (int64, error) {
result, err := q.db.ExecContext(ctx, createUser, arg.ID, arg.Pass) result, err := q.db.ExecContext(ctx, createUser, arg.ID, arg.Pass, arg.AuthHash)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -954,7 +955,7 @@ func (q *Queries) GetProgress(ctx context.Context, arg GetProgressParams) ([]Get
} }
const getUser = `-- name: GetUser :one const getUser = `-- name: GetUser :one
SELECT id, pass, admin, time_offset, created_at FROM users SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users
WHERE id = ?1 LIMIT 1 WHERE id = ?1 LIMIT 1
` `
@ -964,6 +965,7 @@ func (q *Queries) GetUser(ctx context.Context, userID string) (User, error) {
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,
&i.Pass, &i.Pass,
&i.AuthHash,
&i.Admin, &i.Admin,
&i.TimeOffset, &i.TimeOffset,
&i.CreatedAt, &i.CreatedAt,
@ -1092,7 +1094,7 @@ func (q *Queries) GetUserStreaks(ctx context.Context, userID string) ([]UserStre
} }
const getUsers = `-- name: GetUsers :many const getUsers = `-- name: GetUsers :many
SELECT id, pass, admin, time_offset, created_at FROM users SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users
` `
func (q *Queries) GetUsers(ctx context.Context) ([]User, error) { func (q *Queries) GetUsers(ctx context.Context) ([]User, error) {
@ -1107,6 +1109,7 @@ func (q *Queries) GetUsers(ctx context.Context) ([]User, error) {
if err := rows.Scan( if err := rows.Scan(
&i.ID, &i.ID,
&i.Pass, &i.Pass,
&i.AuthHash,
&i.Admin, &i.Admin,
&i.TimeOffset, &i.TimeOffset,
&i.CreatedAt, &i.CreatedAt,
@ -1214,23 +1217,31 @@ const updateUser = `-- name: UpdateUser :one
UPDATE users UPDATE users
SET SET
pass = COALESCE(?1, pass), pass = COALESCE(?1, pass),
time_offset = COALESCE(?2, time_offset) auth_hash = COALESCE(?2, auth_hash),
WHERE id = ?3 time_offset = COALESCE(?3, time_offset)
RETURNING id, pass, admin, time_offset, created_at WHERE id = ?4
RETURNING id, pass, auth_hash, admin, time_offset, created_at
` `
type UpdateUserParams struct { type UpdateUserParams struct {
Password *string `json:"-"` Password *string `json:"-"`
AuthHash string `json:"auth_hash"`
TimeOffset *string `json:"time_offset"` TimeOffset *string `json:"time_offset"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
} }
func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) { func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) {
row := q.db.QueryRowContext(ctx, updateUser, arg.Password, arg.TimeOffset, arg.UserID) row := q.db.QueryRowContext(ctx, updateUser,
arg.Password,
arg.AuthHash,
arg.TimeOffset,
arg.UserID,
)
var i User var i User
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,
&i.Pass, &i.Pass,
&i.AuthHash,
&i.Admin, &i.Admin,
&i.TimeOffset, &i.TimeOffset,
&i.CreatedAt, &i.CreatedAt,

View File

@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS users (
id TEXT NOT NULL PRIMARY KEY, id TEXT NOT NULL PRIMARY KEY,
pass TEXT NOT NULL, pass TEXT NOT NULL,
auth_hash TEXT NOT NULL,
admin BOOLEAN NOT NULL DEFAULT 0 CHECK (admin IN (0, 1)), admin BOOLEAN NOT NULL DEFAULT 0 CHECK (admin IN (0, 1)),
time_offset TEXT NOT NULL DEFAULT '0 hours', time_offset TEXT NOT NULL DEFAULT '0 hours',

View File

@ -3,6 +3,7 @@ package utils
import ( import (
"bytes" "bytes"
"crypto/md5" "crypto/md5"
"crypto/rand"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -42,3 +43,13 @@ func CalculatePartialMD5(filePath string) (string, error) {
allBytes := buf.Bytes() allBytes := buf.Bytes()
return fmt.Sprintf("%x", md5.Sum(allBytes)), nil return fmt.Sprintf("%x", md5.Sum(allBytes)), nil
} }
// Creates a token of n size
func GenerateToken(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}