feat(auth): add auth hash (allows purging sessions & more)
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
_ "modernc.org/sqlite"
|
||||
"reichard.io/antholume/config"
|
||||
_ "reichard.io/antholume/database/migrations"
|
||||
)
|
||||
|
||||
type DBManager struct {
|
||||
@@ -36,13 +37,15 @@ func NewMgr(c *config.Config) *DBManager {
|
||||
cfg: c,
|
||||
}
|
||||
|
||||
dbm.init()
|
||||
if err := dbm.init(); err != nil {
|
||||
log.Panic("Unable to init DB")
|
||||
}
|
||||
|
||||
return dbm
|
||||
}
|
||||
|
||||
// Init manager
|
||||
func (dbm *DBManager) init() {
|
||||
func (dbm *DBManager) init() error {
|
||||
if dbm.cfg.DBType == "sqlite" || dbm.cfg.DBType == "memory" {
|
||||
var dbLocation string = ":memory:"
|
||||
if dbm.cfg.DBType == "sqlite" {
|
||||
@@ -52,7 +55,8 @@ func (dbm *DBManager) init() {
|
||||
var err error
|
||||
dbm.DB, err = sql.Open("sqlite", dbLocation)
|
||||
if err != nil {
|
||||
log.Fatalf("Unable to open DB: %v", err)
|
||||
log.Errorf("Unable to open DB: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Single Open Connection
|
||||
@@ -60,22 +64,36 @@ func (dbm *DBManager) init() {
|
||||
|
||||
// Execute DDL
|
||||
if _, err := dbm.DB.Exec(ddl, nil); err != nil {
|
||||
log.Fatalf("Error executing schema: %v", err)
|
||||
log.Errorf("Error executing schema: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform Migrations
|
||||
err = dbm.performMigrations()
|
||||
if err != nil && err != goose.ErrNoMigrationFiles {
|
||||
log.Fatalf("Error running DB migrations: %v", err)
|
||||
log.Errorf("Error running DB migrations: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Set SQLite Settings (After Migrations)
|
||||
pragmaQuery := `
|
||||
PRAGMA foreign_keys = ON;
|
||||
PRAGMA journal_mode = WAL;
|
||||
`
|
||||
if _, err := dbm.DB.Exec(pragmaQuery, nil); err != nil {
|
||||
log.Errorf("Error executing pragma: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache Tables
|
||||
dbm.CacheTempTables()
|
||||
} else {
|
||||
log.Fatal("Unsupported Database")
|
||||
return fmt.Errorf("unsupported database")
|
||||
}
|
||||
|
||||
dbm.Queries = New(dbm.DB)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reload manager (close DB & reinit)
|
||||
@@ -87,7 +105,9 @@ func (dbm *DBManager) Reload() error {
|
||||
}
|
||||
|
||||
// Reinit DB
|
||||
dbm.init()
|
||||
if err := dbm.init(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"reichard.io/antholume/config"
|
||||
"reichard.io/antholume/utils"
|
||||
)
|
||||
|
||||
type databaseTest struct {
|
||||
@@ -42,9 +44,16 @@ func TestNewMgr(t *testing.T) {
|
||||
|
||||
func (dt *databaseTest) TestUser() {
|
||||
dt.Run("User", func(t *testing.T) {
|
||||
// Generate Auth Hash
|
||||
rawAuthHash, err := utils.GenerateToken(64)
|
||||
if err != nil {
|
||||
t.Fatalf(`Expected: %v, Got: %v, Error: %v`, nil, err, err)
|
||||
}
|
||||
|
||||
changed, err := dt.dbm.Queries.CreateUser(dt.dbm.Ctx, CreateUserParams{
|
||||
ID: userID,
|
||||
Pass: &userPass,
|
||||
ID: userID,
|
||||
Pass: &userPass,
|
||||
AuthHash: fmt.Sprintf("%x", rawAuthHash),
|
||||
})
|
||||
|
||||
if err != nil || changed != 1 {
|
||||
|
||||
91
database/migrations/20240128012356_user_auth_hash.go
Normal file
91
database/migrations/20240128012356_user_auth_hash.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"reichard.io/antholume/utils"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationContext(upUserAuthHash, downUserAuthHash)
|
||||
}
|
||||
|
||||
func upUserAuthHash(ctx context.Context, tx *sql.Tx) error {
|
||||
// Validate column doesn't already exist
|
||||
hasCol, err := hasColumn(tx, "users", "auth_hash")
|
||||
if err != nil {
|
||||
return err
|
||||
} else if hasCol {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy table & create column
|
||||
_, err = tx.Exec(`
|
||||
-- Create Copy Table
|
||||
CREATE TABLE temp_users AS SELECT * FROM users;
|
||||
ALTER TABLE temp_users ADD COLUMN auth_hash TEXT;
|
||||
|
||||
-- Update Schema
|
||||
DELETE FROM users;
|
||||
ALTER TABLE users ADD COLUMN auth_hash TEXT NOT NULL;
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get current users
|
||||
rows, err := tx.Query("SELECT id FROM temp_users")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Query existing users
|
||||
var users []string
|
||||
for rows.Next() {
|
||||
var user string
|
||||
if err := rows.Scan(&user); err != nil {
|
||||
return err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
// Create auth hash per user
|
||||
for _, user := range users {
|
||||
rawAuthHash, err := utils.GenerateToken(64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authHash := fmt.Sprintf("%x", rawAuthHash)
|
||||
_, err = tx.Exec("UPDATE temp_users SET auth_hash = ? WHERE id = ?", authHash, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Copy from temp to true table
|
||||
_, err = tx.Exec(`
|
||||
-- Copy Into New
|
||||
INSERT INTO users SELECT * FROM temp_users;
|
||||
|
||||
-- Drop Temp Table
|
||||
DROP TABLE temp_users;
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func downUserAuthHash(ctx context.Context, tx *sql.Tx) error {
|
||||
// Drop column
|
||||
_, err := tx.Exec("ALTER users DROP COLUMN auth_hash")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,5 +1,9 @@
|
||||
# DB Migrations
|
||||
|
||||
```bash
|
||||
# SQL migration
|
||||
goose create migration_name sql
|
||||
|
||||
# Go migration
|
||||
goose create migration_name
|
||||
```
|
||||
|
||||
38
database/migrations/utils.go
Normal file
38
database/migrations/utils.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type columnInfo struct {
|
||||
CID int
|
||||
Name string
|
||||
Type string
|
||||
NotNull int
|
||||
DefaultVal sql.NullString
|
||||
PK int
|
||||
}
|
||||
|
||||
func hasColumn(tx *sql.Tx, table string, column string) (bool, error) {
|
||||
rows, err := tx.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
colExists := false
|
||||
for rows.Next() {
|
||||
var col columnInfo
|
||||
if err := rows.Scan(&col.CID, &col.Name, &col.Type, &col.NotNull, &col.DefaultVal, &col.PK); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if col.Name == column {
|
||||
colExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return colExists, nil
|
||||
}
|
||||
@@ -96,6 +96,7 @@ type Metadatum struct {
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Pass *string `json:"-"`
|
||||
AuthHash string `json:"auth_hash"`
|
||||
Admin bool `json:"-"`
|
||||
TimeOffset *string `json:"time_offset"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
|
||||
@@ -26,8 +26,8 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
RETURNING *;
|
||||
|
||||
-- name: CreateUser :execrows
|
||||
INSERT INTO users (id, pass)
|
||||
VALUES (?, ?)
|
||||
INSERT INTO users (id, pass, auth_hash)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT DO NOTHING;
|
||||
|
||||
-- name: DeleteDocument :execrows
|
||||
@@ -368,6 +368,7 @@ RETURNING *;
|
||||
UPDATE users
|
||||
SET
|
||||
pass = COALESCE($password, pass),
|
||||
auth_hash = COALESCE($auth_hash, auth_hash),
|
||||
time_offset = COALESCE($time_offset, time_offset)
|
||||
WHERE id = $user_id
|
||||
RETURNING *;
|
||||
|
||||
@@ -113,18 +113,19 @@ func (q *Queries) AddMetadata(ctx context.Context, arg AddMetadataParams) (Metad
|
||||
}
|
||||
|
||||
const createUser = `-- name: CreateUser :execrows
|
||||
INSERT INTO users (id, pass)
|
||||
VALUES (?, ?)
|
||||
INSERT INTO users (id, pass, auth_hash)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
|
||||
type CreateUserParams struct {
|
||||
ID string `json:"id"`
|
||||
Pass *string `json:"-"`
|
||||
ID string `json:"id"`
|
||||
Pass *string `json:"-"`
|
||||
AuthHash string `json:"auth_hash"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (int64, error) {
|
||||
result, err := q.db.ExecContext(ctx, createUser, arg.ID, arg.Pass)
|
||||
result, err := q.db.ExecContext(ctx, createUser, arg.ID, arg.Pass, arg.AuthHash)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -954,7 +955,7 @@ func (q *Queries) GetProgress(ctx context.Context, arg GetProgressParams) ([]Get
|
||||
}
|
||||
|
||||
const getUser = `-- name: GetUser :one
|
||||
SELECT id, pass, admin, time_offset, created_at FROM users
|
||||
SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users
|
||||
WHERE id = ?1 LIMIT 1
|
||||
`
|
||||
|
||||
@@ -964,6 +965,7 @@ func (q *Queries) GetUser(ctx context.Context, userID string) (User, error) {
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Pass,
|
||||
&i.AuthHash,
|
||||
&i.Admin,
|
||||
&i.TimeOffset,
|
||||
&i.CreatedAt,
|
||||
@@ -1092,7 +1094,7 @@ func (q *Queries) GetUserStreaks(ctx context.Context, userID string) ([]UserStre
|
||||
}
|
||||
|
||||
const getUsers = `-- name: GetUsers :many
|
||||
SELECT id, pass, admin, time_offset, created_at FROM users
|
||||
SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users
|
||||
`
|
||||
|
||||
func (q *Queries) GetUsers(ctx context.Context) ([]User, error) {
|
||||
@@ -1107,6 +1109,7 @@ func (q *Queries) GetUsers(ctx context.Context) ([]User, error) {
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Pass,
|
||||
&i.AuthHash,
|
||||
&i.Admin,
|
||||
&i.TimeOffset,
|
||||
&i.CreatedAt,
|
||||
@@ -1214,23 +1217,31 @@ const updateUser = `-- name: UpdateUser :one
|
||||
UPDATE users
|
||||
SET
|
||||
pass = COALESCE(?1, pass),
|
||||
time_offset = COALESCE(?2, time_offset)
|
||||
WHERE id = ?3
|
||||
RETURNING id, pass, admin, time_offset, created_at
|
||||
auth_hash = COALESCE(?2, auth_hash),
|
||||
time_offset = COALESCE(?3, time_offset)
|
||||
WHERE id = ?4
|
||||
RETURNING id, pass, auth_hash, admin, time_offset, created_at
|
||||
`
|
||||
|
||||
type UpdateUserParams struct {
|
||||
Password *string `json:"-"`
|
||||
AuthHash string `json:"auth_hash"`
|
||||
TimeOffset *string `json:"time_offset"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUser, arg.Password, arg.TimeOffset, arg.UserID)
|
||||
row := q.db.QueryRowContext(ctx, updateUser,
|
||||
arg.Password,
|
||||
arg.AuthHash,
|
||||
arg.TimeOffset,
|
||||
arg.UserID,
|
||||
)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Pass,
|
||||
&i.AuthHash,
|
||||
&i.Admin,
|
||||
&i.TimeOffset,
|
||||
&i.CreatedAt,
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
PRAGMA foreign_keys = ON;
|
||||
PRAGMA journal_mode = WAL;
|
||||
|
||||
---------------------------------------------------------------
|
||||
------------------------ Normal Tables ------------------------
|
||||
---------------------------------------------------------------
|
||||
@@ -10,6 +7,7 @@ CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
|
||||
pass TEXT NOT NULL,
|
||||
auth_hash TEXT NOT NULL,
|
||||
admin BOOLEAN NOT NULL DEFAULT 0 CHECK (admin IN (0, 1)),
|
||||
time_offset TEXT NOT NULL DEFAULT '0 hours',
|
||||
|
||||
|
||||
Reference in New Issue
Block a user