chore(db): use context & add db helper

This commit is contained in:
2025-08-10 13:15:46 -04:00
parent 7c92c346fa
commit 938dd69e5e
15 changed files with 241 additions and 328 deletions

View File

@@ -3,6 +3,7 @@ package api
import (
"archive/zip"
"bufio"
"context"
"crypto/md5"
"encoding/json"
"fmt"
@@ -112,7 +113,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) {
// 2. Select all / deselect?
case adminCacheTables:
go func() {
err := api.db.CacheTempTables()
err := api.db.CacheTempTables(c)
if err != nil {
log.Error("Unable to cache temp tables: ", err)
}
@@ -122,7 +123,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) {
return
case adminBackup:
// Vacuum
_, err := api.db.DB.ExecContext(api.db.Ctx, "VACUUM;")
_, err := api.db.DB.ExecContext(c, "VACUUM;")
if err != nil {
log.Error("Unable to vacuum DB: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database")
@@ -144,7 +145,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) {
}
}
err := api.createBackup(w, directories)
err := api.createBackup(c, w, directories)
if err != nil {
log.Error("Backup Error: ", err)
}
@@ -261,7 +262,7 @@ func (api *API) appGetAdminLogs(c *gin.Context) {
func (api *API) appGetAdminUsers(c *gin.Context) {
templateVars, _ := api.getBaseTemplateVars("admin-users", c)
users, err := api.db.Queries.GetUsers(api.db.Ctx)
users, err := api.db.Queries.GetUsers(c)
if err != nil {
log.Error("GetUsers DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err))
@@ -292,11 +293,11 @@ func (api *API) appUpdateAdminUsers(c *gin.Context) {
var err error
switch rUpdate.Operation {
case opCreate:
err = api.createUser(rUpdate.User, rUpdate.Password, rUpdate.IsAdmin)
err = api.createUser(c, rUpdate.User, rUpdate.Password, rUpdate.IsAdmin)
case opUpdate:
err = api.updateUser(rUpdate.User, rUpdate.Password, rUpdate.IsAdmin)
err = api.updateUser(c, rUpdate.User, rUpdate.Password, rUpdate.IsAdmin)
case opDelete:
err = api.deleteUser(rUpdate.User)
err = api.deleteUser(c, rUpdate.User)
default:
appErrorPage(c, http.StatusNotFound, "Unknown user operation")
return
@@ -307,7 +308,7 @@ func (api *API) appUpdateAdminUsers(c *gin.Context) {
return
}
users, err := api.db.Queries.GetUsers(api.db.Ctx)
users, err := api.db.Queries.GetUsers(c)
if err != nil {
log.Error("GetUsers DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err))
@@ -448,7 +449,7 @@ func (api *API) appPerformAdminImport(c *gin.Context) {
iResult.Name = fmt.Sprintf("%s - %s", *fileMeta.Author, *fileMeta.Title)
// Check already exists
_, err = qtx.GetDocument(api.db.Ctx, *fileMeta.PartialMD5)
_, err = qtx.GetDocument(c, *fileMeta.PartialMD5)
if err == nil {
log.Warnf("document already exists: %s", *fileMeta.PartialMD5)
iResult.Status = importExists
@@ -492,7 +493,7 @@ func (api *API) appPerformAdminImport(c *gin.Context) {
}
// Upsert document
if _, err = qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
if _, err = qtx.UpsertDocument(c, database.UpsertDocumentParams{
ID: *fileMeta.PartialMD5,
Title: fileMeta.Title,
Author: fileMeta.Author,
@@ -627,7 +628,7 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
// Save Backup File
w := bufio.NewWriter(backupFile)
err = api.createBackup(w, []string{"covers", "documents"})
err = api.createBackup(c, w, []string{"covers", "documents"})
if err != nil {
log.Error("Unable to save backup file: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to save backup file")
@@ -650,13 +651,13 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
}
// Reinit DB
if err := api.db.Reload(); err != nil {
if err := api.db.Reload(c); err != nil {
appErrorPage(c, http.StatusInternalServerError, "Unable to reload DB")
log.Panicf("Unable to reload DB: %v", err)
}
// Rotate Auth Hashes
if err := api.rotateAllAuthHashes(); err != nil {
if err := api.rotateAllAuthHashes(c); err != nil {
appErrorPage(c, http.StatusInternalServerError, "Unable to rotate hashes")
log.Panicf("Unable to rotate auth hashes: %v", err)
}
@@ -717,9 +718,9 @@ func (api *API) removeData() error {
return nil
}
func (api *API) createBackup(w io.Writer, directories []string) error {
func (api *API) createBackup(ctx context.Context, w io.Writer, directories []string) error {
// Vacuum DB
_, err := api.db.DB.ExecContext(api.db.Ctx, "VACUUM;")
_, err := api.db.DB.ExecContext(ctx, "VACUUM;")
if err != nil {
return errors.Wrap(err, "Unable to vacuum database")
}
@@ -792,8 +793,8 @@ func (api *API) createBackup(w io.Writer, directories []string) error {
return nil
}
func (api *API) isLastAdmin(userID string) (bool, error) {
allUsers, err := api.db.Queries.GetUsers(api.db.Ctx)
func (api *API) isLastAdmin(ctx context.Context, userID string) (bool, error) {
allUsers, err := api.db.Queries.GetUsers(ctx)
if err != nil {
return false, errors.Wrap(err, fmt.Sprintf("GetUsers DB Error: %v", err))
}
@@ -809,7 +810,7 @@ func (api *API) isLastAdmin(userID string) (bool, error) {
return !hasAdmin, nil
}
func (api *API) createUser(user string, rawPassword *string, isAdmin *bool) error {
func (api *API) createUser(ctx context.Context, user string, rawPassword *string, isAdmin *bool) error {
// Validate Necessary Parameters
if rawPassword == nil || *rawPassword == "" {
return fmt.Errorf("password can't be empty")
@@ -844,7 +845,7 @@ func (api *API) createUser(user string, rawPassword *string, isAdmin *bool) erro
createParams.AuthHash = &authHash
// Create user in DB
if rows, err := api.db.Queries.CreateUser(api.db.Ctx, createParams); err != nil {
if rows, err := api.db.Queries.CreateUser(ctx, createParams); err != nil {
log.Error("CreateUser DB Error:", err)
return fmt.Errorf("unable to create user")
} else if rows == 0 {
@@ -855,7 +856,7 @@ func (api *API) createUser(user string, rawPassword *string, isAdmin *bool) erro
return nil
}
func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) error {
func (api *API) updateUser(ctx context.Context, user string, rawPassword *string, isAdmin *bool) error {
// Validate Necessary Parameters
if rawPassword == nil && isAdmin == nil {
return fmt.Errorf("nothing to update")
@@ -870,7 +871,7 @@ func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) erro
if isAdmin != nil {
updateParams.Admin = *isAdmin
} else {
user, err := api.db.Queries.GetUser(api.db.Ctx, user)
user, err := api.db.Queries.GetUser(ctx, user)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("GetUser DB Error: %v", err))
}
@@ -878,7 +879,7 @@ func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) erro
}
// Check Admins - Disallow Demotion
if isLast, err := api.isLastAdmin(user); err != nil {
if isLast, err := api.isLastAdmin(ctx, user); err != nil {
return err
} else if isLast && !updateParams.Admin {
return fmt.Errorf("unable to demote %s - last admin", user)
@@ -908,7 +909,7 @@ func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) erro
}
// Update User
_, err := api.db.Queries.UpdateUser(api.db.Ctx, updateParams)
_, err := api.db.Queries.UpdateUser(ctx, updateParams)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("UpdateUser DB Error: %v", err))
}
@@ -916,9 +917,9 @@ func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) erro
return nil
}
func (api *API) deleteUser(user string) error {
func (api *API) deleteUser(ctx context.Context, user string) error {
// Check Admins
if isLast, err := api.isLastAdmin(user); err != nil {
if isLast, err := api.isLastAdmin(ctx, user); err != nil {
return err
} else if isLast {
return fmt.Errorf("unable to delete %s - last admin", user)
@@ -934,13 +935,13 @@ func (api *API) deleteUser(user string) error {
// Save Backup File (DB Only)
w := bufio.NewWriter(backupFile)
err = api.createBackup(w, []string{})
err = api.createBackup(ctx, w, []string{})
if err != nil {
return err
}
// Delete User
_, err = api.db.Queries.DeleteUser(api.db.Ctx, user)
_, err = api.db.Queries.DeleteUser(ctx, user)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("DeleteUser DB Error: %v", err))
}

View File

@@ -1,6 +1,7 @@
package api
import (
"context"
"crypto/md5"
"database/sql"
"fmt"
@@ -22,6 +23,7 @@ import (
"golang.org/x/exp/slices"
"reichard.io/antholume/database"
"reichard.io/antholume/metadata"
"reichard.io/antholume/pkg/ptr"
"reichard.io/antholume/search"
)
@@ -109,11 +111,12 @@ func (api *API) appGetDocuments(c *gin.Context) {
query = &search
}
documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
Query: query,
Offset: (*qParams.Page - 1) * *qParams.Limit,
Limit: *qParams.Limit,
documents, err := api.db.Queries.GetDocumentsWithStats(c, database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
Query: query,
Deleted: ptr.Of(false),
Offset: (*qParams.Page - 1) * *qParams.Limit,
Limit: *qParams.Limit,
})
if err != nil {
log.Error("GetDocumentsWithStats DB Error: ", err)
@@ -121,14 +124,14 @@ func (api *API) appGetDocuments(c *gin.Context) {
return
}
length, err := api.db.Queries.GetDocumentsSize(api.db.Ctx, query)
length, err := api.db.Queries.GetDocumentsSize(c, query)
if err != nil {
log.Error("GetDocumentsSize DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err))
return
}
if err = api.getDocumentsWordCount(documents); err != nil {
if err = api.getDocumentsWordCount(c, documents); err != nil {
log.Error("Unable to Get Word Counts: ", err)
}
@@ -160,13 +163,10 @@ func (api *API) appGetDocument(c *gin.Context) {
return
}
document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{
UserID: auth.UserName,
DocumentID: rDocID.DocumentID,
})
document, err := api.db.GetDocument(c, rDocID.DocumentID, auth.UserName)
if err != nil {
log.Error("GetDocumentWithStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsWithStats DB Error: %v", err))
log.Error("GetDocument DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
return
}
@@ -192,7 +192,7 @@ func (api *API) appGetProgress(c *gin.Context) {
progressFilter.DocumentID = *qParams.Document
}
progress, err := api.db.Queries.GetProgress(api.db.Ctx, progressFilter)
progress, err := api.db.Queries.GetProgress(c, progressFilter)
if err != nil {
log.Error("GetProgress DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err))
@@ -219,7 +219,7 @@ func (api *API) appGetActivity(c *gin.Context) {
activityFilter.DocumentID = *qParams.Document
}
activity, err := api.db.Queries.GetActivity(api.db.Ctx, activityFilter)
activity, err := api.db.Queries.GetActivity(c, activityFilter)
if err != nil {
log.Error("GetActivity DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err))
@@ -235,7 +235,7 @@ func (api *API) appGetHome(c *gin.Context) {
templateVars, auth := api.getBaseTemplateVars("home", c)
start := time.Now()
graphData, err := api.db.Queries.GetDailyReadStats(api.db.Ctx, auth.UserName)
graphData, err := api.db.Queries.GetDailyReadStats(c, auth.UserName)
if err != nil {
log.Error("GetDailyReadStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDailyReadStats DB Error: %v", err))
@@ -244,7 +244,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetDailyReadStats DB Performance: ", time.Since(start))
start = time.Now()
databaseInfo, err := api.db.Queries.GetDatabaseInfo(api.db.Ctx, auth.UserName)
databaseInfo, err := api.db.Queries.GetDatabaseInfo(c, auth.UserName)
if err != nil {
log.Error("GetDatabaseInfo DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDatabaseInfo DB Error: %v", err))
@@ -253,7 +253,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetDatabaseInfo DB Performance: ", time.Since(start))
start = time.Now()
streaks, err := api.db.Queries.GetUserStreaks(api.db.Ctx, auth.UserName)
streaks, err := api.db.Queries.GetUserStreaks(c, auth.UserName)
if err != nil {
log.Error("GetUserStreaks DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStreaks DB Error: %v", err))
@@ -262,7 +262,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetUserStreaks DB Performance: ", time.Since(start))
start = time.Now()
userStatistics, err := api.db.Queries.GetUserStatistics(api.db.Ctx)
userStatistics, err := api.db.Queries.GetUserStatistics(c)
if err != nil {
log.Error("GetUserStatistics DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStatistics DB Error: %v", err))
@@ -283,14 +283,14 @@ func (api *API) appGetHome(c *gin.Context) {
func (api *API) appGetSettings(c *gin.Context) {
templateVars, auth := api.getBaseTemplateVars("settings", c)
user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName)
user, err := api.db.Queries.GetUser(c, auth.UserName)
if err != nil {
log.Error("GetUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err))
return
}
devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName)
devices, err := api.db.Queries.GetDevices(c, auth.UserName)
if err != nil {
log.Error("GetDevices DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err))
@@ -368,7 +368,7 @@ func (api *API) appGetDocumentProgress(c *gin.Context) {
return
}
progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{
progress, err := api.db.Queries.GetDocumentProgress(c, database.GetDocumentProgressParams{
DocumentID: rDoc.DocumentID,
UserID: auth.UserName,
})
@@ -378,13 +378,10 @@ func (api *API) appGetDocumentProgress(c *gin.Context) {
return
}
document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{
UserID: auth.UserName,
DocumentID: rDoc.DocumentID,
})
document, err := api.db.GetDocument(c, rDoc.DocumentID, auth.UserName)
if err != nil {
log.Error("GetDocumentWithStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentWithStats DB Error: %v", err))
log.Error("GetDocument DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
return
}
@@ -404,7 +401,7 @@ func (api *API) appGetDevices(c *gin.Context) {
auth = data.(authData)
}
devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName)
devices, err := api.db.Queries.GetDevices(c, auth.UserName)
if err != nil && err != sql.ErrNoRows {
log.Error("GetDevices DB Error: ", err)
@@ -455,7 +452,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
}
// Check Already Exists
_, err = api.db.Queries.GetDocument(api.db.Ctx, *metadataInfo.PartialMD5)
_, err = api.db.Queries.GetDocument(c, *metadataInfo.PartialMD5)
if err == nil {
log.Warnf("document already exists: %s", *metadataInfo.PartialMD5)
c.Redirect(http.StatusFound, fmt.Sprintf("./documents/%s", *metadataInfo.PartialMD5))
@@ -483,7 +480,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
}
// Upsert Document
if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: *metadataInfo.PartialMD5,
Title: metadataInfo.Title,
Author: metadataInfo.Author,
@@ -573,7 +570,7 @@ func (api *API) appEditDocument(c *gin.Context) {
coverFileName = &fileName
} else if rDocEdit.CoverGBID != nil {
var coverDir string = filepath.Join(api.cfg.DataPath, "covers")
coverDir := filepath.Join(api.cfg.DataPath, "covers")
fileName, err := metadata.CacheCover(*rDocEdit.CoverGBID, coverDir, rDocID.DocumentID, true)
if err == nil {
coverFileName = fileName
@@ -581,7 +578,7 @@ func (api *API) appEditDocument(c *gin.Context) {
}
// Update Document
if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
if _, err := api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: rDocID.DocumentID,
Title: api.sanitizeInput(rDocEdit.Title),
Author: api.sanitizeInput(rDocEdit.Author),
@@ -605,7 +602,7 @@ func (api *API) appDeleteDocument(c *gin.Context) {
appErrorPage(c, http.StatusNotFound, "Invalid document")
return
}
changed, err := api.db.Queries.DeleteDocument(api.db.Ctx, rDocID.DocumentID)
changed, err := api.db.Queries.DeleteDocument(c, rDocID.DocumentID)
if err != nil {
log.Error("DeleteDocument DB Error")
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("DeleteDocument DB Error: %v", err))
@@ -667,7 +664,7 @@ func (api *API) appIdentifyDocument(c *gin.Context) {
firstResult := metadataResults[0]
// Store First Metadata Result
if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{
if _, err = api.db.Queries.AddMetadata(c, database.AddMetadataParams{
DocumentID: rDocID.DocumentID,
Title: firstResult.Title,
Author: firstResult.Author,
@@ -686,13 +683,10 @@ func (api *API) appIdentifyDocument(c *gin.Context) {
templateVars["MetadataError"] = "No Metadata Found"
}
document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{
UserID: auth.UserName,
DocumentID: rDocID.DocumentID,
})
document, err := api.db.GetDocument(c, rDocID.DocumentID, auth.UserName)
if err != nil {
log.Error("GetDocumentWithStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentWithStats DB Error: %v", err))
log.Error("GetDocument DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
return
}
@@ -817,7 +811,7 @@ func (api *API) appSaveNewDocument(c *gin.Context) {
sendDownloadMessage("Saving to database...", gin.H{"Progress": 99})
// Upsert Document
if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: *metadata.PartialMD5,
Title: &docTitle,
Author: &docAuthor,
@@ -864,7 +858,7 @@ func (api *API) appEditSettings(c *gin.Context) {
// Set New Password
if rUserSettings.Password != nil && rUserSettings.NewPassword != nil {
password := fmt.Sprintf("%x", md5.Sum([]byte(*rUserSettings.Password)))
data := api.authorizeCredentials(auth.UserName, password)
data := api.authorizeCredentials(c, auth.UserName, password)
if data == nil {
templateVars["PasswordErrorMessage"] = "Invalid Password"
} else {
@@ -886,7 +880,7 @@ func (api *API) appEditSettings(c *gin.Context) {
}
// Update User
_, err := api.db.Queries.UpdateUser(api.db.Ctx, newUserSettings)
_, err := api.db.Queries.UpdateUser(c, newUserSettings)
if err != nil {
log.Error("UpdateUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpdateUser DB Error: %v", err))
@@ -894,7 +888,7 @@ func (api *API) appEditSettings(c *gin.Context) {
}
// Get User
user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName)
user, err := api.db.Queries.GetUser(c, auth.UserName)
if err != nil {
log.Error("GetUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err))
@@ -902,7 +896,7 @@ func (api *API) appEditSettings(c *gin.Context) {
}
// Get Devices
devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName)
devices, err := api.db.Queries.GetDevices(c, auth.UserName)
if err != nil {
log.Error("GetDevices DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err))
@@ -921,7 +915,7 @@ func (api *API) appDemoModeError(c *gin.Context) {
appErrorPage(c, http.StatusUnauthorized, "Not Allowed in Demo Mode")
}
func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStatsRow) error {
func (api *API) getDocumentsWordCount(ctx context.Context, documents []database.GetDocumentsWithStatsRow) error {
// Do Transaction
tx, err := api.db.DB.Begin()
if err != nil {
@@ -944,7 +938,7 @@ func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStats
if err != nil {
log.Warn("Word Count Error: ", err)
} else {
if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
if _, err := qtx.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: item.ID,
Words: wordCount,
}); err != nil {
@@ -1005,7 +999,7 @@ func bindQueryParams(c *gin.Context, defaultLimit int64) queryParams {
}
func appErrorPage(c *gin.Context, errorCode int, errorMessage string) {
var errorHuman string = "We're not even sure what happened."
errorHuman := "We're not even sure what happened."
switch errorCode {
case http.StatusInternalServerError:

View File

@@ -1,6 +1,7 @@
package api
import (
"context"
"crypto/md5"
"fmt"
"net/http"
@@ -28,8 +29,8 @@ type authKOHeader struct {
AuthKey string `header:"x-auth-key"`
}
func (api *API) authorizeCredentials(username string, password string) (auth *authData) {
user, err := api.db.Queries.GetUser(api.db.Ctx, username)
func (api *API) authorizeCredentials(ctx context.Context, username string, password string) (auth *authData) {
user, err := api.db.Queries.GetUser(ctx, username)
if err != nil {
return
}
@@ -52,7 +53,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
session := sessions.Default(c)
// Check Session First
if auth, ok := api.getSession(session); ok {
if auth, ok := api.getSession(c, session); ok {
c.Set("Authorization", auth)
c.Header("Cache-Control", "private")
c.Next()
@@ -71,7 +72,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
return
}
authData := api.authorizeCredentials(rHeader.AuthUser, rHeader.AuthKey)
authData := api.authorizeCredentials(c, rHeader.AuthUser, rHeader.AuthKey)
if authData == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
@@ -100,7 +101,7 @@ func (api *API) authOPDSMiddleware(c *gin.Context) {
// Validate Auth
password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword)))
authData := api.authorizeCredentials(user, password)
authData := api.authorizeCredentials(c, user, password)
if authData == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
@@ -115,7 +116,7 @@ func (api *API) authWebAppMiddleware(c *gin.Context) {
session := sessions.Default(c)
// Check Session
if auth, ok := api.getSession(session); ok {
if auth, ok := api.getSession(c, session); ok {
c.Set("Authorization", auth)
c.Header("Cache-Control", "private")
c.Next()
@@ -153,7 +154,7 @@ func (api *API) appAuthLogin(c *gin.Context) {
// MD5 - KOSync Compatiblity
password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword)))
authData := api.authorizeCredentials(username, password)
authData := api.authorizeCredentials(c, username, password)
if authData == nil {
templateVars["Error"] = "Invalid Credentials"
c.HTML(http.StatusUnauthorized, "page/login", templateVars)
@@ -208,7 +209,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
}
// Get current users
currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx)
currentUsers, err := api.db.Queries.GetUsers(c)
if err != nil {
log.Error("Failed to check all users: ", err)
templateVars["Error"] = "Failed to Create User"
@@ -224,7 +225,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
// Create user in DB
authHash := fmt.Sprintf("%x", rawAuthHash)
if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{
ID: username,
Pass: &hashedPassword,
AuthHash: &authHash,
@@ -242,7 +243,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
}
// Get user
user, err := api.db.Queries.GetUser(api.db.Ctx, username)
user, err := api.db.Queries.GetUser(c, username)
if err != nil {
log.Error("GetUser DB Error:", err)
templateVars["Error"] = "Registration Disabled or User Already Exists"
@@ -312,7 +313,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
}
// Get current users
currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx)
currentUsers, err := api.db.Queries.GetUsers(c)
if err != nil {
log.Error("Failed to check all users: ", err)
apiErrorPage(c, http.StatusBadRequest, "Failed to Create User")
@@ -327,7 +328,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
// Create user
authHash := fmt.Sprintf("%x", rawAuthHash)
if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{
ID: rUser.Username,
Pass: &hashedPassword,
AuthHash: &authHash,
@@ -347,7 +348,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
})
}
func (api *API) getSession(session sessions.Session) (auth authData, ok bool) {
func (api *API) getSession(ctx context.Context, session sessions.Session) (auth authData, ok bool) {
// Get Session
authorizedUser := session.Get("authorizedUser")
isAdmin := session.Get("isAdmin")
@@ -365,7 +366,7 @@ func (api *API) getSession(session sessions.Session) (auth authData, ok bool) {
}
// Validate Auth Hash
correctAuthHash, err := api.getUserAuthHash(auth.UserName)
correctAuthHash, err := api.getUserAuthHash(ctx, auth.UserName)
if err != nil || correctAuthHash != auth.AuthHash {
return
}
@@ -393,14 +394,14 @@ func (api *API) setSession(session sessions.Session, auth authData) error {
return session.Save()
}
func (api *API) getUserAuthHash(username string) (string, error) {
func (api *API) getUserAuthHash(ctx context.Context, username string) (string, error) {
// Return Cache
if api.userAuthCache[username] != "" {
return api.userAuthCache[username], nil
}
// Get DB
user, err := api.db.Queries.GetUser(api.db.Ctx, username)
user, err := api.db.Queries.GetUser(ctx, username)
if err != nil {
log.Error("GetUser DB Error:", err)
return "", err
@@ -412,7 +413,7 @@ func (api *API) getUserAuthHash(username string) (string, error) {
return api.userAuthCache[username], nil
}
func (api *API) rotateAllAuthHashes() error {
func (api *API) rotateAllAuthHashes(ctx context.Context) error {
// Do Transaction
tx, err := api.db.DB.Begin()
if err != nil {
@@ -428,7 +429,7 @@ func (api *API) rotateAllAuthHashes() error {
}()
qtx := api.db.Queries.WithTx(tx)
users, err := qtx.GetUsers(api.db.Ctx)
users, err := qtx.GetUsers(ctx)
if err != nil {
return err
}
@@ -444,7 +445,7 @@ func (api *API) rotateAllAuthHashes() error {
// Update User
authHash := fmt.Sprintf("%x", rawAuthHash)
if _, err = qtx.UpdateUser(api.db.Ctx, database.UpdateUserParams{
if _, err = qtx.UpdateUser(ctx, database.UpdateUserParams{
UserID: user.ID,
AuthHash: &authHash,
Admin: user.Admin,

View File

@@ -22,7 +22,7 @@ func (api *API) createDownloadDocumentHandler(errorFunc func(*gin.Context, int,
}
// Get Document
document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID)
document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID)
if err != nil {
log.Error("GetDocument DB Error:", err)
errorFunc(c, http.StatusBadRequest, "Unknown Document")
@@ -68,7 +68,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
}
// Validate Document Exists in DB
document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID)
document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID)
if err != nil {
log.Error("GetDocument DB Error:", err)
errorFunc(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
@@ -117,7 +117,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
}
// Store First Metadata Result
if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{
if _, err = api.db.Queries.AddMetadata(c, database.AddMetadataParams{
DocumentID: document.ID,
Title: firstResult.Title,
Author: firstResult.Author,
@@ -132,7 +132,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
}
// Upsert Document
if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: document.ID,
Coverfile: &coverFile,
}); err != nil {

View File

@@ -91,7 +91,7 @@ func (api *API) koSetProgress(c *gin.Context) {
}
// Upsert Device
if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
if _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{
ID: rPosition.DeviceID,
UserID: auth.UserName,
DeviceName: rPosition.Device,
@@ -101,14 +101,14 @@ func (api *API) koSetProgress(c *gin.Context) {
}
// Upsert Document
if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
if _, err := api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: rPosition.DocumentID,
}); err != nil {
log.Error("UpsertDocument DB Error:", err)
}
// Create or Replace Progress
progress, err := api.db.Queries.UpdateProgress(api.db.Ctx, database.UpdateProgressParams{
progress, err := api.db.Queries.UpdateProgress(c, database.UpdateProgressParams{
Percentage: rPosition.Percentage,
DocumentID: rPosition.DocumentID,
DeviceID: rPosition.DeviceID,
@@ -140,7 +140,7 @@ func (api *API) koGetProgress(c *gin.Context) {
return
}
progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{
progress, err := api.db.Queries.GetDocumentProgress(c, database.GetDocumentProgressParams{
DocumentID: rDocID.DocumentID,
UserID: auth.UserName,
})
@@ -202,7 +202,7 @@ func (api *API) koAddActivities(c *gin.Context) {
// Upsert Documents
for _, doc := range allDocuments {
if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
if _, err := qtx.UpsertDocument(c, database.UpsertDocumentParams{
ID: doc,
}); err != nil {
log.Error("UpsertDocument DB Error:", err)
@@ -212,7 +212,7 @@ func (api *API) koAddActivities(c *gin.Context) {
}
// Upsert Device
if _, err = qtx.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
if _, err = qtx.UpsertDevice(c, database.UpsertDeviceParams{
ID: rActivity.DeviceID,
UserID: auth.UserName,
DeviceName: rActivity.Device,
@@ -225,7 +225,7 @@ func (api *API) koAddActivities(c *gin.Context) {
// Add All Activity
for _, item := range rActivity.Activity {
if _, err := qtx.AddActivity(api.db.Ctx, database.AddActivityParams{
if _, err := qtx.AddActivity(c, database.AddActivityParams{
UserID: auth.UserName,
DocumentID: item.DocumentID,
DeviceID: rActivity.DeviceID,
@@ -266,7 +266,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) {
}
// Upsert Device
if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
if _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{
ID: rCheckActivity.DeviceID,
UserID: auth.UserName,
DeviceName: rCheckActivity.Device,
@@ -278,7 +278,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) {
}
// Get Last Device Activity
lastActivity, err := api.db.Queries.GetLastActivity(api.db.Ctx, database.GetLastActivityParams{
lastActivity, err := api.db.Queries.GetLastActivity(c, database.GetLastActivityParams{
UserID: auth.UserName,
DeviceID: rCheckActivity.DeviceID,
})
@@ -329,7 +329,7 @@ func (api *API) koAddDocuments(c *gin.Context) {
// Upsert Documents
for _, doc := range rNewDocs.Documents {
_, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
_, err := qtx.UpsertDocument(c, database.UpsertDocumentParams{
ID: doc.ID,
Title: api.sanitizeInput(doc.Title),
Author: api.sanitizeInput(doc.Author),
@@ -371,7 +371,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
}
// Upsert Device
_, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
_, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{
ID: rCheckDocs.DeviceID,
UserID: auth.UserName,
DeviceName: rCheckDocs.Device,
@@ -384,7 +384,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
}
// Get Missing Documents
missingDocs, err := api.db.Queries.GetMissingDocuments(api.db.Ctx, rCheckDocs.Have)
missingDocs, err := api.db.Queries.GetMissingDocuments(c, rCheckDocs.Have)
if err != nil {
log.Error("GetMissingDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@@ -392,7 +392,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
}
// Get Deleted Documents
deletedDocIDs, err := api.db.Queries.GetDeletedDocuments(api.db.Ctx, rCheckDocs.Have)
deletedDocIDs, err := api.db.Queries.GetDeletedDocuments(c, rCheckDocs.Have)
if err != nil {
log.Error("GetDeletedDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@@ -407,7 +407,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
return
}
wantedDocs, err := api.db.Queries.GetWantedDocuments(api.db.Ctx, string(jsonHaves))
wantedDocs, err := api.db.Queries.GetWantedDocuments(c, string(jsonHaves))
if err != nil {
log.Error("GetWantedDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@@ -467,7 +467,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
}
// Validate Document Exists in DB
document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID)
document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID)
if err != nil {
log.Error("GetDocument DB Error:", err)
apiErrorPage(c, http.StatusBadRequest, "Unknown Document")
@@ -522,7 +522,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
}
// Upsert Document
if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: document.ID,
Md5: metadataInfo.MD5,
Words: metadataInfo.WordCount,

View File

@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
"reichard.io/antholume/opds"
"reichard.io/antholume/pkg/ptr"
)
var mimeMapping map[string]string = map[string]string{
@@ -77,11 +78,12 @@ func (api *API) opdsDocuments(c *gin.Context) {
}
// Get Documents
documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
Query: query,
Offset: (*qParams.Page - 1) * *qParams.Limit,
Limit: *qParams.Limit,
documents, err := api.db.Queries.GetDocumentsWithStats(c, database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
Query: query,
Deleted: ptr.Of(false),
Offset: (*qParams.Page - 1) * *qParams.Limit,
Limit: *qParams.Limit,
})
if err != nil {
log.Error("GetDocumentsWithStats DB Error:", err)

View File

@@ -55,6 +55,7 @@ func getTimeZones() []string {
// niceSeconds takes in an int (in seconds) and returns a string readable
// representation. For example 1928371 -> "22d 7h 39m 31s".
// Deprecated: Use formatters.FormatDuration
func niceSeconds(input int64) (result string) {
if input == 0 {
return "N/A"
@@ -85,6 +86,7 @@ func niceSeconds(input int64) (result string) {
// niceNumbers takes in an int and returns a string representation. For example
// 19823 -> "19.8k".
// Deprecated: Use formatters.FormatNumber
func niceNumbers(input int64) string {
if input == 0 {
return "0"