chore(db): use context & add db helper
This commit is contained in:
39
api/auth.go
39
api/auth.go
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -28,8 +29,8 @@ type authKOHeader struct {
|
||||
AuthKey string `header:"x-auth-key"`
|
||||
}
|
||||
|
||||
func (api *API) authorizeCredentials(username string, password string) (auth *authData) {
|
||||
user, err := api.db.Queries.GetUser(api.db.Ctx, username)
|
||||
func (api *API) authorizeCredentials(ctx context.Context, username string, password string) (auth *authData) {
|
||||
user, err := api.db.Queries.GetUser(ctx, username)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -52,7 +53,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
|
||||
// Check Session First
|
||||
if auth, ok := api.getSession(session); ok {
|
||||
if auth, ok := api.getSession(c, session); ok {
|
||||
c.Set("Authorization", auth)
|
||||
c.Header("Cache-Control", "private")
|
||||
c.Next()
|
||||
@@ -71,7 +72,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
authData := api.authorizeCredentials(rHeader.AuthUser, rHeader.AuthKey)
|
||||
authData := api.authorizeCredentials(c, rHeader.AuthUser, rHeader.AuthKey)
|
||||
if authData == nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
@@ -100,7 +101,7 @@ func (api *API) authOPDSMiddleware(c *gin.Context) {
|
||||
|
||||
// Validate Auth
|
||||
password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword)))
|
||||
authData := api.authorizeCredentials(user, password)
|
||||
authData := api.authorizeCredentials(c, user, password)
|
||||
if authData == nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
@@ -115,7 +116,7 @@ func (api *API) authWebAppMiddleware(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
|
||||
// Check Session
|
||||
if auth, ok := api.getSession(session); ok {
|
||||
if auth, ok := api.getSession(c, session); ok {
|
||||
c.Set("Authorization", auth)
|
||||
c.Header("Cache-Control", "private")
|
||||
c.Next()
|
||||
@@ -153,7 +154,7 @@ func (api *API) appAuthLogin(c *gin.Context) {
|
||||
|
||||
// MD5 - KOSync Compatiblity
|
||||
password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword)))
|
||||
authData := api.authorizeCredentials(username, password)
|
||||
authData := api.authorizeCredentials(c, username, password)
|
||||
if authData == nil {
|
||||
templateVars["Error"] = "Invalid Credentials"
|
||||
c.HTML(http.StatusUnauthorized, "page/login", templateVars)
|
||||
@@ -208,7 +209,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get current users
|
||||
currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx)
|
||||
currentUsers, err := api.db.Queries.GetUsers(c)
|
||||
if err != nil {
|
||||
log.Error("Failed to check all users: ", err)
|
||||
templateVars["Error"] = "Failed to Create User"
|
||||
@@ -224,7 +225,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
|
||||
|
||||
// Create user in DB
|
||||
authHash := fmt.Sprintf("%x", rawAuthHash)
|
||||
if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
|
||||
if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{
|
||||
ID: username,
|
||||
Pass: &hashedPassword,
|
||||
AuthHash: &authHash,
|
||||
@@ -242,7 +243,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := api.db.Queries.GetUser(api.db.Ctx, username)
|
||||
user, err := api.db.Queries.GetUser(c, username)
|
||||
if err != nil {
|
||||
log.Error("GetUser DB Error:", err)
|
||||
templateVars["Error"] = "Registration Disabled or User Already Exists"
|
||||
@@ -312,7 +313,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Get current users
|
||||
currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx)
|
||||
currentUsers, err := api.db.Queries.GetUsers(c)
|
||||
if err != nil {
|
||||
log.Error("Failed to check all users: ", err)
|
||||
apiErrorPage(c, http.StatusBadRequest, "Failed to Create User")
|
||||
@@ -327,7 +328,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
|
||||
|
||||
// Create user
|
||||
authHash := fmt.Sprintf("%x", rawAuthHash)
|
||||
if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
|
||||
if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{
|
||||
ID: rUser.Username,
|
||||
Pass: &hashedPassword,
|
||||
AuthHash: &authHash,
|
||||
@@ -347,7 +348,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (api *API) getSession(session sessions.Session) (auth authData, ok bool) {
|
||||
func (api *API) getSession(ctx context.Context, session sessions.Session) (auth authData, ok bool) {
|
||||
// Get Session
|
||||
authorizedUser := session.Get("authorizedUser")
|
||||
isAdmin := session.Get("isAdmin")
|
||||
@@ -365,7 +366,7 @@ func (api *API) getSession(session sessions.Session) (auth authData, ok bool) {
|
||||
}
|
||||
|
||||
// Validate Auth Hash
|
||||
correctAuthHash, err := api.getUserAuthHash(auth.UserName)
|
||||
correctAuthHash, err := api.getUserAuthHash(ctx, auth.UserName)
|
||||
if err != nil || correctAuthHash != auth.AuthHash {
|
||||
return
|
||||
}
|
||||
@@ -393,14 +394,14 @@ func (api *API) setSession(session sessions.Session, auth authData) error {
|
||||
return session.Save()
|
||||
}
|
||||
|
||||
func (api *API) getUserAuthHash(username string) (string, error) {
|
||||
func (api *API) getUserAuthHash(ctx context.Context, username string) (string, error) {
|
||||
// Return Cache
|
||||
if api.userAuthCache[username] != "" {
|
||||
return api.userAuthCache[username], nil
|
||||
}
|
||||
|
||||
// Get DB
|
||||
user, err := api.db.Queries.GetUser(api.db.Ctx, username)
|
||||
user, err := api.db.Queries.GetUser(ctx, username)
|
||||
if err != nil {
|
||||
log.Error("GetUser DB Error:", err)
|
||||
return "", err
|
||||
@@ -412,7 +413,7 @@ func (api *API) getUserAuthHash(username string) (string, error) {
|
||||
return api.userAuthCache[username], nil
|
||||
}
|
||||
|
||||
func (api *API) rotateAllAuthHashes() error {
|
||||
func (api *API) rotateAllAuthHashes(ctx context.Context) error {
|
||||
// Do Transaction
|
||||
tx, err := api.db.DB.Begin()
|
||||
if err != nil {
|
||||
@@ -428,7 +429,7 @@ func (api *API) rotateAllAuthHashes() error {
|
||||
}()
|
||||
qtx := api.db.Queries.WithTx(tx)
|
||||
|
||||
users, err := qtx.GetUsers(api.db.Ctx)
|
||||
users, err := qtx.GetUsers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -444,7 +445,7 @@ func (api *API) rotateAllAuthHashes() error {
|
||||
|
||||
// Update User
|
||||
authHash := fmt.Sprintf("%x", rawAuthHash)
|
||||
if _, err = qtx.UpdateUser(api.db.Ctx, database.UpdateUserParams{
|
||||
if _, err = qtx.UpdateUser(ctx, database.UpdateUserParams{
|
||||
UserID: user.ID,
|
||||
AuthHash: &authHash,
|
||||
Admin: user.Admin,
|
||||
|
||||
Reference in New Issue
Block a user