feat(auth): add auth hash (allows purging sessions & more)
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
9792a6ff19
commit
015ca30ac5
20
api/api.go
20
api/api.go
@ -2,7 +2,6 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
@ -20,14 +19,16 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"reichard.io/antholume/config"
|
"reichard.io/antholume/config"
|
||||||
"reichard.io/antholume/database"
|
"reichard.io/antholume/database"
|
||||||
|
"reichard.io/antholume/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type API struct {
|
type API struct {
|
||||||
db *database.DBManager
|
db *database.DBManager
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
assets *embed.FS
|
assets *embed.FS
|
||||||
templates map[string]*template.Template
|
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
|
templates map[string]*template.Template
|
||||||
|
userAuthCache map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
var htmlPolicy = bluemonday.StrictPolicy()
|
var htmlPolicy = bluemonday.StrictPolicy()
|
||||||
@ -37,6 +38,7 @@ func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API {
|
|||||||
db: db,
|
db: db,
|
||||||
cfg: c,
|
cfg: c,
|
||||||
assets: assets,
|
assets: assets,
|
||||||
|
userAuthCache: make(map[string]string),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create Router
|
// Create Router
|
||||||
@ -63,7 +65,7 @@ func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API {
|
|||||||
newToken = []byte(c.CookieAuthKey)
|
newToken = []byte(c.CookieAuthKey)
|
||||||
} else {
|
} else {
|
||||||
log.Info("Generating cookie auth key")
|
log.Info("Generating cookie auth key")
|
||||||
newToken, err = generateToken(64)
|
newToken, err = utils.GenerateToken(64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic("Unable to generate cookie auth key")
|
log.Panic("Unable to generate cookie auth key")
|
||||||
}
|
}
|
||||||
@ -102,8 +104,8 @@ func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API {
|
|||||||
|
|
||||||
func (api *API) Start() error {
|
func (api *API) Start() error {
|
||||||
return api.httpServer.ListenAndServe()
|
return api.httpServer.ListenAndServe()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) Stop() error {
|
func (api *API) Stop() error {
|
||||||
// Stop Server
|
// Stop Server
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
@ -115,7 +117,6 @@ func (api *API) Stop() error {
|
|||||||
|
|
||||||
// Close DB
|
// Close DB
|
||||||
return api.db.DB.Close()
|
return api.db.DB.Close()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) registerWebAppRoutes(router *gin.Engine) {
|
func (api *API) registerWebAppRoutes(router *gin.Engine) {
|
||||||
@ -312,12 +313,3 @@ func apiLogger() gin.HandlerFunc {
|
|||||||
log.WithFields(logData).Info(fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
log.WithFields(logData).Info(fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateToken(n int) ([]byte, error) {
|
|
||||||
b := make([]byte, n)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return b, nil
|
|
||||||
}
|
|
||||||
|
@ -1453,15 +1453,10 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close DB
|
|
||||||
err = api.db.DB.Close()
|
|
||||||
if err != nil {
|
|
||||||
appErrorPage(c, http.StatusInternalServerError, "Unable to close DB.")
|
|
||||||
log.Panic("Unable to close DB: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reinit DB
|
// Reinit DB
|
||||||
api.db.Reload()
|
if err := api.db.Reload(); err != nil {
|
||||||
|
log.Panicf("Unable to reload DB: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) restoreData(zipReader *zip.Reader) error {
|
func (api *API) restoreData(zipReader *zip.Reader) error {
|
||||||
|
87
api/auth.go
87
api/auth.go
@ -12,12 +12,14 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"reichard.io/antholume/database"
|
"reichard.io/antholume/database"
|
||||||
|
"reichard.io/antholume/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authorization Data
|
// Authorization Data
|
||||||
type authData struct {
|
type authData struct {
|
||||||
UserName string
|
UserName string
|
||||||
IsAdmin bool
|
IsAdmin bool
|
||||||
|
AuthHash string
|
||||||
}
|
}
|
||||||
|
|
||||||
// KOSync API Auth Headers
|
// KOSync API Auth Headers
|
||||||
@ -41,9 +43,13 @@ func (api *API) authorizeCredentials(username string, password string) (auth *au
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update Auth Cache
|
||||||
|
api.userAuthCache[user.ID] = user.AuthHash
|
||||||
|
|
||||||
return &authData{
|
return &authData{
|
||||||
UserName: user.ID,
|
UserName: user.ID,
|
||||||
IsAdmin: user.Admin,
|
IsAdmin: user.Admin,
|
||||||
|
AuthHash: user.AuthHash,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +57,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
|
|||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
|
|
||||||
// Check Session First
|
// Check Session First
|
||||||
if auth, ok := getSession(session); ok == true {
|
if auth, ok := api.getSession(session); ok == true {
|
||||||
c.Set("Authorization", auth)
|
c.Set("Authorization", auth)
|
||||||
c.Header("Cache-Control", "private")
|
c.Header("Cache-Control", "private")
|
||||||
c.Next()
|
c.Next()
|
||||||
@ -76,7 +82,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := setSession(session, *authData); err != nil {
|
if err := api.setSession(session, *authData); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -114,7 +120,7 @@ func (api *API) authWebAppMiddleware(c *gin.Context) {
|
|||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
|
|
||||||
// Check Session
|
// Check Session
|
||||||
if auth, ok := getSession(session); ok == true {
|
if auth, ok := api.getSession(session); ok == true {
|
||||||
c.Set("Authorization", auth)
|
c.Set("Authorization", auth)
|
||||||
c.Header("Cache-Control", "private")
|
c.Header("Cache-Control", "private")
|
||||||
c.Next()
|
c.Next()
|
||||||
@ -163,7 +169,7 @@ func (api *API) appAuthFormLogin(c *gin.Context) {
|
|||||||
|
|
||||||
// Set Session
|
// Set Session
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
if err := setSession(session, *authData); err != nil {
|
if err := api.setSession(session, *authData); err != nil {
|
||||||
templateVars["Error"] = "Invalid Credentials"
|
templateVars["Error"] = "Invalid Credentials"
|
||||||
c.HTML(http.StatusUnauthorized, "page/login", templateVars)
|
c.HTML(http.StatusUnauthorized, "page/login", templateVars)
|
||||||
return
|
return
|
||||||
@ -199,9 +205,20 @@ func (api *API) appAuthFormRegister(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate Auth Hash
|
||||||
|
rawAuthHash, err := utils.GenerateToken(64)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to generate user token: ", err)
|
||||||
|
templateVars["Error"] = "Failed to Create User"
|
||||||
|
c.HTML(http.StatusBadRequest, "page/login", templateVars)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create User in DB
|
||||||
rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
|
rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
|
||||||
ID: username,
|
ID: username,
|
||||||
Pass: &hashedPassword,
|
Pass: &hashedPassword,
|
||||||
|
AuthHash: fmt.Sprintf("%x", rawAuthHash),
|
||||||
})
|
})
|
||||||
|
|
||||||
// SQL Error
|
// SQL Error
|
||||||
@ -233,9 +250,10 @@ func (api *API) appAuthFormRegister(c *gin.Context) {
|
|||||||
auth := authData{
|
auth := authData{
|
||||||
UserName: user.ID,
|
UserName: user.ID,
|
||||||
IsAdmin: user.Admin,
|
IsAdmin: user.Admin,
|
||||||
|
AuthHash: user.AuthHash,
|
||||||
}
|
}
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
if err := setSession(session, auth); err != nil {
|
if err := api.setSession(session, auth); err != nil {
|
||||||
appErrorPage(c, http.StatusUnauthorized, "Unauthorized.")
|
appErrorPage(c, http.StatusUnauthorized, "Unauthorized.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -251,12 +269,13 @@ func (api *API) appAuthLogout(c *gin.Context) {
|
|||||||
c.Redirect(http.StatusFound, "/login")
|
c.Redirect(http.StatusFound, "/login")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSession(session sessions.Session) (auth authData, ok bool) {
|
func (api *API) getSession(session sessions.Session) (auth authData, ok bool) {
|
||||||
// Check Session
|
// Get Session
|
||||||
authorizedUser := session.Get("authorizedUser")
|
authorizedUser := session.Get("authorizedUser")
|
||||||
isAdmin := session.Get("isAdmin")
|
isAdmin := session.Get("isAdmin")
|
||||||
expiresAt := session.Get("expiresAt")
|
expiresAt := session.Get("expiresAt")
|
||||||
if authorizedUser == nil || isAdmin == nil || expiresAt == nil {
|
authHash := session.Get("authHash")
|
||||||
|
if authorizedUser == nil || isAdmin == nil || expiresAt == nil || authHash == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,22 +283,70 @@ func getSession(session sessions.Session) (auth authData, ok bool) {
|
|||||||
auth = authData{
|
auth = authData{
|
||||||
UserName: authorizedUser.(string),
|
UserName: authorizedUser.(string),
|
||||||
IsAdmin: isAdmin.(bool),
|
IsAdmin: isAdmin.(bool),
|
||||||
|
AuthHash: authHash.(string),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Auth Hash
|
||||||
|
correctAuthHash, err := api.getUserAuthHash(auth.UserName)
|
||||||
|
if err != nil || correctAuthHash != auth.AuthHash {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh
|
// Refresh
|
||||||
if expiresAt.(int64)-time.Now().Unix() < 60*60*24 {
|
if expiresAt.(int64)-time.Now().Unix() < 60*60*24 {
|
||||||
log.Info("Refreshing Session")
|
log.Info("Refreshing Session")
|
||||||
setSession(session, auth)
|
api.setSession(session, auth)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authorized
|
// Authorized
|
||||||
return auth, true
|
return auth, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func setSession(session sessions.Session, auth authData) error {
|
func (api *API) setSession(session sessions.Session, auth authData) error {
|
||||||
// Set Session Cookie
|
// Set Session Cookie
|
||||||
session.Set("authorizedUser", auth.UserName)
|
session.Set("authorizedUser", auth.UserName)
|
||||||
session.Set("isAdmin", auth.IsAdmin)
|
session.Set("isAdmin", auth.IsAdmin)
|
||||||
session.Set("expiresAt", time.Now().Unix()+(60*60*24*7))
|
session.Set("expiresAt", time.Now().Unix()+(60*60*24*7))
|
||||||
|
session.Set("authHash", auth.AuthHash)
|
||||||
|
|
||||||
return session.Save()
|
return session.Save()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (api *API) getUserAuthHash(username string) (string, error) {
|
||||||
|
// Return Cache
|
||||||
|
if api.userAuthCache[username] != "" {
|
||||||
|
return api.userAuthCache[username], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get DB
|
||||||
|
user, err := api.db.Queries.GetUser(api.db.Ctx, username)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("GetUser DB Error:", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update Cache
|
||||||
|
api.userAuthCache[username] = user.AuthHash
|
||||||
|
|
||||||
|
return api.userAuthCache[username], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (api *API) rotateUserAuthHash(username string) error {
|
||||||
|
// Generate Auth Hash
|
||||||
|
rawAuthHash, err := utils.GenerateToken(64)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to generate user token: ", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update User
|
||||||
|
_, err = api.db.Queries.UpdateUser(api.db.Ctx, database.UpdateUserParams{
|
||||||
|
UserID: username,
|
||||||
|
AuthHash: fmt.Sprintf("%x", rawAuthHash),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Update Cache
|
||||||
|
api.userAuthCache[username] = fmt.Sprintf("%x", rawAuthHash)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -20,6 +20,7 @@ import (
|
|||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
"reichard.io/antholume/database"
|
"reichard.io/antholume/database"
|
||||||
"reichard.io/antholume/metadata"
|
"reichard.io/antholume/metadata"
|
||||||
|
"reichard.io/antholume/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type activityItem struct {
|
type activityItem struct {
|
||||||
@ -107,9 +108,18 @@ func (api *API) koCreateUser(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate Auth Hash
|
||||||
|
rawAuthHash, err := utils.GenerateToken(64)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to generate user token: ", err)
|
||||||
|
apiErrorPage(c, http.StatusBadRequest, "Unknown Error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
|
rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
|
||||||
ID: rUser.Username,
|
ID: rUser.Username,
|
||||||
Pass: &hashedPassword,
|
Pass: &hashedPassword,
|
||||||
|
AuthHash: fmt.Sprintf("%x", rawAuthHash),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("CreateUser DB Error:", err)
|
log.Error("CreateUser DB Error:", err)
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
"reichard.io/antholume/config"
|
"reichard.io/antholume/config"
|
||||||
|
_ "reichard.io/antholume/database/migrations"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DBManager struct {
|
type DBManager struct {
|
||||||
@ -36,13 +37,15 @@ func NewMgr(c *config.Config) *DBManager {
|
|||||||
cfg: c,
|
cfg: c,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbm.init()
|
if err := dbm.init(); err != nil {
|
||||||
|
log.Panic("Unable to init DB")
|
||||||
|
}
|
||||||
|
|
||||||
return dbm
|
return dbm
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init manager
|
// Init manager
|
||||||
func (dbm *DBManager) init() {
|
func (dbm *DBManager) init() error {
|
||||||
if dbm.cfg.DBType == "sqlite" || dbm.cfg.DBType == "memory" {
|
if dbm.cfg.DBType == "sqlite" || dbm.cfg.DBType == "memory" {
|
||||||
var dbLocation string = ":memory:"
|
var dbLocation string = ":memory:"
|
||||||
if dbm.cfg.DBType == "sqlite" {
|
if dbm.cfg.DBType == "sqlite" {
|
||||||
@ -52,7 +55,8 @@ func (dbm *DBManager) init() {
|
|||||||
var err error
|
var err error
|
||||||
dbm.DB, err = sql.Open("sqlite", dbLocation)
|
dbm.DB, err = sql.Open("sqlite", dbLocation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Unable to open DB: %v", err)
|
log.Errorf("Unable to open DB: %v", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single Open Connection
|
// Single Open Connection
|
||||||
@ -60,22 +64,36 @@ func (dbm *DBManager) init() {
|
|||||||
|
|
||||||
// Execute DDL
|
// Execute DDL
|
||||||
if _, err := dbm.DB.Exec(ddl, nil); err != nil {
|
if _, err := dbm.DB.Exec(ddl, nil); err != nil {
|
||||||
log.Fatalf("Error executing schema: %v", err)
|
log.Errorf("Error executing schema: %v", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform Migrations
|
// Perform Migrations
|
||||||
err = dbm.performMigrations()
|
err = dbm.performMigrations()
|
||||||
if err != nil && err != goose.ErrNoMigrationFiles {
|
if err != nil && err != goose.ErrNoMigrationFiles {
|
||||||
log.Fatalf("Error running DB migrations: %v", err)
|
log.Errorf("Error running DB migrations: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set SQLite Settings (After Migrations)
|
||||||
|
pragmaQuery := `
|
||||||
|
PRAGMA foreign_keys = ON;
|
||||||
|
PRAGMA journal_mode = WAL;
|
||||||
|
`
|
||||||
|
if _, err := dbm.DB.Exec(pragmaQuery, nil); err != nil {
|
||||||
|
log.Errorf("Error executing pragma: %v", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache Tables
|
// Cache Tables
|
||||||
dbm.CacheTempTables()
|
dbm.CacheTempTables()
|
||||||
} else {
|
} else {
|
||||||
log.Fatal("Unsupported Database")
|
return fmt.Errorf("unsupported database")
|
||||||
}
|
}
|
||||||
|
|
||||||
dbm.Queries = New(dbm.DB)
|
dbm.Queries = New(dbm.DB)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload manager (close DB & reinit)
|
// Reload manager (close DB & reinit)
|
||||||
@ -87,7 +105,9 @@ func (dbm *DBManager) Reload() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reinit DB
|
// Reinit DB
|
||||||
dbm.init()
|
if err := dbm.init(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"reichard.io/antholume/config"
|
"reichard.io/antholume/config"
|
||||||
|
"reichard.io/antholume/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type databaseTest struct {
|
type databaseTest struct {
|
||||||
@ -42,9 +44,16 @@ func TestNewMgr(t *testing.T) {
|
|||||||
|
|
||||||
func (dt *databaseTest) TestUser() {
|
func (dt *databaseTest) TestUser() {
|
||||||
dt.Run("User", func(t *testing.T) {
|
dt.Run("User", func(t *testing.T) {
|
||||||
|
// Generate Auth Hash
|
||||||
|
rawAuthHash, err := utils.GenerateToken(64)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(`Expected: %v, Got: %v, Error: %v`, nil, err, err)
|
||||||
|
}
|
||||||
|
|
||||||
changed, err := dt.dbm.Queries.CreateUser(dt.dbm.Ctx, CreateUserParams{
|
changed, err := dt.dbm.Queries.CreateUser(dt.dbm.Ctx, CreateUserParams{
|
||||||
ID: userID,
|
ID: userID,
|
||||||
Pass: &userPass,
|
Pass: &userPass,
|
||||||
|
AuthHash: fmt.Sprintf("%x", rawAuthHash),
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil || changed != 1 {
|
if err != nil || changed != 1 {
|
||||||
|
91
database/migrations/20240128012356_user_auth_hash.go
Normal file
91
database/migrations/20240128012356_user_auth_hash.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package migrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pressly/goose/v3"
|
||||||
|
"reichard.io/antholume/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
goose.AddMigrationContext(upUserAuthHash, downUserAuthHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func upUserAuthHash(ctx context.Context, tx *sql.Tx) error {
|
||||||
|
// Validate column doesn't already exist
|
||||||
|
hasCol, err := hasColumn(tx, "users", "auth_hash")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if hasCol {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy table & create column
|
||||||
|
_, err = tx.Exec(`
|
||||||
|
-- Create Copy Table
|
||||||
|
CREATE TABLE temp_users AS SELECT * FROM users;
|
||||||
|
ALTER TABLE temp_users ADD COLUMN auth_hash TEXT;
|
||||||
|
|
||||||
|
-- Update Schema
|
||||||
|
DELETE FROM users;
|
||||||
|
ALTER TABLE users ADD COLUMN auth_hash TEXT NOT NULL;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current users
|
||||||
|
rows, err := tx.Query("SELECT id FROM temp_users")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query existing users
|
||||||
|
var users []string
|
||||||
|
for rows.Next() {
|
||||||
|
var user string
|
||||||
|
if err := rows.Scan(&user); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
users = append(users, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create auth hash per user
|
||||||
|
for _, user := range users {
|
||||||
|
rawAuthHash, err := utils.GenerateToken(64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
authHash := fmt.Sprintf("%x", rawAuthHash)
|
||||||
|
_, err = tx.Exec("UPDATE temp_users SET auth_hash = ? WHERE id = ?", authHash, user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy from temp to true table
|
||||||
|
_, err = tx.Exec(`
|
||||||
|
-- Copy Into New
|
||||||
|
INSERT INTO users SELECT * FROM temp_users;
|
||||||
|
|
||||||
|
-- Drop Temp Table
|
||||||
|
DROP TABLE temp_users;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func downUserAuthHash(ctx context.Context, tx *sql.Tx) error {
|
||||||
|
// Drop column
|
||||||
|
_, err := tx.Exec("ALTER users DROP COLUMN auth_hash")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,5 +1,9 @@
|
|||||||
# DB Migrations
|
# DB Migrations
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# SQL migration
|
||||||
goose create migration_name sql
|
goose create migration_name sql
|
||||||
|
|
||||||
|
# Go migration
|
||||||
|
goose create migration_name
|
||||||
```
|
```
|
||||||
|
38
database/migrations/utils.go
Normal file
38
database/migrations/utils.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package migrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type columnInfo struct {
|
||||||
|
CID int
|
||||||
|
Name string
|
||||||
|
Type string
|
||||||
|
NotNull int
|
||||||
|
DefaultVal sql.NullString
|
||||||
|
PK int
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasColumn(tx *sql.Tx, table string, column string) (bool, error) {
|
||||||
|
rows, err := tx.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
colExists := false
|
||||||
|
for rows.Next() {
|
||||||
|
var col columnInfo
|
||||||
|
if err := rows.Scan(&col.CID, &col.Name, &col.Type, &col.NotNull, &col.DefaultVal, &col.PK); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if col.Name == column {
|
||||||
|
colExists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return colExists, nil
|
||||||
|
}
|
@ -96,6 +96,7 @@ type Metadatum struct {
|
|||||||
type User struct {
|
type User struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Pass *string `json:"-"`
|
Pass *string `json:"-"`
|
||||||
|
AuthHash string `json:"auth_hash"`
|
||||||
Admin bool `json:"-"`
|
Admin bool `json:"-"`
|
||||||
TimeOffset *string `json:"time_offset"`
|
TimeOffset *string `json:"time_offset"`
|
||||||
CreatedAt string `json:"created_at"`
|
CreatedAt string `json:"created_at"`
|
||||||
|
@ -26,8 +26,8 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
-- name: CreateUser :execrows
|
-- name: CreateUser :execrows
|
||||||
INSERT INTO users (id, pass)
|
INSERT INTO users (id, pass, auth_hash)
|
||||||
VALUES (?, ?)
|
VALUES (?, ?, ?)
|
||||||
ON CONFLICT DO NOTHING;
|
ON CONFLICT DO NOTHING;
|
||||||
|
|
||||||
-- name: DeleteDocument :execrows
|
-- name: DeleteDocument :execrows
|
||||||
@ -368,6 +368,7 @@ RETURNING *;
|
|||||||
UPDATE users
|
UPDATE users
|
||||||
SET
|
SET
|
||||||
pass = COALESCE($password, pass),
|
pass = COALESCE($password, pass),
|
||||||
|
auth_hash = COALESCE($auth_hash, auth_hash),
|
||||||
time_offset = COALESCE($time_offset, time_offset)
|
time_offset = COALESCE($time_offset, time_offset)
|
||||||
WHERE id = $user_id
|
WHERE id = $user_id
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
@ -113,18 +113,19 @@ func (q *Queries) AddMetadata(ctx context.Context, arg AddMetadataParams) (Metad
|
|||||||
}
|
}
|
||||||
|
|
||||||
const createUser = `-- name: CreateUser :execrows
|
const createUser = `-- name: CreateUser :execrows
|
||||||
INSERT INTO users (id, pass)
|
INSERT INTO users (id, pass, auth_hash)
|
||||||
VALUES (?, ?)
|
VALUES (?, ?, ?)
|
||||||
ON CONFLICT DO NOTHING
|
ON CONFLICT DO NOTHING
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateUserParams struct {
|
type CreateUserParams struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Pass *string `json:"-"`
|
Pass *string `json:"-"`
|
||||||
|
AuthHash string `json:"auth_hash"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (int64, error) {
|
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (int64, error) {
|
||||||
result, err := q.db.ExecContext(ctx, createUser, arg.ID, arg.Pass)
|
result, err := q.db.ExecContext(ctx, createUser, arg.ID, arg.Pass, arg.AuthHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -954,7 +955,7 @@ func (q *Queries) GetProgress(ctx context.Context, arg GetProgressParams) ([]Get
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getUser = `-- name: GetUser :one
|
const getUser = `-- name: GetUser :one
|
||||||
SELECT id, pass, admin, time_offset, created_at FROM users
|
SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users
|
||||||
WHERE id = ?1 LIMIT 1
|
WHERE id = ?1 LIMIT 1
|
||||||
`
|
`
|
||||||
|
|
||||||
@ -964,6 +965,7 @@ func (q *Queries) GetUser(ctx context.Context, userID string) (User, error) {
|
|||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.ID,
|
&i.ID,
|
||||||
&i.Pass,
|
&i.Pass,
|
||||||
|
&i.AuthHash,
|
||||||
&i.Admin,
|
&i.Admin,
|
||||||
&i.TimeOffset,
|
&i.TimeOffset,
|
||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
@ -1092,7 +1094,7 @@ func (q *Queries) GetUserStreaks(ctx context.Context, userID string) ([]UserStre
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getUsers = `-- name: GetUsers :many
|
const getUsers = `-- name: GetUsers :many
|
||||||
SELECT id, pass, admin, time_offset, created_at FROM users
|
SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetUsers(ctx context.Context) ([]User, error) {
|
func (q *Queries) GetUsers(ctx context.Context) ([]User, error) {
|
||||||
@ -1107,6 +1109,7 @@ func (q *Queries) GetUsers(ctx context.Context) ([]User, error) {
|
|||||||
if err := rows.Scan(
|
if err := rows.Scan(
|
||||||
&i.ID,
|
&i.ID,
|
||||||
&i.Pass,
|
&i.Pass,
|
||||||
|
&i.AuthHash,
|
||||||
&i.Admin,
|
&i.Admin,
|
||||||
&i.TimeOffset,
|
&i.TimeOffset,
|
||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
@ -1214,23 +1217,31 @@ const updateUser = `-- name: UpdateUser :one
|
|||||||
UPDATE users
|
UPDATE users
|
||||||
SET
|
SET
|
||||||
pass = COALESCE(?1, pass),
|
pass = COALESCE(?1, pass),
|
||||||
time_offset = COALESCE(?2, time_offset)
|
auth_hash = COALESCE(?2, auth_hash),
|
||||||
WHERE id = ?3
|
time_offset = COALESCE(?3, time_offset)
|
||||||
RETURNING id, pass, admin, time_offset, created_at
|
WHERE id = ?4
|
||||||
|
RETURNING id, pass, auth_hash, admin, time_offset, created_at
|
||||||
`
|
`
|
||||||
|
|
||||||
type UpdateUserParams struct {
|
type UpdateUserParams struct {
|
||||||
Password *string `json:"-"`
|
Password *string `json:"-"`
|
||||||
|
AuthHash string `json:"auth_hash"`
|
||||||
TimeOffset *string `json:"time_offset"`
|
TimeOffset *string `json:"time_offset"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) {
|
func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) {
|
||||||
row := q.db.QueryRowContext(ctx, updateUser, arg.Password, arg.TimeOffset, arg.UserID)
|
row := q.db.QueryRowContext(ctx, updateUser,
|
||||||
|
arg.Password,
|
||||||
|
arg.AuthHash,
|
||||||
|
arg.TimeOffset,
|
||||||
|
arg.UserID,
|
||||||
|
)
|
||||||
var i User
|
var i User
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.ID,
|
&i.ID,
|
||||||
&i.Pass,
|
&i.Pass,
|
||||||
|
&i.AuthHash,
|
||||||
&i.Admin,
|
&i.Admin,
|
||||||
&i.TimeOffset,
|
&i.TimeOffset,
|
||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
PRAGMA foreign_keys = ON;
|
|
||||||
PRAGMA journal_mode = WAL;
|
|
||||||
|
|
||||||
---------------------------------------------------------------
|
---------------------------------------------------------------
|
||||||
------------------------ Normal Tables ------------------------
|
------------------------ Normal Tables ------------------------
|
||||||
---------------------------------------------------------------
|
---------------------------------------------------------------
|
||||||
@ -10,6 +7,7 @@ CREATE TABLE IF NOT EXISTS users (
|
|||||||
id TEXT NOT NULL PRIMARY KEY,
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
|
||||||
pass TEXT NOT NULL,
|
pass TEXT NOT NULL,
|
||||||
|
auth_hash TEXT NOT NULL,
|
||||||
admin BOOLEAN NOT NULL DEFAULT 0 CHECK (admin IN (0, 1)),
|
admin BOOLEAN NOT NULL DEFAULT 0 CHECK (admin IN (0, 1)),
|
||||||
time_offset TEXT NOT NULL DEFAULT '0 hours',
|
time_offset TEXT NOT NULL DEFAULT '0 hours',
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
@ -42,3 +43,13 @@ func CalculatePartialMD5(filePath string) (string, error) {
|
|||||||
allBytes := buf.Bytes()
|
allBytes := buf.Bytes()
|
||||||
return fmt.Sprintf("%x", md5.Sum(allBytes)), nil
|
return fmt.Sprintf("%x", md5.Sum(allBytes)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Creates a token of n size
|
||||||
|
func GenerateToken(n int) ([]byte, error) {
|
||||||
|
b := make([]byte, n)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user