From 5f8a9b7b1419ee35aa50e71cd60d21bed4f5797e Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Sun, 10 Aug 2025 13:15:46 -0400 Subject: [PATCH] chore(db): use context & add db helper --- api/app-admin-routes.go | 55 +++++++++-------- api/app-routes.go | 92 +++++++++++++--------------- api/auth.go | 39 ++++++------ api/common.go | 8 +-- api/ko-routes.go | 32 +++++----- api/opds-routes.go | 12 ++-- api/utils.go | 2 + database/documents.go | 27 ++++++++ database/documents_test.go | 25 ++++---- database/manager.go | 29 ++++----- database/manager_test.go | 19 +++--- database/query.sql | 57 +++-------------- database/query.sql.go | 122 +++++++------------------------------ database/users_test.go | 41 +++++++------ server/server.go | 9 ++- 15 files changed, 241 insertions(+), 328 deletions(-) create mode 100644 database/documents.go diff --git a/api/app-admin-routes.go b/api/app-admin-routes.go index 323c015..d43645f 100644 --- a/api/app-admin-routes.go +++ b/api/app-admin-routes.go @@ -3,6 +3,7 @@ package api import ( "archive/zip" "bufio" + "context" "crypto/md5" "encoding/json" "fmt" @@ -112,7 +113,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) { // 2. Select all / deselect? case adminCacheTables: go func() { - err := api.db.CacheTempTables() + err := api.db.CacheTempTables(c) if err != nil { log.Error("Unable to cache temp tables: ", err) } @@ -122,7 +123,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) { return case adminBackup: // Vacuum - _, err := api.db.DB.ExecContext(api.db.Ctx, "VACUUM;") + _, err := api.db.DB.ExecContext(c, "VACUUM;") if err != nil { log.Error("Unable to vacuum DB: ", err) appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database") @@ -144,7 +145,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) { } } - err := api.createBackup(w, directories) + err := api.createBackup(c, w, directories) if err != nil { log.Error("Backup Error: ", err) } @@ -261,7 +262,7 @@ func (api *API) appGetAdminLogs(c *gin.Context) { func (api *API) appGetAdminUsers(c *gin.Context) { templateVars, _ := api.getBaseTemplateVars("admin-users", c) - users, err := api.db.Queries.GetUsers(api.db.Ctx) + users, err := api.db.Queries.GetUsers(c) if err != nil { log.Error("GetUsers DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err)) @@ -292,11 +293,11 @@ func (api *API) appUpdateAdminUsers(c *gin.Context) { var err error switch rUpdate.Operation { case opCreate: - err = api.createUser(rUpdate.User, rUpdate.Password, rUpdate.IsAdmin) + err = api.createUser(c, rUpdate.User, rUpdate.Password, rUpdate.IsAdmin) case opUpdate: - err = api.updateUser(rUpdate.User, rUpdate.Password, rUpdate.IsAdmin) + err = api.updateUser(c, rUpdate.User, rUpdate.Password, rUpdate.IsAdmin) case opDelete: - err = api.deleteUser(rUpdate.User) + err = api.deleteUser(c, rUpdate.User) default: appErrorPage(c, http.StatusNotFound, "Unknown user operation") return @@ -307,7 +308,7 @@ func (api *API) appUpdateAdminUsers(c *gin.Context) { return } - users, err := api.db.Queries.GetUsers(api.db.Ctx) + users, err := api.db.Queries.GetUsers(c) if err != nil { log.Error("GetUsers DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err)) @@ -448,7 +449,7 @@ func (api *API) appPerformAdminImport(c *gin.Context) { iResult.Name = fmt.Sprintf("%s - %s", *fileMeta.Author, *fileMeta.Title) // Check already exists - _, err = qtx.GetDocument(api.db.Ctx, *fileMeta.PartialMD5) + _, err = qtx.GetDocument(c, *fileMeta.PartialMD5) if err == nil { log.Warnf("document already exists: %s", *fileMeta.PartialMD5) iResult.Status = importExists @@ -492,7 +493,7 @@ func (api *API) appPerformAdminImport(c *gin.Context) { } // Upsert document - if _, err = qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + if _, err = qtx.UpsertDocument(c, database.UpsertDocumentParams{ ID: *fileMeta.PartialMD5, Title: fileMeta.Title, Author: fileMeta.Author, @@ -627,7 +628,7 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte // Save Backup File w := bufio.NewWriter(backupFile) - err = api.createBackup(w, []string{"covers", "documents"}) + err = api.createBackup(c, w, []string{"covers", "documents"}) if err != nil { log.Error("Unable to save backup file: ", err) appErrorPage(c, http.StatusInternalServerError, "Unable to save backup file") @@ -650,13 +651,13 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte } // Reinit DB - if err := api.db.Reload(); err != nil { + if err := api.db.Reload(c); err != nil { appErrorPage(c, http.StatusInternalServerError, "Unable to reload DB") log.Panicf("Unable to reload DB: %v", err) } // Rotate Auth Hashes - if err := api.rotateAllAuthHashes(); err != nil { + if err := api.rotateAllAuthHashes(c); err != nil { appErrorPage(c, http.StatusInternalServerError, "Unable to rotate hashes") log.Panicf("Unable to rotate auth hashes: %v", err) } @@ -717,9 +718,9 @@ func (api *API) removeData() error { return nil } -func (api *API) createBackup(w io.Writer, directories []string) error { +func (api *API) createBackup(ctx context.Context, w io.Writer, directories []string) error { // Vacuum DB - _, err := api.db.DB.ExecContext(api.db.Ctx, "VACUUM;") + _, err := api.db.DB.ExecContext(ctx, "VACUUM;") if err != nil { return errors.Wrap(err, "Unable to vacuum database") } @@ -792,8 +793,8 @@ func (api *API) createBackup(w io.Writer, directories []string) error { return nil } -func (api *API) isLastAdmin(userID string) (bool, error) { - allUsers, err := api.db.Queries.GetUsers(api.db.Ctx) +func (api *API) isLastAdmin(ctx context.Context, userID string) (bool, error) { + allUsers, err := api.db.Queries.GetUsers(ctx) if err != nil { return false, errors.Wrap(err, fmt.Sprintf("GetUsers DB Error: %v", err)) } @@ -809,7 +810,7 @@ func (api *API) isLastAdmin(userID string) (bool, error) { return !hasAdmin, nil } -func (api *API) createUser(user string, rawPassword *string, isAdmin *bool) error { +func (api *API) createUser(ctx context.Context, user string, rawPassword *string, isAdmin *bool) error { // Validate Necessary Parameters if rawPassword == nil || *rawPassword == "" { return fmt.Errorf("password can't be empty") @@ -844,7 +845,7 @@ func (api *API) createUser(user string, rawPassword *string, isAdmin *bool) erro createParams.AuthHash = &authHash // Create user in DB - if rows, err := api.db.Queries.CreateUser(api.db.Ctx, createParams); err != nil { + if rows, err := api.db.Queries.CreateUser(ctx, createParams); err != nil { log.Error("CreateUser DB Error:", err) return fmt.Errorf("unable to create user") } else if rows == 0 { @@ -855,7 +856,7 @@ func (api *API) createUser(user string, rawPassword *string, isAdmin *bool) erro return nil } -func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) error { +func (api *API) updateUser(ctx context.Context, user string, rawPassword *string, isAdmin *bool) error { // Validate Necessary Parameters if rawPassword == nil && isAdmin == nil { return fmt.Errorf("nothing to update") @@ -870,7 +871,7 @@ func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) erro if isAdmin != nil { updateParams.Admin = *isAdmin } else { - user, err := api.db.Queries.GetUser(api.db.Ctx, user) + user, err := api.db.Queries.GetUser(ctx, user) if err != nil { return errors.Wrap(err, fmt.Sprintf("GetUser DB Error: %v", err)) } @@ -878,7 +879,7 @@ func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) erro } // Check Admins - Disallow Demotion - if isLast, err := api.isLastAdmin(user); err != nil { + if isLast, err := api.isLastAdmin(ctx, user); err != nil { return err } else if isLast && !updateParams.Admin { return fmt.Errorf("unable to demote %s - last admin", user) @@ -908,7 +909,7 @@ func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) erro } // Update User - _, err := api.db.Queries.UpdateUser(api.db.Ctx, updateParams) + _, err := api.db.Queries.UpdateUser(ctx, updateParams) if err != nil { return errors.Wrap(err, fmt.Sprintf("UpdateUser DB Error: %v", err)) } @@ -916,9 +917,9 @@ func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) erro return nil } -func (api *API) deleteUser(user string) error { +func (api *API) deleteUser(ctx context.Context, user string) error { // Check Admins - if isLast, err := api.isLastAdmin(user); err != nil { + if isLast, err := api.isLastAdmin(ctx, user); err != nil { return err } else if isLast { return fmt.Errorf("unable to delete %s - last admin", user) @@ -934,13 +935,13 @@ func (api *API) deleteUser(user string) error { // Save Backup File (DB Only) w := bufio.NewWriter(backupFile) - err = api.createBackup(w, []string{}) + err = api.createBackup(ctx, w, []string{}) if err != nil { return err } // Delete User - _, err = api.db.Queries.DeleteUser(api.db.Ctx, user) + _, err = api.db.Queries.DeleteUser(ctx, user) if err != nil { return errors.Wrap(err, fmt.Sprintf("DeleteUser DB Error: %v", err)) } diff --git a/api/app-routes.go b/api/app-routes.go index a5f37d3..26ba7fe 100644 --- a/api/app-routes.go +++ b/api/app-routes.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/md5" "database/sql" "fmt" @@ -22,6 +23,7 @@ import ( "golang.org/x/exp/slices" "reichard.io/antholume/database" "reichard.io/antholume/metadata" + "reichard.io/antholume/pkg/ptr" "reichard.io/antholume/search" ) @@ -109,11 +111,12 @@ func (api *API) appGetDocuments(c *gin.Context) { query = &search } - documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{ - UserID: auth.UserName, - Query: query, - Offset: (*qParams.Page - 1) * *qParams.Limit, - Limit: *qParams.Limit, + documents, err := api.db.Queries.GetDocumentsWithStats(c, database.GetDocumentsWithStatsParams{ + UserID: auth.UserName, + Query: query, + Deleted: ptr.Of(false), + Offset: (*qParams.Page - 1) * *qParams.Limit, + Limit: *qParams.Limit, }) if err != nil { log.Error("GetDocumentsWithStats DB Error: ", err) @@ -121,14 +124,14 @@ func (api *API) appGetDocuments(c *gin.Context) { return } - length, err := api.db.Queries.GetDocumentsSize(api.db.Ctx, query) + length, err := api.db.Queries.GetDocumentsSize(c, query) if err != nil { log.Error("GetDocumentsSize DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err)) return } - if err = api.getDocumentsWordCount(documents); err != nil { + if err = api.getDocumentsWordCount(c, documents); err != nil { log.Error("Unable to Get Word Counts: ", err) } @@ -160,13 +163,10 @@ func (api *API) appGetDocument(c *gin.Context) { return } - document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{ - UserID: auth.UserName, - DocumentID: rDocID.DocumentID, - }) + document, err := api.db.GetDocument(c, rDocID.DocumentID, auth.UserName) if err != nil { - log.Error("GetDocumentWithStats DB Error: ", err) - appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsWithStats DB Error: %v", err)) + log.Error("GetDocument DB Error: ", err) + appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) return } @@ -192,7 +192,7 @@ func (api *API) appGetProgress(c *gin.Context) { progressFilter.DocumentID = *qParams.Document } - progress, err := api.db.Queries.GetProgress(api.db.Ctx, progressFilter) + progress, err := api.db.Queries.GetProgress(c, progressFilter) if err != nil { log.Error("GetProgress DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) @@ -219,7 +219,7 @@ func (api *API) appGetActivity(c *gin.Context) { activityFilter.DocumentID = *qParams.Document } - activity, err := api.db.Queries.GetActivity(api.db.Ctx, activityFilter) + activity, err := api.db.Queries.GetActivity(c, activityFilter) if err != nil { log.Error("GetActivity DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) @@ -235,7 +235,7 @@ func (api *API) appGetHome(c *gin.Context) { templateVars, auth := api.getBaseTemplateVars("home", c) start := time.Now() - graphData, err := api.db.Queries.GetDailyReadStats(api.db.Ctx, auth.UserName) + graphData, err := api.db.Queries.GetDailyReadStats(c, auth.UserName) if err != nil { log.Error("GetDailyReadStats DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDailyReadStats DB Error: %v", err)) @@ -244,7 +244,7 @@ func (api *API) appGetHome(c *gin.Context) { log.Debug("GetDailyReadStats DB Performance: ", time.Since(start)) start = time.Now() - databaseInfo, err := api.db.Queries.GetDatabaseInfo(api.db.Ctx, auth.UserName) + databaseInfo, err := api.db.Queries.GetDatabaseInfo(c, auth.UserName) if err != nil { log.Error("GetDatabaseInfo DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDatabaseInfo DB Error: %v", err)) @@ -253,7 +253,7 @@ func (api *API) appGetHome(c *gin.Context) { log.Debug("GetDatabaseInfo DB Performance: ", time.Since(start)) start = time.Now() - streaks, err := api.db.Queries.GetUserStreaks(api.db.Ctx, auth.UserName) + streaks, err := api.db.Queries.GetUserStreaks(c, auth.UserName) if err != nil { log.Error("GetUserStreaks DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStreaks DB Error: %v", err)) @@ -262,7 +262,7 @@ func (api *API) appGetHome(c *gin.Context) { log.Debug("GetUserStreaks DB Performance: ", time.Since(start)) start = time.Now() - userStatistics, err := api.db.Queries.GetUserStatistics(api.db.Ctx) + userStatistics, err := api.db.Queries.GetUserStatistics(c) if err != nil { log.Error("GetUserStatistics DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStatistics DB Error: %v", err)) @@ -283,14 +283,14 @@ func (api *API) appGetHome(c *gin.Context) { func (api *API) appGetSettings(c *gin.Context) { templateVars, auth := api.getBaseTemplateVars("settings", c) - user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName) + user, err := api.db.Queries.GetUser(c, auth.UserName) if err != nil { log.Error("GetUser DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) return } - devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName) + devices, err := api.db.Queries.GetDevices(c, auth.UserName) if err != nil { log.Error("GetDevices DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) @@ -368,7 +368,7 @@ func (api *API) appGetDocumentProgress(c *gin.Context) { return } - progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{ + progress, err := api.db.Queries.GetDocumentProgress(c, database.GetDocumentProgressParams{ DocumentID: rDoc.DocumentID, UserID: auth.UserName, }) @@ -378,13 +378,10 @@ func (api *API) appGetDocumentProgress(c *gin.Context) { return } - document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{ - UserID: auth.UserName, - DocumentID: rDoc.DocumentID, - }) + document, err := api.db.GetDocument(c, rDoc.DocumentID, auth.UserName) if err != nil { - log.Error("GetDocumentWithStats DB Error: ", err) - appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentWithStats DB Error: %v", err)) + log.Error("GetDocument DB Error: ", err) + appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) return } @@ -404,7 +401,7 @@ func (api *API) appGetDevices(c *gin.Context) { auth = data.(authData) } - devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName) + devices, err := api.db.Queries.GetDevices(c, auth.UserName) if err != nil && err != sql.ErrNoRows { log.Error("GetDevices DB Error: ", err) @@ -455,7 +452,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) { } // Check Already Exists - _, err = api.db.Queries.GetDocument(api.db.Ctx, *metadataInfo.PartialMD5) + _, err = api.db.Queries.GetDocument(c, *metadataInfo.PartialMD5) if err == nil { log.Warnf("document already exists: %s", *metadataInfo.PartialMD5) c.Redirect(http.StatusFound, fmt.Sprintf("./documents/%s", *metadataInfo.PartialMD5)) @@ -483,7 +480,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) { } // Upsert Document - if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ ID: *metadataInfo.PartialMD5, Title: metadataInfo.Title, Author: metadataInfo.Author, @@ -573,7 +570,7 @@ func (api *API) appEditDocument(c *gin.Context) { coverFileName = &fileName } else if rDocEdit.CoverGBID != nil { - var coverDir string = filepath.Join(api.cfg.DataPath, "covers") + coverDir := filepath.Join(api.cfg.DataPath, "covers") fileName, err := metadata.CacheCover(*rDocEdit.CoverGBID, coverDir, rDocID.DocumentID, true) if err == nil { coverFileName = fileName @@ -581,7 +578,7 @@ func (api *API) appEditDocument(c *gin.Context) { } // Update Document - if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + if _, err := api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ ID: rDocID.DocumentID, Title: api.sanitizeInput(rDocEdit.Title), Author: api.sanitizeInput(rDocEdit.Author), @@ -605,7 +602,7 @@ func (api *API) appDeleteDocument(c *gin.Context) { appErrorPage(c, http.StatusNotFound, "Invalid document") return } - changed, err := api.db.Queries.DeleteDocument(api.db.Ctx, rDocID.DocumentID) + changed, err := api.db.Queries.DeleteDocument(c, rDocID.DocumentID) if err != nil { log.Error("DeleteDocument DB Error") appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("DeleteDocument DB Error: %v", err)) @@ -667,7 +664,7 @@ func (api *API) appIdentifyDocument(c *gin.Context) { firstResult := metadataResults[0] // Store First Metadata Result - if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{ + if _, err = api.db.Queries.AddMetadata(c, database.AddMetadataParams{ DocumentID: rDocID.DocumentID, Title: firstResult.Title, Author: firstResult.Author, @@ -686,13 +683,10 @@ func (api *API) appIdentifyDocument(c *gin.Context) { templateVars["MetadataError"] = "No Metadata Found" } - document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{ - UserID: auth.UserName, - DocumentID: rDocID.DocumentID, - }) + document, err := api.db.GetDocument(c, rDocID.DocumentID, auth.UserName) if err != nil { - log.Error("GetDocumentWithStats DB Error: ", err) - appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentWithStats DB Error: %v", err)) + log.Error("GetDocument DB Error: ", err) + appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) return } @@ -817,7 +811,7 @@ func (api *API) appSaveNewDocument(c *gin.Context) { sendDownloadMessage("Saving to database...", gin.H{"Progress": 99}) // Upsert Document - if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ ID: *metadata.PartialMD5, Title: &docTitle, Author: &docAuthor, @@ -864,7 +858,7 @@ func (api *API) appEditSettings(c *gin.Context) { // Set New Password if rUserSettings.Password != nil && rUserSettings.NewPassword != nil { password := fmt.Sprintf("%x", md5.Sum([]byte(*rUserSettings.Password))) - data := api.authorizeCredentials(auth.UserName, password) + data := api.authorizeCredentials(c, auth.UserName, password) if data == nil { templateVars["PasswordErrorMessage"] = "Invalid Password" } else { @@ -886,7 +880,7 @@ func (api *API) appEditSettings(c *gin.Context) { } // Update User - _, err := api.db.Queries.UpdateUser(api.db.Ctx, newUserSettings) + _, err := api.db.Queries.UpdateUser(c, newUserSettings) if err != nil { log.Error("UpdateUser DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpdateUser DB Error: %v", err)) @@ -894,7 +888,7 @@ func (api *API) appEditSettings(c *gin.Context) { } // Get User - user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName) + user, err := api.db.Queries.GetUser(c, auth.UserName) if err != nil { log.Error("GetUser DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) @@ -902,7 +896,7 @@ func (api *API) appEditSettings(c *gin.Context) { } // Get Devices - devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName) + devices, err := api.db.Queries.GetDevices(c, auth.UserName) if err != nil { log.Error("GetDevices DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) @@ -921,7 +915,7 @@ func (api *API) appDemoModeError(c *gin.Context) { appErrorPage(c, http.StatusUnauthorized, "Not Allowed in Demo Mode") } -func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStatsRow) error { +func (api *API) getDocumentsWordCount(ctx context.Context, documents []database.GetDocumentsWithStatsRow) error { // Do Transaction tx, err := api.db.DB.Begin() if err != nil { @@ -944,7 +938,7 @@ func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStats if err != nil { log.Warn("Word Count Error: ", err) } else { - if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + if _, err := qtx.UpsertDocument(ctx, database.UpsertDocumentParams{ ID: item.ID, Words: wordCount, }); err != nil { @@ -1005,7 +999,7 @@ func bindQueryParams(c *gin.Context, defaultLimit int64) queryParams { } func appErrorPage(c *gin.Context, errorCode int, errorMessage string) { - var errorHuman string = "We're not even sure what happened." + errorHuman := "We're not even sure what happened." switch errorCode { case http.StatusInternalServerError: diff --git a/api/auth.go b/api/auth.go index 8dda00a..e3fa854 100644 --- a/api/auth.go +++ b/api/auth.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/md5" "fmt" "net/http" @@ -28,8 +29,8 @@ type authKOHeader struct { AuthKey string `header:"x-auth-key"` } -func (api *API) authorizeCredentials(username string, password string) (auth *authData) { - user, err := api.db.Queries.GetUser(api.db.Ctx, username) +func (api *API) authorizeCredentials(ctx context.Context, username string, password string) (auth *authData) { + user, err := api.db.Queries.GetUser(ctx, username) if err != nil { return } @@ -52,7 +53,7 @@ func (api *API) authKOMiddleware(c *gin.Context) { session := sessions.Default(c) // Check Session First - if auth, ok := api.getSession(session); ok { + if auth, ok := api.getSession(c, session); ok { c.Set("Authorization", auth) c.Header("Cache-Control", "private") c.Next() @@ -71,7 +72,7 @@ func (api *API) authKOMiddleware(c *gin.Context) { return } - authData := api.authorizeCredentials(rHeader.AuthUser, rHeader.AuthKey) + authData := api.authorizeCredentials(c, rHeader.AuthUser, rHeader.AuthKey) if authData == nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) return @@ -100,7 +101,7 @@ func (api *API) authOPDSMiddleware(c *gin.Context) { // Validate Auth password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword))) - authData := api.authorizeCredentials(user, password) + authData := api.authorizeCredentials(c, user, password) if authData == nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) return @@ -115,7 +116,7 @@ func (api *API) authWebAppMiddleware(c *gin.Context) { session := sessions.Default(c) // Check Session - if auth, ok := api.getSession(session); ok { + if auth, ok := api.getSession(c, session); ok { c.Set("Authorization", auth) c.Header("Cache-Control", "private") c.Next() @@ -153,7 +154,7 @@ func (api *API) appAuthLogin(c *gin.Context) { // MD5 - KOSync Compatiblity password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword))) - authData := api.authorizeCredentials(username, password) + authData := api.authorizeCredentials(c, username, password) if authData == nil { templateVars["Error"] = "Invalid Credentials" c.HTML(http.StatusUnauthorized, "page/login", templateVars) @@ -208,7 +209,7 @@ func (api *API) appAuthRegister(c *gin.Context) { } // Get current users - currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx) + currentUsers, err := api.db.Queries.GetUsers(c) if err != nil { log.Error("Failed to check all users: ", err) templateVars["Error"] = "Failed to Create User" @@ -224,7 +225,7 @@ func (api *API) appAuthRegister(c *gin.Context) { // Create user in DB authHash := fmt.Sprintf("%x", rawAuthHash) - if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{ + if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{ ID: username, Pass: &hashedPassword, AuthHash: &authHash, @@ -242,7 +243,7 @@ func (api *API) appAuthRegister(c *gin.Context) { } // Get user - user, err := api.db.Queries.GetUser(api.db.Ctx, username) + user, err := api.db.Queries.GetUser(c, username) if err != nil { log.Error("GetUser DB Error:", err) templateVars["Error"] = "Registration Disabled or User Already Exists" @@ -312,7 +313,7 @@ func (api *API) koAuthRegister(c *gin.Context) { } // Get current users - currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx) + currentUsers, err := api.db.Queries.GetUsers(c) if err != nil { log.Error("Failed to check all users: ", err) apiErrorPage(c, http.StatusBadRequest, "Failed to Create User") @@ -327,7 +328,7 @@ func (api *API) koAuthRegister(c *gin.Context) { // Create user authHash := fmt.Sprintf("%x", rawAuthHash) - if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{ + if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{ ID: rUser.Username, Pass: &hashedPassword, AuthHash: &authHash, @@ -347,7 +348,7 @@ func (api *API) koAuthRegister(c *gin.Context) { }) } -func (api *API) getSession(session sessions.Session) (auth authData, ok bool) { +func (api *API) getSession(ctx context.Context, session sessions.Session) (auth authData, ok bool) { // Get Session authorizedUser := session.Get("authorizedUser") isAdmin := session.Get("isAdmin") @@ -365,7 +366,7 @@ func (api *API) getSession(session sessions.Session) (auth authData, ok bool) { } // Validate Auth Hash - correctAuthHash, err := api.getUserAuthHash(auth.UserName) + correctAuthHash, err := api.getUserAuthHash(ctx, auth.UserName) if err != nil || correctAuthHash != auth.AuthHash { return } @@ -393,14 +394,14 @@ func (api *API) setSession(session sessions.Session, auth authData) error { return session.Save() } -func (api *API) getUserAuthHash(username string) (string, error) { +func (api *API) getUserAuthHash(ctx context.Context, username string) (string, error) { // Return Cache if api.userAuthCache[username] != "" { return api.userAuthCache[username], nil } // Get DB - user, err := api.db.Queries.GetUser(api.db.Ctx, username) + user, err := api.db.Queries.GetUser(ctx, username) if err != nil { log.Error("GetUser DB Error:", err) return "", err @@ -412,7 +413,7 @@ func (api *API) getUserAuthHash(username string) (string, error) { return api.userAuthCache[username], nil } -func (api *API) rotateAllAuthHashes() error { +func (api *API) rotateAllAuthHashes(ctx context.Context) error { // Do Transaction tx, err := api.db.DB.Begin() if err != nil { @@ -428,7 +429,7 @@ func (api *API) rotateAllAuthHashes() error { }() qtx := api.db.Queries.WithTx(tx) - users, err := qtx.GetUsers(api.db.Ctx) + users, err := qtx.GetUsers(ctx) if err != nil { return err } @@ -444,7 +445,7 @@ func (api *API) rotateAllAuthHashes() error { // Update User authHash := fmt.Sprintf("%x", rawAuthHash) - if _, err = qtx.UpdateUser(api.db.Ctx, database.UpdateUserParams{ + if _, err = qtx.UpdateUser(ctx, database.UpdateUserParams{ UserID: user.ID, AuthHash: &authHash, Admin: user.Admin, diff --git a/api/common.go b/api/common.go index 598ff68..0888ed4 100644 --- a/api/common.go +++ b/api/common.go @@ -22,7 +22,7 @@ func (api *API) createDownloadDocumentHandler(errorFunc func(*gin.Context, int, } // Get Document - document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID) + document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID) if err != nil { log.Error("GetDocument DB Error:", err) errorFunc(c, http.StatusBadRequest, "Unknown Document") @@ -68,7 +68,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string)) } // Validate Document Exists in DB - document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID) + document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID) if err != nil { log.Error("GetDocument DB Error:", err) errorFunc(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) @@ -117,7 +117,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string)) } // Store First Metadata Result - if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{ + if _, err = api.db.Queries.AddMetadata(c, database.AddMetadataParams{ DocumentID: document.ID, Title: firstResult.Title, Author: firstResult.Author, @@ -132,7 +132,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string)) } // Upsert Document - if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ ID: document.ID, Coverfile: &coverFile, }); err != nil { diff --git a/api/ko-routes.go b/api/ko-routes.go index 3e3645f..27ad78d 100644 --- a/api/ko-routes.go +++ b/api/ko-routes.go @@ -91,7 +91,7 @@ func (api *API) koSetProgress(c *gin.Context) { } // Upsert Device - if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ + if _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{ ID: rPosition.DeviceID, UserID: auth.UserName, DeviceName: rPosition.Device, @@ -101,14 +101,14 @@ func (api *API) koSetProgress(c *gin.Context) { } // Upsert Document - if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + if _, err := api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ ID: rPosition.DocumentID, }); err != nil { log.Error("UpsertDocument DB Error:", err) } // Create or Replace Progress - progress, err := api.db.Queries.UpdateProgress(api.db.Ctx, database.UpdateProgressParams{ + progress, err := api.db.Queries.UpdateProgress(c, database.UpdateProgressParams{ Percentage: rPosition.Percentage, DocumentID: rPosition.DocumentID, DeviceID: rPosition.DeviceID, @@ -140,7 +140,7 @@ func (api *API) koGetProgress(c *gin.Context) { return } - progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{ + progress, err := api.db.Queries.GetDocumentProgress(c, database.GetDocumentProgressParams{ DocumentID: rDocID.DocumentID, UserID: auth.UserName, }) @@ -202,7 +202,7 @@ func (api *API) koAddActivities(c *gin.Context) { // Upsert Documents for _, doc := range allDocuments { - if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + if _, err := qtx.UpsertDocument(c, database.UpsertDocumentParams{ ID: doc, }); err != nil { log.Error("UpsertDocument DB Error:", err) @@ -212,7 +212,7 @@ func (api *API) koAddActivities(c *gin.Context) { } // Upsert Device - if _, err = qtx.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ + if _, err = qtx.UpsertDevice(c, database.UpsertDeviceParams{ ID: rActivity.DeviceID, UserID: auth.UserName, DeviceName: rActivity.Device, @@ -225,7 +225,7 @@ func (api *API) koAddActivities(c *gin.Context) { // Add All Activity for _, item := range rActivity.Activity { - if _, err := qtx.AddActivity(api.db.Ctx, database.AddActivityParams{ + if _, err := qtx.AddActivity(c, database.AddActivityParams{ UserID: auth.UserName, DocumentID: item.DocumentID, DeviceID: rActivity.DeviceID, @@ -266,7 +266,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) { } // Upsert Device - if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ + if _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{ ID: rCheckActivity.DeviceID, UserID: auth.UserName, DeviceName: rCheckActivity.Device, @@ -278,7 +278,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) { } // Get Last Device Activity - lastActivity, err := api.db.Queries.GetLastActivity(api.db.Ctx, database.GetLastActivityParams{ + lastActivity, err := api.db.Queries.GetLastActivity(c, database.GetLastActivityParams{ UserID: auth.UserName, DeviceID: rCheckActivity.DeviceID, }) @@ -329,7 +329,7 @@ func (api *API) koAddDocuments(c *gin.Context) { // Upsert Documents for _, doc := range rNewDocs.Documents { - _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + _, err := qtx.UpsertDocument(c, database.UpsertDocumentParams{ ID: doc.ID, Title: api.sanitizeInput(doc.Title), Author: api.sanitizeInput(doc.Author), @@ -371,7 +371,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) { } // Upsert Device - _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ + _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{ ID: rCheckDocs.DeviceID, UserID: auth.UserName, DeviceName: rCheckDocs.Device, @@ -384,7 +384,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) { } // Get Missing Documents - missingDocs, err := api.db.Queries.GetMissingDocuments(api.db.Ctx, rCheckDocs.Have) + missingDocs, err := api.db.Queries.GetMissingDocuments(c, rCheckDocs.Have) if err != nil { log.Error("GetMissingDocuments DB Error", err) apiErrorPage(c, http.StatusBadRequest, "Invalid Request") @@ -392,7 +392,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) { } // Get Deleted Documents - deletedDocIDs, err := api.db.Queries.GetDeletedDocuments(api.db.Ctx, rCheckDocs.Have) + deletedDocIDs, err := api.db.Queries.GetDeletedDocuments(c, rCheckDocs.Have) if err != nil { log.Error("GetDeletedDocuments DB Error", err) apiErrorPage(c, http.StatusBadRequest, "Invalid Request") @@ -407,7 +407,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) { return } - wantedDocs, err := api.db.Queries.GetWantedDocuments(api.db.Ctx, string(jsonHaves)) + wantedDocs, err := api.db.Queries.GetWantedDocuments(c, string(jsonHaves)) if err != nil { log.Error("GetWantedDocuments DB Error", err) apiErrorPage(c, http.StatusBadRequest, "Invalid Request") @@ -467,7 +467,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) { } // Validate Document Exists in DB - document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID) + document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID) if err != nil { log.Error("GetDocument DB Error:", err) apiErrorPage(c, http.StatusBadRequest, "Unknown Document") @@ -522,7 +522,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) { } // Upsert Document - if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ + if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ ID: document.ID, Md5: metadataInfo.MD5, Words: metadataInfo.WordCount, diff --git a/api/opds-routes.go b/api/opds-routes.go index 31434d4..17cceea 100644 --- a/api/opds-routes.go +++ b/api/opds-routes.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "reichard.io/antholume/database" "reichard.io/antholume/opds" + "reichard.io/antholume/pkg/ptr" ) var mimeMapping map[string]string = map[string]string{ @@ -77,11 +78,12 @@ func (api *API) opdsDocuments(c *gin.Context) { } // Get Documents - documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{ - UserID: auth.UserName, - Query: query, - Offset: (*qParams.Page - 1) * *qParams.Limit, - Limit: *qParams.Limit, + documents, err := api.db.Queries.GetDocumentsWithStats(c, database.GetDocumentsWithStatsParams{ + UserID: auth.UserName, + Query: query, + Deleted: ptr.Of(false), + Offset: (*qParams.Page - 1) * *qParams.Limit, + Limit: *qParams.Limit, }) if err != nil { log.Error("GetDocumentsWithStats DB Error:", err) diff --git a/api/utils.go b/api/utils.go index d52d281..3fb3e4a 100644 --- a/api/utils.go +++ b/api/utils.go @@ -55,6 +55,7 @@ func getTimeZones() []string { // niceSeconds takes in an int (in seconds) and returns a string readable // representation. For example 1928371 -> "22d 7h 39m 31s". +// Deprecated: Use formatters.FormatDuration func niceSeconds(input int64) (result string) { if input == 0 { return "N/A" @@ -85,6 +86,7 @@ func niceSeconds(input int64) (result string) { // niceNumbers takes in an int and returns a string representation. For example // 19823 -> "19.8k". +// Deprecated: Use formatters.FormatNumber func niceNumbers(input int64) string { if input == 0 { return "0" diff --git a/database/documents.go b/database/documents.go new file mode 100644 index 0000000..c8f358d --- /dev/null +++ b/database/documents.go @@ -0,0 +1,27 @@ +package database + +import ( + "context" + "fmt" + + "reichard.io/antholume/pkg/ptr" + "reichard.io/antholume/pkg/sliceutils" +) + +func (d *DBManager) GetDocument(ctx context.Context, docID, userID string) (*GetDocumentsWithStatsRow, error) { + documents, err := d.Queries.GetDocumentsWithStats(ctx, GetDocumentsWithStatsParams{ + ID: ptr.Of(docID), + UserID: userID, + Limit: 1, + }) + if err != nil { + return nil, err + } + + document, found := sliceutils.First(documents) + if !found { + return nil, fmt.Errorf("document not found: %s", docID) + } + + return &document, nil +} diff --git a/database/documents_test.go b/database/documents_test.go index a688448..cda01b1 100644 --- a/database/documents_test.go +++ b/database/documents_test.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "testing" @@ -26,7 +27,7 @@ func (suite *DocumentsTestSuite) SetupTest() { suite.dbm = NewMgr(&cfg) // Create Document - _, err := suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{ + _, err := suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ ID: documentID, Title: &documentTitle, Author: &documentAuthor, @@ -42,7 +43,7 @@ func (suite *DocumentsTestSuite) SetupTest() { // - 󰊕 (q *Queries) GetDocumentsWithStats // - 󰊕 (q *Queries) GetMissingDocuments func (suite *DocumentsTestSuite) TestGetDocument() { - doc, err := suite.dbm.Queries.GetDocument(suite.dbm.Ctx, documentID) + doc, err := suite.dbm.Queries.GetDocument(context.Background(), documentID) suite.Nil(err, "should have nil err") suite.Equal(documentID, doc.ID, "should have changed the document") } @@ -50,7 +51,7 @@ func (suite *DocumentsTestSuite) TestGetDocument() { func (suite *DocumentsTestSuite) TestUpsertDocument() { testDocID := "docid1" - doc, err := suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{ + doc, err := suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ ID: testDocID, Title: &documentTitle, Author: &documentAuthor, @@ -63,51 +64,51 @@ func (suite *DocumentsTestSuite) TestUpsertDocument() { } func (suite *DocumentsTestSuite) TestDeleteDocument() { - changed, err := suite.dbm.Queries.DeleteDocument(suite.dbm.Ctx, documentID) + changed, err := suite.dbm.Queries.DeleteDocument(context.Background(), documentID) suite.Nil(err, "should have nil err") suite.Equal(int64(1), changed, "should have changed the document") - doc, err := suite.dbm.Queries.GetDocument(suite.dbm.Ctx, documentID) + doc, err := suite.dbm.Queries.GetDocument(context.Background(), documentID) suite.Nil(err, "should have nil err") suite.True(doc.Deleted, "should have deleted the document") } func (suite *DocumentsTestSuite) TestGetDeletedDocuments() { - changed, err := suite.dbm.Queries.DeleteDocument(suite.dbm.Ctx, documentID) + changed, err := suite.dbm.Queries.DeleteDocument(context.Background(), documentID) suite.Nil(err, "should have nil err") suite.Equal(int64(1), changed, "should have changed the document") - deletedDocs, err := suite.dbm.Queries.GetDeletedDocuments(suite.dbm.Ctx, []string{documentID}) + deletedDocs, err := suite.dbm.Queries.GetDeletedDocuments(context.Background(), []string{documentID}) suite.Nil(err, "should have nil err") suite.Len(deletedDocs, 1, "should have one deleted document") } // TODO - Convert GetWantedDocuments -> (sqlc.slice('document_ids')); func (suite *DocumentsTestSuite) TestGetWantedDocuments() { - wantedDocs, err := suite.dbm.Queries.GetWantedDocuments(suite.dbm.Ctx, fmt.Sprintf("[\"%s\"]", documentID)) + wantedDocs, err := suite.dbm.Queries.GetWantedDocuments(context.Background(), fmt.Sprintf("[\"%s\"]", documentID)) suite.Nil(err, "should have nil err") suite.Len(wantedDocs, 1, "should have one wanted document") } func (suite *DocumentsTestSuite) TestGetMissingDocuments() { // Create Document - _, err := suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{ + _, err := suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ ID: documentID, Filepath: &documentFilepath, }) suite.NoError(err) - missingDocs, err := suite.dbm.Queries.GetMissingDocuments(suite.dbm.Ctx, []string{documentID}) + missingDocs, err := suite.dbm.Queries.GetMissingDocuments(context.Background(), []string{documentID}) suite.Nil(err, "should have nil err") suite.Len(missingDocs, 0, "should have no wanted document") - missingDocs, err = suite.dbm.Queries.GetMissingDocuments(suite.dbm.Ctx, []string{"other"}) + missingDocs, err = suite.dbm.Queries.GetMissingDocuments(context.Background(), []string{"other"}) suite.Nil(err, "should have nil err") suite.Len(missingDocs, 1, "should have one missing document") suite.Equal(documentID, missingDocs[0].ID, "should have missing doc") // TODO - https://github.com/sqlc-dev/sqlc/issues/3451 - // missingDocs, err = suite.dbm.Queries.GetMissingDocuments(suite.dbm.Ctx, []string{}) + // missingDocs, err = suite.dbm.Queries.GetMissingDocuments(context.Background(), []string{}) // suite.Nil(err, "should have nil err") // suite.Len(missingDocs, 1, "should have one missing document") // suite.Equal(documentID, missingDocs[0].ID, "should have missing doc") diff --git a/database/manager.go b/database/manager.go index be4722f..6b88d6e 100644 --- a/database/manager.go +++ b/database/manager.go @@ -5,7 +5,6 @@ import ( "database/sql" "database/sql/driver" "embed" - _ "embed" "errors" "fmt" "path/filepath" @@ -20,7 +19,6 @@ import ( type DBManager struct { DB *sql.DB - Ctx context.Context Queries *Queries cfg *config.Config } @@ -54,12 +52,9 @@ func init() { // NewMgr Returns an initialized manager func NewMgr(c *config.Config) *DBManager { // Create Manager - dbm := &DBManager{ - Ctx: context.Background(), - cfg: c, - } + dbm := &DBManager{cfg: c} - if err := dbm.init(); err != nil { + if err := dbm.init(context.Background()); err != nil { log.Panic("Unable to init DB") } @@ -67,7 +62,7 @@ func NewMgr(c *config.Config) *DBManager { } // init loads the DB manager -func (dbm *DBManager) init() error { +func (dbm *DBManager) init(ctx context.Context) error { // Build DB Location var dbLocation string switch dbm.cfg.DBType { @@ -113,14 +108,14 @@ func (dbm *DBManager) init() error { } // Update settings - err = dbm.updateSettings() + err = dbm.updateSettings(ctx) if err != nil { log.Panicf("Error running DB settings update: %v", err) return err } // Cache tables - if err := dbm.CacheTempTables(); err != nil { + if err := dbm.CacheTempTables(ctx); err != nil { log.Warn("Refreshing temp table cache failed: ", err) } @@ -128,7 +123,7 @@ func (dbm *DBManager) init() error { } // Reload closes the DB & reinits -func (dbm *DBManager) Reload() error { +func (dbm *DBManager) Reload(ctx context.Context) error { // Close handle err := dbm.DB.Close() if err != nil { @@ -136,7 +131,7 @@ func (dbm *DBManager) Reload() error { } // Reinit DB - if err := dbm.init(); err != nil { + if err := dbm.init(ctx); err != nil { return err } @@ -144,15 +139,15 @@ func (dbm *DBManager) Reload() error { } // CacheTempTables clears existing statistics and recalculates -func (dbm *DBManager) CacheTempTables() error { +func (dbm *DBManager) CacheTempTables(ctx context.Context) error { start := time.Now() - if _, err := dbm.DB.ExecContext(dbm.Ctx, user_streaks); err != nil { + if _, err := dbm.DB.ExecContext(ctx, user_streaks); err != nil { return err } log.Debug("Cached 'user_streaks' in: ", time.Since(start)) start = time.Now() - if _, err := dbm.DB.ExecContext(dbm.Ctx, document_user_statistics); err != nil { + if _, err := dbm.DB.ExecContext(ctx, document_user_statistics); err != nil { return err } log.Debug("Cached 'document_user_statistics' in: ", time.Since(start)) @@ -162,7 +157,7 @@ func (dbm *DBManager) CacheTempTables() error { // updateSettings ensures that we're enforcing foreign keys and enable journal // mode. -func (dbm *DBManager) updateSettings() error { +func (dbm *DBManager) updateSettings(ctx context.Context) error { // Set SQLite PRAGMA Settings pragmaQuery := ` PRAGMA foreign_keys = ON; @@ -174,7 +169,7 @@ func (dbm *DBManager) updateSettings() error { } // Update Antholume Version in DB - if _, err := dbm.Queries.UpdateSettings(dbm.Ctx, UpdateSettingsParams{ + if _, err := dbm.Queries.UpdateSettings(ctx, UpdateSettingsParams{ Name: "version", Value: dbm.cfg.Version, }); err != nil { diff --git a/database/manager_test.go b/database/manager_test.go index 8878360..57534d9 100644 --- a/database/manager_test.go +++ b/database/manager_test.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "testing" "time" @@ -46,7 +47,7 @@ func (suite *DatabaseTestSuite) SetupTest() { // Create User rawAuthHash, _ := utils.GenerateToken(64) authHash := fmt.Sprintf("%x", rawAuthHash) - _, err := suite.dbm.Queries.CreateUser(suite.dbm.Ctx, CreateUserParams{ + _, err := suite.dbm.Queries.CreateUser(context.Background(), CreateUserParams{ ID: userID, Pass: &userPass, AuthHash: &authHash, @@ -54,7 +55,7 @@ func (suite *DatabaseTestSuite) SetupTest() { suite.NoError(err) // Create Document - _, err = suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{ + _, err = suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ ID: documentID, Title: &documentTitle, Author: &documentAuthor, @@ -64,7 +65,7 @@ func (suite *DatabaseTestSuite) SetupTest() { suite.NoError(err) // Create Device - _, err = suite.dbm.Queries.UpsertDevice(suite.dbm.Ctx, UpsertDeviceParams{ + _, err = suite.dbm.Queries.UpsertDevice(context.Background(), UpsertDeviceParams{ ID: deviceID, UserID: userID, DeviceName: deviceName, @@ -80,7 +81,7 @@ func (suite *DatabaseTestSuite) SetupTest() { counter += 1 // Add Item - activity, err := suite.dbm.Queries.AddActivity(suite.dbm.Ctx, AddActivityParams{ + activity, err := suite.dbm.Queries.AddActivity(context.Background(), AddActivityParams{ DocumentID: documentID, DeviceID: deviceID, UserID: userID, @@ -95,7 +96,7 @@ func (suite *DatabaseTestSuite) SetupTest() { } // Initiate Cache - err = suite.dbm.CacheTempTables() + err = suite.dbm.CacheTempTables(context.Background()) suite.NoError(err) } @@ -105,7 +106,7 @@ func (suite *DatabaseTestSuite) SetupTest() { // - 󰊕 (q *Queries) UpsertDevice func (suite *DatabaseTestSuite) TestDevice() { testDevice := "dev123" - device, err := suite.dbm.Queries.UpsertDevice(suite.dbm.Ctx, UpsertDeviceParams{ + device, err := suite.dbm.Queries.UpsertDevice(context.Background(), UpsertDeviceParams{ ID: testDevice, UserID: userID, DeviceName: deviceName, @@ -123,7 +124,7 @@ func (suite *DatabaseTestSuite) TestDevice() { // - 󰊕 (q *Queries) GetLastActivity func (suite *DatabaseTestSuite) TestActivity() { // Validate Exists - existsRows, err := suite.dbm.Queries.GetActivity(suite.dbm.Ctx, GetActivityParams{ + existsRows, err := suite.dbm.Queries.GetActivity(context.Background(), GetActivityParams{ UserID: userID, Offset: 0, Limit: 50, @@ -133,7 +134,7 @@ func (suite *DatabaseTestSuite) TestActivity() { suite.Len(existsRows, 10, "should have correct number of rows get activity") // Validate Doesn't Exist - doesntExistsRows, err := suite.dbm.Queries.GetActivity(suite.dbm.Ctx, GetActivityParams{ + doesntExistsRows, err := suite.dbm.Queries.GetActivity(context.Background(), GetActivityParams{ UserID: userID, DocumentID: "unknownDoc", DocFilter: true, @@ -151,7 +152,7 @@ func (suite *DatabaseTestSuite) TestActivity() { // - 󰊕 (q *Queries) GetDatabaseInfo // - 󰊕 (q *Queries) UpdateSettings func (suite *DatabaseTestSuite) TestGetDailyReadStats() { - readStats, err := suite.dbm.Queries.GetDailyReadStats(suite.dbm.Ctx, userID) + readStats, err := suite.dbm.Queries.GetDailyReadStats(context.Background(), userID) suite.Nil(err, "should have nil err") suite.Len(readStats, 30, "should have length of 30") diff --git a/database/query.sql b/database/query.sql index b46212f..1929773 100644 --- a/database/query.sql +++ b/database/query.sql @@ -163,42 +163,6 @@ ORDER BY DESC LIMIT 1; --- name: GetDocumentWithStats :one -SELECT - docs.id, - docs.title, - docs.author, - docs.description, - docs.isbn10, - docs.isbn13, - docs.filepath, - docs.words, - - CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm, - COALESCE(dus.read_percentage, 0) AS read_percentage, - COALESCE(dus.total_time_seconds, 0) AS total_time_seconds, - STRFTIME('%Y-%m-%d %H:%M:%S', LOCAL_TIME(COALESCE(dus.last_read, STRFTIME('%Y-%m-%dT%H:%M:%SZ', 0, 'unixepoch')), users.timezone)) - AS last_read, - ROUND(CAST(CASE - WHEN dus.percentage IS NULL THEN 0.0 - WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0 - ELSE dus.percentage * 100.0 - END AS REAL), 2) AS percentage, - CAST(CASE - WHEN dus.total_time_seconds IS NULL THEN 0.0 - ELSE - CAST(dus.total_time_seconds AS REAL) - / (dus.read_percentage * 100.0) - END AS INTEGER) AS seconds_per_percent -FROM documents AS docs -LEFT JOIN users ON users.id = $user_id -LEFT JOIN - document_user_statistics AS dus - ON dus.document_id = docs.id AND dus.user_id = $user_id -WHERE users.id = $user_id -AND docs.id = $document_id -LIMIT 1; - -- name: GetDocuments :many SELECT * FROM documents ORDER BY created_at DESC @@ -236,26 +200,25 @@ SELECT WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0 ELSE dus.percentage * 100.0 END AS REAL), 2) AS percentage, - - CASE + CAST(CASE WHEN dus.total_time_seconds IS NULL THEN 0.0 ELSE - ROUND( - CAST(dus.total_time_seconds AS REAL) - / (dus.read_percentage * 100.0) - ) - END AS seconds_per_percent + CAST(dus.total_time_seconds AS REAL) + / (dus.read_percentage * 100.0) + END AS INTEGER) AS seconds_per_percent FROM documents AS docs LEFT JOIN users ON users.id = $user_id LEFT JOIN document_user_statistics AS dus ON dus.document_id = docs.id AND dus.user_id = $user_id WHERE - docs.deleted = false AND ( - $query IS NULL OR ( - docs.title LIKE $query OR + (docs.id = sqlc.narg('id') OR $id IS NULL) + AND (docs.deleted = sqlc.narg(deleted) OR $deleted IS NULL) + AND ( + ( + docs.title LIKE sqlc.narg('query') OR docs.author LIKE $query - ) + ) OR $query IS NULL ) ORDER BY dus.last_read DESC, docs.created_at DESC LIMIT $limit diff --git a/database/query.sql.go b/database/query.sql.go index 64da387..1f5acb3 100644 --- a/database/query.sql.go +++ b/database/query.sql.go @@ -543,87 +543,6 @@ func (q *Queries) GetDocumentProgress(ctx context.Context, arg GetDocumentProgre return i, err } -const getDocumentWithStats = `-- name: GetDocumentWithStats :one -SELECT - docs.id, - docs.title, - docs.author, - docs.description, - docs.isbn10, - docs.isbn13, - docs.filepath, - docs.words, - - CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm, - COALESCE(dus.read_percentage, 0) AS read_percentage, - COALESCE(dus.total_time_seconds, 0) AS total_time_seconds, - STRFTIME('%Y-%m-%d %H:%M:%S', LOCAL_TIME(COALESCE(dus.last_read, STRFTIME('%Y-%m-%dT%H:%M:%SZ', 0, 'unixepoch')), users.timezone)) - AS last_read, - ROUND(CAST(CASE - WHEN dus.percentage IS NULL THEN 0.0 - WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0 - ELSE dus.percentage * 100.0 - END AS REAL), 2) AS percentage, - CAST(CASE - WHEN dus.total_time_seconds IS NULL THEN 0.0 - ELSE - CAST(dus.total_time_seconds AS REAL) - / (dus.read_percentage * 100.0) - END AS INTEGER) AS seconds_per_percent -FROM documents AS docs -LEFT JOIN users ON users.id = ?1 -LEFT JOIN - document_user_statistics AS dus - ON dus.document_id = docs.id AND dus.user_id = ?1 -WHERE users.id = ?1 -AND docs.id = ?2 -LIMIT 1 -` - -type GetDocumentWithStatsParams struct { - UserID string `json:"user_id"` - DocumentID string `json:"document_id"` -} - -type GetDocumentWithStatsRow struct { - ID string `json:"id"` - Title *string `json:"title"` - Author *string `json:"author"` - Description *string `json:"description"` - Isbn10 *string `json:"isbn10"` - Isbn13 *string `json:"isbn13"` - Filepath *string `json:"filepath"` - Words *int64 `json:"words"` - Wpm int64 `json:"wpm"` - ReadPercentage float64 `json:"read_percentage"` - TotalTimeSeconds int64 `json:"total_time_seconds"` - LastRead interface{} `json:"last_read"` - Percentage float64 `json:"percentage"` - SecondsPerPercent int64 `json:"seconds_per_percent"` -} - -func (q *Queries) GetDocumentWithStats(ctx context.Context, arg GetDocumentWithStatsParams) (GetDocumentWithStatsRow, error) { - row := q.db.QueryRowContext(ctx, getDocumentWithStats, arg.UserID, arg.DocumentID) - var i GetDocumentWithStatsRow - err := row.Scan( - &i.ID, - &i.Title, - &i.Author, - &i.Description, - &i.Isbn10, - &i.Isbn13, - &i.Filepath, - &i.Words, - &i.Wpm, - &i.ReadPercentage, - &i.TotalTimeSeconds, - &i.LastRead, - &i.Percentage, - &i.SecondsPerPercent, - ) - return i, err -} - const getDocuments = `-- name: GetDocuments :many SELECT id, md5, basepath, filepath, coverfile, title, author, series, series_index, lang, description, words, gbid, olid, isbn10, isbn13, synced, deleted, updated_at, created_at FROM documents ORDER BY created_at DESC @@ -719,37 +638,38 @@ SELECT WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0 ELSE dus.percentage * 100.0 END AS REAL), 2) AS percentage, - - CASE + CAST(CASE WHEN dus.total_time_seconds IS NULL THEN 0.0 ELSE - ROUND( - CAST(dus.total_time_seconds AS REAL) - / (dus.read_percentage * 100.0) - ) - END AS seconds_per_percent + CAST(dus.total_time_seconds AS REAL) + / (dus.read_percentage * 100.0) + END AS INTEGER) AS seconds_per_percent FROM documents AS docs LEFT JOIN users ON users.id = ?1 LEFT JOIN document_user_statistics AS dus ON dus.document_id = docs.id AND dus.user_id = ?1 WHERE - docs.deleted = false AND ( - ?2 IS NULL OR ( - docs.title LIKE ?2 OR - docs.author LIKE ?2 - ) + (docs.id = ?2 OR ?2 IS NULL) + AND (docs.deleted = ?3 OR ?3 IS NULL) + AND ( + ( + docs.title LIKE ?4 OR + docs.author LIKE ?4 + ) OR ?4 IS NULL ) ORDER BY dus.last_read DESC, docs.created_at DESC -LIMIT ?4 -OFFSET ?3 +LIMIT ?6 +OFFSET ?5 ` type GetDocumentsWithStatsParams struct { - UserID string `json:"user_id"` - Query interface{} `json:"query"` - Offset int64 `json:"offset"` - Limit int64 `json:"limit"` + UserID string `json:"user_id"` + ID *string `json:"id"` + Deleted *bool `json:"-"` + Query *string `json:"query"` + Offset int64 `json:"offset"` + Limit int64 `json:"limit"` } type GetDocumentsWithStatsRow struct { @@ -766,12 +686,14 @@ type GetDocumentsWithStatsRow struct { TotalTimeSeconds int64 `json:"total_time_seconds"` LastRead interface{} `json:"last_read"` Percentage float64 `json:"percentage"` - SecondsPerPercent interface{} `json:"seconds_per_percent"` + SecondsPerPercent int64 `json:"seconds_per_percent"` } func (q *Queries) GetDocumentsWithStats(ctx context.Context, arg GetDocumentsWithStatsParams) ([]GetDocumentsWithStatsRow, error) { rows, err := q.db.QueryContext(ctx, getDocumentsWithStats, arg.UserID, + arg.ID, + arg.Deleted, arg.Query, arg.Offset, arg.Limit, diff --git a/database/users_test.go b/database/users_test.go index 7a6de8e..7376ab4 100644 --- a/database/users_test.go +++ b/database/users_test.go @@ -1,6 +1,7 @@ package database import ( + "context" "database/sql" "fmt" "testing" @@ -36,7 +37,7 @@ func (suite *UsersTestSuite) SetupTest() { // Create User rawAuthHash, _ := utils.GenerateToken(64) authHash := fmt.Sprintf("%x", rawAuthHash) - _, err := suite.dbm.Queries.CreateUser(suite.dbm.Ctx, CreateUserParams{ + _, err := suite.dbm.Queries.CreateUser(context.Background(), CreateUserParams{ ID: testUserID, Pass: &testUserPass, AuthHash: &authHash, @@ -44,7 +45,7 @@ func (suite *UsersTestSuite) SetupTest() { suite.NoError(err) // Create Document - _, err = suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{ + _, err = suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ ID: documentID, Title: &documentTitle, Author: &documentAuthor, @@ -53,7 +54,7 @@ func (suite *UsersTestSuite) SetupTest() { suite.NoError(err) // Create Device - _, err = suite.dbm.Queries.UpsertDevice(suite.dbm.Ctx, UpsertDeviceParams{ + _, err = suite.dbm.Queries.UpsertDevice(context.Background(), UpsertDeviceParams{ ID: deviceID, UserID: testUserID, DeviceName: deviceName, @@ -62,7 +63,7 @@ func (suite *UsersTestSuite) SetupTest() { } func (suite *UsersTestSuite) TestGetUser() { - user, err := suite.dbm.Queries.GetUser(suite.dbm.Ctx, testUserID) + user, err := suite.dbm.Queries.GetUser(context.Background(), testUserID) suite.Nil(err, "should have nil err") suite.Equal(testUserPass, *user.Pass) } @@ -76,7 +77,7 @@ func (suite *UsersTestSuite) TestCreateUser() { suite.Nil(err, "should have nil err") authHash := fmt.Sprintf("%x", rawAuthHash) - changed, err := suite.dbm.Queries.CreateUser(suite.dbm.Ctx, CreateUserParams{ + changed, err := suite.dbm.Queries.CreateUser(context.Background(), CreateUserParams{ ID: testUser, Pass: &testPass, AuthHash: &authHash, @@ -85,29 +86,29 @@ func (suite *UsersTestSuite) TestCreateUser() { suite.Nil(err, "should have nil err") suite.Equal(int64(1), changed) - user, err := suite.dbm.Queries.GetUser(suite.dbm.Ctx, testUser) + user, err := suite.dbm.Queries.GetUser(context.Background(), testUser) suite.Nil(err, "should have nil err") suite.Equal(testPass, *user.Pass) } func (suite *UsersTestSuite) TestDeleteUser() { - changed, err := suite.dbm.Queries.DeleteUser(suite.dbm.Ctx, testUserID) + changed, err := suite.dbm.Queries.DeleteUser(context.Background(), testUserID) suite.Nil(err, "should have nil err") suite.Equal(int64(1), changed, "should have one changed row") - _, err = suite.dbm.Queries.GetUser(suite.dbm.Ctx, testUserID) + _, err = suite.dbm.Queries.GetUser(context.Background(), testUserID) suite.ErrorIs(err, sql.ErrNoRows, "should have no rows error") } func (suite *UsersTestSuite) TestGetUsers() { - users, err := suite.dbm.Queries.GetUsers(suite.dbm.Ctx) + users, err := suite.dbm.Queries.GetUsers(context.Background()) suite.Nil(err, "should have nil err") suite.Len(users, 1, "should have single user") } func (suite *UsersTestSuite) TestUpdateUser() { newPassword := "newPass123" - user, err := suite.dbm.Queries.UpdateUser(suite.dbm.Ctx, UpdateUserParams{ + user, err := suite.dbm.Queries.UpdateUser(context.Background(), UpdateUserParams{ UserID: testUserID, Password: &newPassword, }) @@ -116,11 +117,11 @@ func (suite *UsersTestSuite) TestUpdateUser() { } func (suite *UsersTestSuite) TestGetUserStatistics() { - err := suite.dbm.CacheTempTables() + err := suite.dbm.CacheTempTables(context.Background()) suite.NoError(err) // Ensure Zero Items - userStats, err := suite.dbm.Queries.GetUserStatistics(suite.dbm.Ctx) + userStats, err := suite.dbm.Queries.GetUserStatistics(context.Background()) suite.Nil(err, "should have nil err") suite.Empty(userStats, "should be empty") @@ -133,7 +134,7 @@ func (suite *UsersTestSuite) TestGetUserStatistics() { counter += 1 // Add Item - activity, err := suite.dbm.Queries.AddActivity(suite.dbm.Ctx, AddActivityParams{ + activity, err := suite.dbm.Queries.AddActivity(context.Background(), AddActivityParams{ DocumentID: documentID, DeviceID: deviceID, UserID: testUserID, @@ -147,21 +148,21 @@ func (suite *UsersTestSuite) TestGetUserStatistics() { suite.Equal(counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter)) } - err = suite.dbm.CacheTempTables() + err = suite.dbm.CacheTempTables(context.Background()) suite.NoError(err) // Ensure One Item - userStats, err = suite.dbm.Queries.GetUserStatistics(suite.dbm.Ctx) + userStats, err = suite.dbm.Queries.GetUserStatistics(context.Background()) suite.Nil(err, "should have nil err") suite.Len(userStats, 1, "should have length of one") } func (suite *UsersTestSuite) TestGetUsersStreaks() { - err := suite.dbm.CacheTempTables() + err := suite.dbm.CacheTempTables(context.Background()) suite.NoError(err) // Ensure Zero Items - userStats, err := suite.dbm.Queries.GetUserStreaks(suite.dbm.Ctx, testUserID) + userStats, err := suite.dbm.Queries.GetUserStreaks(context.Background(), testUserID) suite.Nil(err, "should have nil err") suite.Empty(userStats, "should be empty") @@ -174,7 +175,7 @@ func (suite *UsersTestSuite) TestGetUsersStreaks() { counter += 1 // Add Item - activity, err := suite.dbm.Queries.AddActivity(suite.dbm.Ctx, AddActivityParams{ + activity, err := suite.dbm.Queries.AddActivity(context.Background(), AddActivityParams{ DocumentID: documentID, DeviceID: deviceID, UserID: testUserID, @@ -188,11 +189,11 @@ func (suite *UsersTestSuite) TestGetUsersStreaks() { suite.Equal(counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter)) } - err = suite.dbm.CacheTempTables() + err = suite.dbm.CacheTempTables(context.Background()) suite.NoError(err) // Ensure Two Item - userStats, err = suite.dbm.Queries.GetUserStreaks(suite.dbm.Ctx, testUserID) + userStats, err = suite.dbm.Queries.GetUserStreaks(context.Background(), testUserID) suite.Nil(err, "should have nil err") suite.Len(userStats, 2, "should have length of two") diff --git a/server/server.go b/server/server.go index 56d1693..f9cf34a 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "io/fs" "net/http" "sync" @@ -52,12 +53,14 @@ func (s *server) Start() { ticker := time.NewTicker(15 * time.Minute) defer ticker.Stop() + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Minute)) for { select { case <-ticker.C: - s.runScheduledTasks() + s.runScheduledTasks(ctx) case <-s.done: log.Info("Stopping task runner...") + cancel() return } } @@ -81,9 +84,9 @@ func (s *server) Stop() { } // Run normal scheduled tasks -func (s *server) runScheduledTasks() { +func (s *server) runScheduledTasks(ctx context.Context) { start := time.Now() - if err := s.db.CacheTempTables(); err != nil { + if err := s.db.CacheTempTables(ctx); err != nil { log.Warn("Refreshing temp table cache failed: ", err) } log.Debug("Completed in: ", time.Since(start))