Compare commits

..

No commits in common. "5f8a9b7b1419ee35aa50e71cd60d21bed4f5797e" and "456b6e457c8541eca2f5646aeab3a1f0570e0dad" have entirely different histories.

20 changed files with 328 additions and 371 deletions

View File

@ -3,7 +3,6 @@ package api
import ( import (
"archive/zip" "archive/zip"
"bufio" "bufio"
"context"
"crypto/md5" "crypto/md5"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -113,7 +112,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) {
// 2. Select all / deselect? // 2. Select all / deselect?
case adminCacheTables: case adminCacheTables:
go func() { go func() {
err := api.db.CacheTempTables(c) err := api.db.CacheTempTables()
if err != nil { if err != nil {
log.Error("Unable to cache temp tables: ", err) log.Error("Unable to cache temp tables: ", err)
} }
@ -123,7 +122,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) {
return return
case adminBackup: case adminBackup:
// Vacuum // Vacuum
_, err := api.db.DB.ExecContext(c, "VACUUM;") _, err := api.db.DB.ExecContext(api.db.Ctx, "VACUUM;")
if err != nil { if err != nil {
log.Error("Unable to vacuum DB: ", err) log.Error("Unable to vacuum DB: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database") appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database")
@ -145,7 +144,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) {
} }
} }
err := api.createBackup(c, w, directories) err := api.createBackup(w, directories)
if err != nil { if err != nil {
log.Error("Backup Error: ", err) log.Error("Backup Error: ", err)
} }
@ -262,7 +261,7 @@ func (api *API) appGetAdminLogs(c *gin.Context) {
func (api *API) appGetAdminUsers(c *gin.Context) { func (api *API) appGetAdminUsers(c *gin.Context) {
templateVars, _ := api.getBaseTemplateVars("admin-users", c) templateVars, _ := api.getBaseTemplateVars("admin-users", c)
users, err := api.db.Queries.GetUsers(c) users, err := api.db.Queries.GetUsers(api.db.Ctx)
if err != nil { if err != nil {
log.Error("GetUsers DB Error: ", err) log.Error("GetUsers DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err))
@ -293,11 +292,11 @@ func (api *API) appUpdateAdminUsers(c *gin.Context) {
var err error var err error
switch rUpdate.Operation { switch rUpdate.Operation {
case opCreate: case opCreate:
err = api.createUser(c, rUpdate.User, rUpdate.Password, rUpdate.IsAdmin) err = api.createUser(rUpdate.User, rUpdate.Password, rUpdate.IsAdmin)
case opUpdate: case opUpdate:
err = api.updateUser(c, rUpdate.User, rUpdate.Password, rUpdate.IsAdmin) err = api.updateUser(rUpdate.User, rUpdate.Password, rUpdate.IsAdmin)
case opDelete: case opDelete:
err = api.deleteUser(c, rUpdate.User) err = api.deleteUser(rUpdate.User)
default: default:
appErrorPage(c, http.StatusNotFound, "Unknown user operation") appErrorPage(c, http.StatusNotFound, "Unknown user operation")
return return
@ -308,7 +307,7 @@ func (api *API) appUpdateAdminUsers(c *gin.Context) {
return return
} }
users, err := api.db.Queries.GetUsers(c) users, err := api.db.Queries.GetUsers(api.db.Ctx)
if err != nil { if err != nil {
log.Error("GetUsers DB Error: ", err) log.Error("GetUsers DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err))
@ -449,7 +448,7 @@ func (api *API) appPerformAdminImport(c *gin.Context) {
iResult.Name = fmt.Sprintf("%s - %s", *fileMeta.Author, *fileMeta.Title) iResult.Name = fmt.Sprintf("%s - %s", *fileMeta.Author, *fileMeta.Title)
// Check already exists // Check already exists
_, err = qtx.GetDocument(c, *fileMeta.PartialMD5) _, err = qtx.GetDocument(api.db.Ctx, *fileMeta.PartialMD5)
if err == nil { if err == nil {
log.Warnf("document already exists: %s", *fileMeta.PartialMD5) log.Warnf("document already exists: %s", *fileMeta.PartialMD5)
iResult.Status = importExists iResult.Status = importExists
@ -493,7 +492,7 @@ func (api *API) appPerformAdminImport(c *gin.Context) {
} }
// Upsert document // Upsert document
if _, err = qtx.UpsertDocument(c, database.UpsertDocumentParams{ if _, err = qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: *fileMeta.PartialMD5, ID: *fileMeta.PartialMD5,
Title: fileMeta.Title, Title: fileMeta.Title,
Author: fileMeta.Author, Author: fileMeta.Author,
@ -628,7 +627,7 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
// Save Backup File // Save Backup File
w := bufio.NewWriter(backupFile) w := bufio.NewWriter(backupFile)
err = api.createBackup(c, w, []string{"covers", "documents"}) err = api.createBackup(w, []string{"covers", "documents"})
if err != nil { if err != nil {
log.Error("Unable to save backup file: ", err) log.Error("Unable to save backup file: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to save backup file") appErrorPage(c, http.StatusInternalServerError, "Unable to save backup file")
@ -651,13 +650,13 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
} }
// Reinit DB // Reinit DB
if err := api.db.Reload(c); err != nil { if err := api.db.Reload(); err != nil {
appErrorPage(c, http.StatusInternalServerError, "Unable to reload DB") appErrorPage(c, http.StatusInternalServerError, "Unable to reload DB")
log.Panicf("Unable to reload DB: %v", err) log.Panicf("Unable to reload DB: %v", err)
} }
// Rotate Auth Hashes // Rotate Auth Hashes
if err := api.rotateAllAuthHashes(c); err != nil { if err := api.rotateAllAuthHashes(); err != nil {
appErrorPage(c, http.StatusInternalServerError, "Unable to rotate hashes") appErrorPage(c, http.StatusInternalServerError, "Unable to rotate hashes")
log.Panicf("Unable to rotate auth hashes: %v", err) log.Panicf("Unable to rotate auth hashes: %v", err)
} }
@ -718,9 +717,9 @@ func (api *API) removeData() error {
return nil return nil
} }
func (api *API) createBackup(ctx context.Context, w io.Writer, directories []string) error { func (api *API) createBackup(w io.Writer, directories []string) error {
// Vacuum DB // Vacuum DB
_, err := api.db.DB.ExecContext(ctx, "VACUUM;") _, err := api.db.DB.ExecContext(api.db.Ctx, "VACUUM;")
if err != nil { if err != nil {
return errors.Wrap(err, "Unable to vacuum database") return errors.Wrap(err, "Unable to vacuum database")
} }
@ -793,8 +792,8 @@ func (api *API) createBackup(ctx context.Context, w io.Writer, directories []str
return nil return nil
} }
func (api *API) isLastAdmin(ctx context.Context, userID string) (bool, error) { func (api *API) isLastAdmin(userID string) (bool, error) {
allUsers, err := api.db.Queries.GetUsers(ctx) allUsers, err := api.db.Queries.GetUsers(api.db.Ctx)
if err != nil { if err != nil {
return false, errors.Wrap(err, fmt.Sprintf("GetUsers DB Error: %v", err)) return false, errors.Wrap(err, fmt.Sprintf("GetUsers DB Error: %v", err))
} }
@ -810,7 +809,7 @@ func (api *API) isLastAdmin(ctx context.Context, userID string) (bool, error) {
return !hasAdmin, nil return !hasAdmin, nil
} }
func (api *API) createUser(ctx context.Context, user string, rawPassword *string, isAdmin *bool) error { func (api *API) createUser(user string, rawPassword *string, isAdmin *bool) error {
// Validate Necessary Parameters // Validate Necessary Parameters
if rawPassword == nil || *rawPassword == "" { if rawPassword == nil || *rawPassword == "" {
return fmt.Errorf("password can't be empty") return fmt.Errorf("password can't be empty")
@ -845,7 +844,7 @@ func (api *API) createUser(ctx context.Context, user string, rawPassword *string
createParams.AuthHash = &authHash createParams.AuthHash = &authHash
// Create user in DB // Create user in DB
if rows, err := api.db.Queries.CreateUser(ctx, createParams); err != nil { if rows, err := api.db.Queries.CreateUser(api.db.Ctx, createParams); err != nil {
log.Error("CreateUser DB Error:", err) log.Error("CreateUser DB Error:", err)
return fmt.Errorf("unable to create user") return fmt.Errorf("unable to create user")
} else if rows == 0 { } else if rows == 0 {
@ -856,7 +855,7 @@ func (api *API) createUser(ctx context.Context, user string, rawPassword *string
return nil return nil
} }
func (api *API) updateUser(ctx context.Context, user string, rawPassword *string, isAdmin *bool) error { func (api *API) updateUser(user string, rawPassword *string, isAdmin *bool) error {
// Validate Necessary Parameters // Validate Necessary Parameters
if rawPassword == nil && isAdmin == nil { if rawPassword == nil && isAdmin == nil {
return fmt.Errorf("nothing to update") return fmt.Errorf("nothing to update")
@ -871,7 +870,7 @@ func (api *API) updateUser(ctx context.Context, user string, rawPassword *string
if isAdmin != nil { if isAdmin != nil {
updateParams.Admin = *isAdmin updateParams.Admin = *isAdmin
} else { } else {
user, err := api.db.Queries.GetUser(ctx, user) user, err := api.db.Queries.GetUser(api.db.Ctx, user)
if err != nil { if err != nil {
return errors.Wrap(err, fmt.Sprintf("GetUser DB Error: %v", err)) return errors.Wrap(err, fmt.Sprintf("GetUser DB Error: %v", err))
} }
@ -879,7 +878,7 @@ func (api *API) updateUser(ctx context.Context, user string, rawPassword *string
} }
// Check Admins - Disallow Demotion // Check Admins - Disallow Demotion
if isLast, err := api.isLastAdmin(ctx, user); err != nil { if isLast, err := api.isLastAdmin(user); err != nil {
return err return err
} else if isLast && !updateParams.Admin { } else if isLast && !updateParams.Admin {
return fmt.Errorf("unable to demote %s - last admin", user) return fmt.Errorf("unable to demote %s - last admin", user)
@ -909,7 +908,7 @@ func (api *API) updateUser(ctx context.Context, user string, rawPassword *string
} }
// Update User // Update User
_, err := api.db.Queries.UpdateUser(ctx, updateParams) _, err := api.db.Queries.UpdateUser(api.db.Ctx, updateParams)
if err != nil { if err != nil {
return errors.Wrap(err, fmt.Sprintf("UpdateUser DB Error: %v", err)) return errors.Wrap(err, fmt.Sprintf("UpdateUser DB Error: %v", err))
} }
@ -917,9 +916,9 @@ func (api *API) updateUser(ctx context.Context, user string, rawPassword *string
return nil return nil
} }
func (api *API) deleteUser(ctx context.Context, user string) error { func (api *API) deleteUser(user string) error {
// Check Admins // Check Admins
if isLast, err := api.isLastAdmin(ctx, user); err != nil { if isLast, err := api.isLastAdmin(user); err != nil {
return err return err
} else if isLast { } else if isLast {
return fmt.Errorf("unable to delete %s - last admin", user) return fmt.Errorf("unable to delete %s - last admin", user)
@ -935,13 +934,13 @@ func (api *API) deleteUser(ctx context.Context, user string) error {
// Save Backup File (DB Only) // Save Backup File (DB Only)
w := bufio.NewWriter(backupFile) w := bufio.NewWriter(backupFile)
err = api.createBackup(ctx, w, []string{}) err = api.createBackup(w, []string{})
if err != nil { if err != nil {
return err return err
} }
// Delete User // Delete User
_, err = api.db.Queries.DeleteUser(ctx, user) _, err = api.db.Queries.DeleteUser(api.db.Ctx, user)
if err != nil { if err != nil {
return errors.Wrap(err, fmt.Sprintf("DeleteUser DB Error: %v", err)) return errors.Wrap(err, fmt.Sprintf("DeleteUser DB Error: %v", err))
} }

View File

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"crypto/md5" "crypto/md5"
"database/sql" "database/sql"
"fmt" "fmt"
@ -23,7 +22,6 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"reichard.io/antholume/database" "reichard.io/antholume/database"
"reichard.io/antholume/metadata" "reichard.io/antholume/metadata"
"reichard.io/antholume/pkg/ptr"
"reichard.io/antholume/search" "reichard.io/antholume/search"
) )
@ -111,12 +109,11 @@ func (api *API) appGetDocuments(c *gin.Context) {
query = &search query = &search
} }
documents, err := api.db.Queries.GetDocumentsWithStats(c, database.GetDocumentsWithStatsParams{ documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{
UserID: auth.UserName, UserID: auth.UserName,
Query: query, Query: query,
Deleted: ptr.Of(false), Offset: (*qParams.Page - 1) * *qParams.Limit,
Offset: (*qParams.Page - 1) * *qParams.Limit, Limit: *qParams.Limit,
Limit: *qParams.Limit,
}) })
if err != nil { if err != nil {
log.Error("GetDocumentsWithStats DB Error: ", err) log.Error("GetDocumentsWithStats DB Error: ", err)
@ -124,14 +121,14 @@ func (api *API) appGetDocuments(c *gin.Context) {
return return
} }
length, err := api.db.Queries.GetDocumentsSize(c, query) length, err := api.db.Queries.GetDocumentsSize(api.db.Ctx, query)
if err != nil { if err != nil {
log.Error("GetDocumentsSize DB Error: ", err) log.Error("GetDocumentsSize DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err))
return return
} }
if err = api.getDocumentsWordCount(c, documents); err != nil { if err = api.getDocumentsWordCount(documents); err != nil {
log.Error("Unable to Get Word Counts: ", err) log.Error("Unable to Get Word Counts: ", err)
} }
@ -163,10 +160,13 @@ func (api *API) appGetDocument(c *gin.Context) {
return return
} }
document, err := api.db.GetDocument(c, rDocID.DocumentID, auth.UserName) document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{
UserID: auth.UserName,
DocumentID: rDocID.DocumentID,
})
if err != nil { if err != nil {
log.Error("GetDocument DB Error: ", err) log.Error("GetDocumentWithStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsWithStats DB Error: %v", err))
return return
} }
@ -192,7 +192,7 @@ func (api *API) appGetProgress(c *gin.Context) {
progressFilter.DocumentID = *qParams.Document progressFilter.DocumentID = *qParams.Document
} }
progress, err := api.db.Queries.GetProgress(c, progressFilter) progress, err := api.db.Queries.GetProgress(api.db.Ctx, progressFilter)
if err != nil { if err != nil {
log.Error("GetProgress DB Error: ", err) log.Error("GetProgress DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err))
@ -219,7 +219,7 @@ func (api *API) appGetActivity(c *gin.Context) {
activityFilter.DocumentID = *qParams.Document activityFilter.DocumentID = *qParams.Document
} }
activity, err := api.db.Queries.GetActivity(c, activityFilter) activity, err := api.db.Queries.GetActivity(api.db.Ctx, activityFilter)
if err != nil { if err != nil {
log.Error("GetActivity DB Error: ", err) log.Error("GetActivity DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err))
@ -235,7 +235,7 @@ func (api *API) appGetHome(c *gin.Context) {
templateVars, auth := api.getBaseTemplateVars("home", c) templateVars, auth := api.getBaseTemplateVars("home", c)
start := time.Now() start := time.Now()
graphData, err := api.db.Queries.GetDailyReadStats(c, auth.UserName) graphData, err := api.db.Queries.GetDailyReadStats(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDailyReadStats DB Error: ", err) log.Error("GetDailyReadStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDailyReadStats DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDailyReadStats DB Error: %v", err))
@ -244,7 +244,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetDailyReadStats DB Performance: ", time.Since(start)) log.Debug("GetDailyReadStats DB Performance: ", time.Since(start))
start = time.Now() start = time.Now()
databaseInfo, err := api.db.Queries.GetDatabaseInfo(c, auth.UserName) databaseInfo, err := api.db.Queries.GetDatabaseInfo(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDatabaseInfo DB Error: ", err) log.Error("GetDatabaseInfo DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDatabaseInfo DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDatabaseInfo DB Error: %v", err))
@ -253,7 +253,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetDatabaseInfo DB Performance: ", time.Since(start)) log.Debug("GetDatabaseInfo DB Performance: ", time.Since(start))
start = time.Now() start = time.Now()
streaks, err := api.db.Queries.GetUserStreaks(c, auth.UserName) streaks, err := api.db.Queries.GetUserStreaks(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetUserStreaks DB Error: ", err) log.Error("GetUserStreaks DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStreaks DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStreaks DB Error: %v", err))
@ -262,7 +262,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetUserStreaks DB Performance: ", time.Since(start)) log.Debug("GetUserStreaks DB Performance: ", time.Since(start))
start = time.Now() start = time.Now()
userStatistics, err := api.db.Queries.GetUserStatistics(c) userStatistics, err := api.db.Queries.GetUserStatistics(api.db.Ctx)
if err != nil { if err != nil {
log.Error("GetUserStatistics DB Error: ", err) log.Error("GetUserStatistics DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStatistics DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStatistics DB Error: %v", err))
@ -283,14 +283,14 @@ func (api *API) appGetHome(c *gin.Context) {
func (api *API) appGetSettings(c *gin.Context) { func (api *API) appGetSettings(c *gin.Context) {
templateVars, auth := api.getBaseTemplateVars("settings", c) templateVars, auth := api.getBaseTemplateVars("settings", c)
user, err := api.db.Queries.GetUser(c, auth.UserName) user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetUser DB Error: ", err) log.Error("GetUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err))
return return
} }
devices, err := api.db.Queries.GetDevices(c, auth.UserName) devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDevices DB Error: ", err) log.Error("GetDevices DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err))
@ -368,7 +368,7 @@ func (api *API) appGetDocumentProgress(c *gin.Context) {
return return
} }
progress, err := api.db.Queries.GetDocumentProgress(c, database.GetDocumentProgressParams{ progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{
DocumentID: rDoc.DocumentID, DocumentID: rDoc.DocumentID,
UserID: auth.UserName, UserID: auth.UserName,
}) })
@ -378,10 +378,13 @@ func (api *API) appGetDocumentProgress(c *gin.Context) {
return return
} }
document, err := api.db.GetDocument(c, rDoc.DocumentID, auth.UserName) document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{
UserID: auth.UserName,
DocumentID: rDoc.DocumentID,
})
if err != nil { if err != nil {
log.Error("GetDocument DB Error: ", err) log.Error("GetDocumentWithStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentWithStats DB Error: %v", err))
return return
} }
@ -401,7 +404,7 @@ func (api *API) appGetDevices(c *gin.Context) {
auth = data.(authData) auth = data.(authData)
} }
devices, err := api.db.Queries.GetDevices(c, auth.UserName) devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
log.Error("GetDevices DB Error: ", err) log.Error("GetDevices DB Error: ", err)
@ -452,7 +455,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
} }
// Check Already Exists // Check Already Exists
_, err = api.db.Queries.GetDocument(c, *metadataInfo.PartialMD5) _, err = api.db.Queries.GetDocument(api.db.Ctx, *metadataInfo.PartialMD5)
if err == nil { if err == nil {
log.Warnf("document already exists: %s", *metadataInfo.PartialMD5) log.Warnf("document already exists: %s", *metadataInfo.PartialMD5)
c.Redirect(http.StatusFound, fmt.Sprintf("./documents/%s", *metadataInfo.PartialMD5)) c.Redirect(http.StatusFound, fmt.Sprintf("./documents/%s", *metadataInfo.PartialMD5))
@ -480,7 +483,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
} }
// Upsert Document // Upsert Document
if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: *metadataInfo.PartialMD5, ID: *metadataInfo.PartialMD5,
Title: metadataInfo.Title, Title: metadataInfo.Title,
Author: metadataInfo.Author, Author: metadataInfo.Author,
@ -570,7 +573,7 @@ func (api *API) appEditDocument(c *gin.Context) {
coverFileName = &fileName coverFileName = &fileName
} else if rDocEdit.CoverGBID != nil { } else if rDocEdit.CoverGBID != nil {
coverDir := filepath.Join(api.cfg.DataPath, "covers") var coverDir string = filepath.Join(api.cfg.DataPath, "covers")
fileName, err := metadata.CacheCover(*rDocEdit.CoverGBID, coverDir, rDocID.DocumentID, true) fileName, err := metadata.CacheCover(*rDocEdit.CoverGBID, coverDir, rDocID.DocumentID, true)
if err == nil { if err == nil {
coverFileName = fileName coverFileName = fileName
@ -578,7 +581,7 @@ func (api *API) appEditDocument(c *gin.Context) {
} }
// Update Document // Update Document
if _, err := api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: rDocID.DocumentID, ID: rDocID.DocumentID,
Title: api.sanitizeInput(rDocEdit.Title), Title: api.sanitizeInput(rDocEdit.Title),
Author: api.sanitizeInput(rDocEdit.Author), Author: api.sanitizeInput(rDocEdit.Author),
@ -602,7 +605,7 @@ func (api *API) appDeleteDocument(c *gin.Context) {
appErrorPage(c, http.StatusNotFound, "Invalid document") appErrorPage(c, http.StatusNotFound, "Invalid document")
return return
} }
changed, err := api.db.Queries.DeleteDocument(c, rDocID.DocumentID) changed, err := api.db.Queries.DeleteDocument(api.db.Ctx, rDocID.DocumentID)
if err != nil { if err != nil {
log.Error("DeleteDocument DB Error") log.Error("DeleteDocument DB Error")
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("DeleteDocument DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("DeleteDocument DB Error: %v", err))
@ -664,7 +667,7 @@ func (api *API) appIdentifyDocument(c *gin.Context) {
firstResult := metadataResults[0] firstResult := metadataResults[0]
// Store First Metadata Result // Store First Metadata Result
if _, err = api.db.Queries.AddMetadata(c, database.AddMetadataParams{ if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{
DocumentID: rDocID.DocumentID, DocumentID: rDocID.DocumentID,
Title: firstResult.Title, Title: firstResult.Title,
Author: firstResult.Author, Author: firstResult.Author,
@ -683,10 +686,13 @@ func (api *API) appIdentifyDocument(c *gin.Context) {
templateVars["MetadataError"] = "No Metadata Found" templateVars["MetadataError"] = "No Metadata Found"
} }
document, err := api.db.GetDocument(c, rDocID.DocumentID, auth.UserName) document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{
UserID: auth.UserName,
DocumentID: rDocID.DocumentID,
})
if err != nil { if err != nil {
log.Error("GetDocument DB Error: ", err) log.Error("GetDocumentWithStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentWithStats DB Error: %v", err))
return return
} }
@ -811,7 +817,7 @@ func (api *API) appSaveNewDocument(c *gin.Context) {
sendDownloadMessage("Saving to database...", gin.H{"Progress": 99}) sendDownloadMessage("Saving to database...", gin.H{"Progress": 99})
// Upsert Document // Upsert Document
if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: *metadata.PartialMD5, ID: *metadata.PartialMD5,
Title: &docTitle, Title: &docTitle,
Author: &docAuthor, Author: &docAuthor,
@ -858,7 +864,7 @@ func (api *API) appEditSettings(c *gin.Context) {
// Set New Password // Set New Password
if rUserSettings.Password != nil && rUserSettings.NewPassword != nil { if rUserSettings.Password != nil && rUserSettings.NewPassword != nil {
password := fmt.Sprintf("%x", md5.Sum([]byte(*rUserSettings.Password))) password := fmt.Sprintf("%x", md5.Sum([]byte(*rUserSettings.Password)))
data := api.authorizeCredentials(c, auth.UserName, password) data := api.authorizeCredentials(auth.UserName, password)
if data == nil { if data == nil {
templateVars["PasswordErrorMessage"] = "Invalid Password" templateVars["PasswordErrorMessage"] = "Invalid Password"
} else { } else {
@ -880,7 +886,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Update User // Update User
_, err := api.db.Queries.UpdateUser(c, newUserSettings) _, err := api.db.Queries.UpdateUser(api.db.Ctx, newUserSettings)
if err != nil { if err != nil {
log.Error("UpdateUser DB Error: ", err) log.Error("UpdateUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpdateUser DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpdateUser DB Error: %v", err))
@ -888,7 +894,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Get User // Get User
user, err := api.db.Queries.GetUser(c, auth.UserName) user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetUser DB Error: ", err) log.Error("GetUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err))
@ -896,7 +902,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Get Devices // Get Devices
devices, err := api.db.Queries.GetDevices(c, auth.UserName) devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDevices DB Error: ", err) log.Error("GetDevices DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err))
@ -915,7 +921,7 @@ func (api *API) appDemoModeError(c *gin.Context) {
appErrorPage(c, http.StatusUnauthorized, "Not Allowed in Demo Mode") appErrorPage(c, http.StatusUnauthorized, "Not Allowed in Demo Mode")
} }
func (api *API) getDocumentsWordCount(ctx context.Context, documents []database.GetDocumentsWithStatsRow) error { func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStatsRow) error {
// Do Transaction // Do Transaction
tx, err := api.db.DB.Begin() tx, err := api.db.DB.Begin()
if err != nil { if err != nil {
@ -938,7 +944,7 @@ func (api *API) getDocumentsWordCount(ctx context.Context, documents []database.
if err != nil { if err != nil {
log.Warn("Word Count Error: ", err) log.Warn("Word Count Error: ", err)
} else { } else {
if _, err := qtx.UpsertDocument(ctx, database.UpsertDocumentParams{ if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: item.ID, ID: item.ID,
Words: wordCount, Words: wordCount,
}); err != nil { }); err != nil {
@ -999,7 +1005,7 @@ func bindQueryParams(c *gin.Context, defaultLimit int64) queryParams {
} }
func appErrorPage(c *gin.Context, errorCode int, errorMessage string) { func appErrorPage(c *gin.Context, errorCode int, errorMessage string) {
errorHuman := "We're not even sure what happened." var errorHuman string = "We're not even sure what happened."
switch errorCode { switch errorCode {
case http.StatusInternalServerError: case http.StatusInternalServerError:

View File

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"crypto/md5" "crypto/md5"
"fmt" "fmt"
"net/http" "net/http"
@ -29,8 +28,8 @@ type authKOHeader struct {
AuthKey string `header:"x-auth-key"` AuthKey string `header:"x-auth-key"`
} }
func (api *API) authorizeCredentials(ctx context.Context, username string, password string) (auth *authData) { func (api *API) authorizeCredentials(username string, password string) (auth *authData) {
user, err := api.db.Queries.GetUser(ctx, username) user, err := api.db.Queries.GetUser(api.db.Ctx, username)
if err != nil { if err != nil {
return return
} }
@ -53,7 +52,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
// Check Session First // Check Session First
if auth, ok := api.getSession(c, session); ok { if auth, ok := api.getSession(session); ok {
c.Set("Authorization", auth) c.Set("Authorization", auth)
c.Header("Cache-Control", "private") c.Header("Cache-Control", "private")
c.Next() c.Next()
@ -72,7 +71,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
return return
} }
authData := api.authorizeCredentials(c, rHeader.AuthUser, rHeader.AuthKey) authData := api.authorizeCredentials(rHeader.AuthUser, rHeader.AuthKey)
if authData == nil { if authData == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return return
@ -101,7 +100,7 @@ func (api *API) authOPDSMiddleware(c *gin.Context) {
// Validate Auth // Validate Auth
password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword))) password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword)))
authData := api.authorizeCredentials(c, user, password) authData := api.authorizeCredentials(user, password)
if authData == nil { if authData == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return return
@ -116,7 +115,7 @@ func (api *API) authWebAppMiddleware(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
// Check Session // Check Session
if auth, ok := api.getSession(c, session); ok { if auth, ok := api.getSession(session); ok {
c.Set("Authorization", auth) c.Set("Authorization", auth)
c.Header("Cache-Control", "private") c.Header("Cache-Control", "private")
c.Next() c.Next()
@ -154,7 +153,7 @@ func (api *API) appAuthLogin(c *gin.Context) {
// MD5 - KOSync Compatiblity // MD5 - KOSync Compatiblity
password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword))) password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword)))
authData := api.authorizeCredentials(c, username, password) authData := api.authorizeCredentials(username, password)
if authData == nil { if authData == nil {
templateVars["Error"] = "Invalid Credentials" templateVars["Error"] = "Invalid Credentials"
c.HTML(http.StatusUnauthorized, "page/login", templateVars) c.HTML(http.StatusUnauthorized, "page/login", templateVars)
@ -209,7 +208,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
} }
// Get current users // Get current users
currentUsers, err := api.db.Queries.GetUsers(c) currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx)
if err != nil { if err != nil {
log.Error("Failed to check all users: ", err) log.Error("Failed to check all users: ", err)
templateVars["Error"] = "Failed to Create User" templateVars["Error"] = "Failed to Create User"
@ -225,7 +224,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
// Create user in DB // Create user in DB
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{ if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
ID: username, ID: username,
Pass: &hashedPassword, Pass: &hashedPassword,
AuthHash: &authHash, AuthHash: &authHash,
@ -243,7 +242,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
} }
// Get user // Get user
user, err := api.db.Queries.GetUser(c, username) user, err := api.db.Queries.GetUser(api.db.Ctx, username)
if err != nil { if err != nil {
log.Error("GetUser DB Error:", err) log.Error("GetUser DB Error:", err)
templateVars["Error"] = "Registration Disabled or User Already Exists" templateVars["Error"] = "Registration Disabled or User Already Exists"
@ -313,7 +312,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
} }
// Get current users // Get current users
currentUsers, err := api.db.Queries.GetUsers(c) currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx)
if err != nil { if err != nil {
log.Error("Failed to check all users: ", err) log.Error("Failed to check all users: ", err)
apiErrorPage(c, http.StatusBadRequest, "Failed to Create User") apiErrorPage(c, http.StatusBadRequest, "Failed to Create User")
@ -328,7 +327,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
// Create user // Create user
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{ if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
ID: rUser.Username, ID: rUser.Username,
Pass: &hashedPassword, Pass: &hashedPassword,
AuthHash: &authHash, AuthHash: &authHash,
@ -348,7 +347,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
}) })
} }
func (api *API) getSession(ctx context.Context, session sessions.Session) (auth authData, ok bool) { func (api *API) getSession(session sessions.Session) (auth authData, ok bool) {
// Get Session // Get Session
authorizedUser := session.Get("authorizedUser") authorizedUser := session.Get("authorizedUser")
isAdmin := session.Get("isAdmin") isAdmin := session.Get("isAdmin")
@ -366,7 +365,7 @@ func (api *API) getSession(ctx context.Context, session sessions.Session) (auth
} }
// Validate Auth Hash // Validate Auth Hash
correctAuthHash, err := api.getUserAuthHash(ctx, auth.UserName) correctAuthHash, err := api.getUserAuthHash(auth.UserName)
if err != nil || correctAuthHash != auth.AuthHash { if err != nil || correctAuthHash != auth.AuthHash {
return return
} }
@ -394,14 +393,14 @@ func (api *API) setSession(session sessions.Session, auth authData) error {
return session.Save() return session.Save()
} }
func (api *API) getUserAuthHash(ctx context.Context, username string) (string, error) { func (api *API) getUserAuthHash(username string) (string, error) {
// Return Cache // Return Cache
if api.userAuthCache[username] != "" { if api.userAuthCache[username] != "" {
return api.userAuthCache[username], nil return api.userAuthCache[username], nil
} }
// Get DB // Get DB
user, err := api.db.Queries.GetUser(ctx, username) user, err := api.db.Queries.GetUser(api.db.Ctx, username)
if err != nil { if err != nil {
log.Error("GetUser DB Error:", err) log.Error("GetUser DB Error:", err)
return "", err return "", err
@ -413,7 +412,7 @@ func (api *API) getUserAuthHash(ctx context.Context, username string) (string, e
return api.userAuthCache[username], nil return api.userAuthCache[username], nil
} }
func (api *API) rotateAllAuthHashes(ctx context.Context) error { func (api *API) rotateAllAuthHashes() error {
// Do Transaction // Do Transaction
tx, err := api.db.DB.Begin() tx, err := api.db.DB.Begin()
if err != nil { if err != nil {
@ -429,7 +428,7 @@ func (api *API) rotateAllAuthHashes(ctx context.Context) error {
}() }()
qtx := api.db.Queries.WithTx(tx) qtx := api.db.Queries.WithTx(tx)
users, err := qtx.GetUsers(ctx) users, err := qtx.GetUsers(api.db.Ctx)
if err != nil { if err != nil {
return err return err
} }
@ -445,7 +444,7 @@ func (api *API) rotateAllAuthHashes(ctx context.Context) error {
// Update User // Update User
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
if _, err = qtx.UpdateUser(ctx, database.UpdateUserParams{ if _, err = qtx.UpdateUser(api.db.Ctx, database.UpdateUserParams{
UserID: user.ID, UserID: user.ID,
AuthHash: &authHash, AuthHash: &authHash,
Admin: user.Admin, Admin: user.Admin,

View File

@ -22,7 +22,7 @@ func (api *API) createDownloadDocumentHandler(errorFunc func(*gin.Context, int,
} }
// Get Document // Get Document
document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID) document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID)
if err != nil { if err != nil {
log.Error("GetDocument DB Error:", err) log.Error("GetDocument DB Error:", err)
errorFunc(c, http.StatusBadRequest, "Unknown Document") errorFunc(c, http.StatusBadRequest, "Unknown Document")
@ -68,7 +68,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Validate Document Exists in DB // Validate Document Exists in DB
document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID) document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID)
if err != nil { if err != nil {
log.Error("GetDocument DB Error:", err) log.Error("GetDocument DB Error:", err)
errorFunc(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) errorFunc(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
@ -117,7 +117,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Store First Metadata Result // Store First Metadata Result
if _, err = api.db.Queries.AddMetadata(c, database.AddMetadataParams{ if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{
DocumentID: document.ID, DocumentID: document.ID,
Title: firstResult.Title, Title: firstResult.Title,
Author: firstResult.Author, Author: firstResult.Author,
@ -132,7 +132,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Upsert Document // Upsert Document
if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: document.ID, ID: document.ID,
Coverfile: &coverFile, Coverfile: &coverFile,
}); err != nil { }); err != nil {

View File

@ -91,7 +91,7 @@ func (api *API) koSetProgress(c *gin.Context) {
} }
// Upsert Device // Upsert Device
if _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{ if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
ID: rPosition.DeviceID, ID: rPosition.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rPosition.Device, DeviceName: rPosition.Device,
@ -101,14 +101,14 @@ func (api *API) koSetProgress(c *gin.Context) {
} }
// Upsert Document // Upsert Document
if _, err := api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: rPosition.DocumentID, ID: rPosition.DocumentID,
}); err != nil { }); err != nil {
log.Error("UpsertDocument DB Error:", err) log.Error("UpsertDocument DB Error:", err)
} }
// Create or Replace Progress // Create or Replace Progress
progress, err := api.db.Queries.UpdateProgress(c, database.UpdateProgressParams{ progress, err := api.db.Queries.UpdateProgress(api.db.Ctx, database.UpdateProgressParams{
Percentage: rPosition.Percentage, Percentage: rPosition.Percentage,
DocumentID: rPosition.DocumentID, DocumentID: rPosition.DocumentID,
DeviceID: rPosition.DeviceID, DeviceID: rPosition.DeviceID,
@ -140,7 +140,7 @@ func (api *API) koGetProgress(c *gin.Context) {
return return
} }
progress, err := api.db.Queries.GetDocumentProgress(c, database.GetDocumentProgressParams{ progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{
DocumentID: rDocID.DocumentID, DocumentID: rDocID.DocumentID,
UserID: auth.UserName, UserID: auth.UserName,
}) })
@ -202,7 +202,7 @@ func (api *API) koAddActivities(c *gin.Context) {
// Upsert Documents // Upsert Documents
for _, doc := range allDocuments { for _, doc := range allDocuments {
if _, err := qtx.UpsertDocument(c, database.UpsertDocumentParams{ if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: doc, ID: doc,
}); err != nil { }); err != nil {
log.Error("UpsertDocument DB Error:", err) log.Error("UpsertDocument DB Error:", err)
@ -212,7 +212,7 @@ func (api *API) koAddActivities(c *gin.Context) {
} }
// Upsert Device // Upsert Device
if _, err = qtx.UpsertDevice(c, database.UpsertDeviceParams{ if _, err = qtx.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
ID: rActivity.DeviceID, ID: rActivity.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rActivity.Device, DeviceName: rActivity.Device,
@ -225,7 +225,7 @@ func (api *API) koAddActivities(c *gin.Context) {
// Add All Activity // Add All Activity
for _, item := range rActivity.Activity { for _, item := range rActivity.Activity {
if _, err := qtx.AddActivity(c, database.AddActivityParams{ if _, err := qtx.AddActivity(api.db.Ctx, database.AddActivityParams{
UserID: auth.UserName, UserID: auth.UserName,
DocumentID: item.DocumentID, DocumentID: item.DocumentID,
DeviceID: rActivity.DeviceID, DeviceID: rActivity.DeviceID,
@ -266,7 +266,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) {
} }
// Upsert Device // Upsert Device
if _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{ if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
ID: rCheckActivity.DeviceID, ID: rCheckActivity.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rCheckActivity.Device, DeviceName: rCheckActivity.Device,
@ -278,7 +278,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) {
} }
// Get Last Device Activity // Get Last Device Activity
lastActivity, err := api.db.Queries.GetLastActivity(c, database.GetLastActivityParams{ lastActivity, err := api.db.Queries.GetLastActivity(api.db.Ctx, database.GetLastActivityParams{
UserID: auth.UserName, UserID: auth.UserName,
DeviceID: rCheckActivity.DeviceID, DeviceID: rCheckActivity.DeviceID,
}) })
@ -329,7 +329,7 @@ func (api *API) koAddDocuments(c *gin.Context) {
// Upsert Documents // Upsert Documents
for _, doc := range rNewDocs.Documents { for _, doc := range rNewDocs.Documents {
_, err := qtx.UpsertDocument(c, database.UpsertDocumentParams{ _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: doc.ID, ID: doc.ID,
Title: api.sanitizeInput(doc.Title), Title: api.sanitizeInput(doc.Title),
Author: api.sanitizeInput(doc.Author), Author: api.sanitizeInput(doc.Author),
@ -371,7 +371,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
} }
// Upsert Device // Upsert Device
_, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{ _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
ID: rCheckDocs.DeviceID, ID: rCheckDocs.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rCheckDocs.Device, DeviceName: rCheckDocs.Device,
@ -384,7 +384,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
} }
// Get Missing Documents // Get Missing Documents
missingDocs, err := api.db.Queries.GetMissingDocuments(c, rCheckDocs.Have) missingDocs, err := api.db.Queries.GetMissingDocuments(api.db.Ctx, rCheckDocs.Have)
if err != nil { if err != nil {
log.Error("GetMissingDocuments DB Error", err) log.Error("GetMissingDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request") apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@ -392,7 +392,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
} }
// Get Deleted Documents // Get Deleted Documents
deletedDocIDs, err := api.db.Queries.GetDeletedDocuments(c, rCheckDocs.Have) deletedDocIDs, err := api.db.Queries.GetDeletedDocuments(api.db.Ctx, rCheckDocs.Have)
if err != nil { if err != nil {
log.Error("GetDeletedDocuments DB Error", err) log.Error("GetDeletedDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request") apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@ -407,7 +407,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
return return
} }
wantedDocs, err := api.db.Queries.GetWantedDocuments(c, string(jsonHaves)) wantedDocs, err := api.db.Queries.GetWantedDocuments(api.db.Ctx, string(jsonHaves))
if err != nil { if err != nil {
log.Error("GetWantedDocuments DB Error", err) log.Error("GetWantedDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request") apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@ -467,7 +467,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
} }
// Validate Document Exists in DB // Validate Document Exists in DB
document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID) document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID)
if err != nil { if err != nil {
log.Error("GetDocument DB Error:", err) log.Error("GetDocument DB Error:", err)
apiErrorPage(c, http.StatusBadRequest, "Unknown Document") apiErrorPage(c, http.StatusBadRequest, "Unknown Document")
@ -522,7 +522,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
} }
// Upsert Document // Upsert Document
if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: document.ID, ID: document.ID,
Md5: metadataInfo.MD5, Md5: metadataInfo.MD5,
Words: metadataInfo.WordCount, Words: metadataInfo.WordCount,

View File

@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/antholume/database" "reichard.io/antholume/database"
"reichard.io/antholume/opds" "reichard.io/antholume/opds"
"reichard.io/antholume/pkg/ptr"
) )
var mimeMapping map[string]string = map[string]string{ var mimeMapping map[string]string = map[string]string{
@ -78,12 +77,11 @@ func (api *API) opdsDocuments(c *gin.Context) {
} }
// Get Documents // Get Documents
documents, err := api.db.Queries.GetDocumentsWithStats(c, database.GetDocumentsWithStatsParams{ documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{
UserID: auth.UserName, UserID: auth.UserName,
Query: query, Query: query,
Deleted: ptr.Of(false), Offset: (*qParams.Page - 1) * *qParams.Limit,
Offset: (*qParams.Page - 1) * *qParams.Limit, Limit: *qParams.Limit,
Limit: *qParams.Limit,
}) })
if err != nil { if err != nil {
log.Error("GetDocumentsWithStats DB Error:", err) log.Error("GetDocumentsWithStats DB Error:", err)

View File

@ -55,7 +55,6 @@ func getTimeZones() []string {
// niceSeconds takes in an int (in seconds) and returns a string readable // niceSeconds takes in an int (in seconds) and returns a string readable
// representation. For example 1928371 -> "22d 7h 39m 31s". // representation. For example 1928371 -> "22d 7h 39m 31s".
// Deprecated: Use formatters.FormatDuration
func niceSeconds(input int64) (result string) { func niceSeconds(input int64) (result string) {
if input == 0 { if input == 0 {
return "N/A" return "N/A"
@ -86,7 +85,6 @@ func niceSeconds(input int64) (result string) {
// niceNumbers takes in an int and returns a string representation. For example // niceNumbers takes in an int and returns a string representation. For example
// 19823 -> "19.8k". // 19823 -> "19.8k".
// Deprecated: Use formatters.FormatNumber
func niceNumbers(input int64) string { func niceNumbers(input int64) string {
if input == 0 { if input == 0 {
return "0" return "0"

View File

@ -1,27 +0,0 @@
package database
import (
"context"
"fmt"
"reichard.io/antholume/pkg/ptr"
"reichard.io/antholume/pkg/sliceutils"
)
func (d *DBManager) GetDocument(ctx context.Context, docID, userID string) (*GetDocumentsWithStatsRow, error) {
documents, err := d.Queries.GetDocumentsWithStats(ctx, GetDocumentsWithStatsParams{
ID: ptr.Of(docID),
UserID: userID,
Limit: 1,
})
if err != nil {
return nil, err
}
document, found := sliceutils.First(documents)
if !found {
return nil, fmt.Errorf("document not found: %s", docID)
}
return &document, nil
}

View File

@ -1,7 +1,6 @@
package database package database
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
@ -27,7 +26,7 @@ func (suite *DocumentsTestSuite) SetupTest() {
suite.dbm = NewMgr(&cfg) suite.dbm = NewMgr(&cfg)
// Create Document // Create Document
_, err := suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ _, err := suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{
ID: documentID, ID: documentID,
Title: &documentTitle, Title: &documentTitle,
Author: &documentAuthor, Author: &documentAuthor,
@ -43,7 +42,7 @@ func (suite *DocumentsTestSuite) SetupTest() {
// - 󰊕 (q *Queries) GetDocumentsWithStats // - 󰊕 (q *Queries) GetDocumentsWithStats
// - 󰊕 (q *Queries) GetMissingDocuments // - 󰊕 (q *Queries) GetMissingDocuments
func (suite *DocumentsTestSuite) TestGetDocument() { func (suite *DocumentsTestSuite) TestGetDocument() {
doc, err := suite.dbm.Queries.GetDocument(context.Background(), documentID) doc, err := suite.dbm.Queries.GetDocument(suite.dbm.Ctx, documentID)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Equal(documentID, doc.ID, "should have changed the document") suite.Equal(documentID, doc.ID, "should have changed the document")
} }
@ -51,7 +50,7 @@ func (suite *DocumentsTestSuite) TestGetDocument() {
func (suite *DocumentsTestSuite) TestUpsertDocument() { func (suite *DocumentsTestSuite) TestUpsertDocument() {
testDocID := "docid1" testDocID := "docid1"
doc, err := suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ doc, err := suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{
ID: testDocID, ID: testDocID,
Title: &documentTitle, Title: &documentTitle,
Author: &documentAuthor, Author: &documentAuthor,
@ -64,51 +63,51 @@ func (suite *DocumentsTestSuite) TestUpsertDocument() {
} }
func (suite *DocumentsTestSuite) TestDeleteDocument() { func (suite *DocumentsTestSuite) TestDeleteDocument() {
changed, err := suite.dbm.Queries.DeleteDocument(context.Background(), documentID) changed, err := suite.dbm.Queries.DeleteDocument(suite.dbm.Ctx, documentID)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Equal(int64(1), changed, "should have changed the document") suite.Equal(int64(1), changed, "should have changed the document")
doc, err := suite.dbm.Queries.GetDocument(context.Background(), documentID) doc, err := suite.dbm.Queries.GetDocument(suite.dbm.Ctx, documentID)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.True(doc.Deleted, "should have deleted the document") suite.True(doc.Deleted, "should have deleted the document")
} }
func (suite *DocumentsTestSuite) TestGetDeletedDocuments() { func (suite *DocumentsTestSuite) TestGetDeletedDocuments() {
changed, err := suite.dbm.Queries.DeleteDocument(context.Background(), documentID) changed, err := suite.dbm.Queries.DeleteDocument(suite.dbm.Ctx, documentID)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Equal(int64(1), changed, "should have changed the document") suite.Equal(int64(1), changed, "should have changed the document")
deletedDocs, err := suite.dbm.Queries.GetDeletedDocuments(context.Background(), []string{documentID}) deletedDocs, err := suite.dbm.Queries.GetDeletedDocuments(suite.dbm.Ctx, []string{documentID})
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Len(deletedDocs, 1, "should have one deleted document") suite.Len(deletedDocs, 1, "should have one deleted document")
} }
// TODO - Convert GetWantedDocuments -> (sqlc.slice('document_ids')); // TODO - Convert GetWantedDocuments -> (sqlc.slice('document_ids'));
func (suite *DocumentsTestSuite) TestGetWantedDocuments() { func (suite *DocumentsTestSuite) TestGetWantedDocuments() {
wantedDocs, err := suite.dbm.Queries.GetWantedDocuments(context.Background(), fmt.Sprintf("[\"%s\"]", documentID)) wantedDocs, err := suite.dbm.Queries.GetWantedDocuments(suite.dbm.Ctx, fmt.Sprintf("[\"%s\"]", documentID))
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Len(wantedDocs, 1, "should have one wanted document") suite.Len(wantedDocs, 1, "should have one wanted document")
} }
func (suite *DocumentsTestSuite) TestGetMissingDocuments() { func (suite *DocumentsTestSuite) TestGetMissingDocuments() {
// Create Document // Create Document
_, err := suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ _, err := suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{
ID: documentID, ID: documentID,
Filepath: &documentFilepath, Filepath: &documentFilepath,
}) })
suite.NoError(err) suite.NoError(err)
missingDocs, err := suite.dbm.Queries.GetMissingDocuments(context.Background(), []string{documentID}) missingDocs, err := suite.dbm.Queries.GetMissingDocuments(suite.dbm.Ctx, []string{documentID})
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Len(missingDocs, 0, "should have no wanted document") suite.Len(missingDocs, 0, "should have no wanted document")
missingDocs, err = suite.dbm.Queries.GetMissingDocuments(context.Background(), []string{"other"}) missingDocs, err = suite.dbm.Queries.GetMissingDocuments(suite.dbm.Ctx, []string{"other"})
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Len(missingDocs, 1, "should have one missing document") suite.Len(missingDocs, 1, "should have one missing document")
suite.Equal(documentID, missingDocs[0].ID, "should have missing doc") suite.Equal(documentID, missingDocs[0].ID, "should have missing doc")
// TODO - https://github.com/sqlc-dev/sqlc/issues/3451 // TODO - https://github.com/sqlc-dev/sqlc/issues/3451
// missingDocs, err = suite.dbm.Queries.GetMissingDocuments(context.Background(), []string{}) // missingDocs, err = suite.dbm.Queries.GetMissingDocuments(suite.dbm.Ctx, []string{})
// suite.Nil(err, "should have nil err") // suite.Nil(err, "should have nil err")
// suite.Len(missingDocs, 1, "should have one missing document") // suite.Len(missingDocs, 1, "should have one missing document")
// suite.Equal(documentID, missingDocs[0].ID, "should have missing doc") // suite.Equal(documentID, missingDocs[0].ID, "should have missing doc")

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"embed" "embed"
_ "embed"
"errors" "errors"
"fmt" "fmt"
"path/filepath" "path/filepath"
@ -19,6 +20,7 @@ import (
type DBManager struct { type DBManager struct {
DB *sql.DB DB *sql.DB
Ctx context.Context
Queries *Queries Queries *Queries
cfg *config.Config cfg *config.Config
} }
@ -52,9 +54,12 @@ func init() {
// NewMgr Returns an initialized manager // NewMgr Returns an initialized manager
func NewMgr(c *config.Config) *DBManager { func NewMgr(c *config.Config) *DBManager {
// Create Manager // Create Manager
dbm := &DBManager{cfg: c} dbm := &DBManager{
Ctx: context.Background(),
cfg: c,
}
if err := dbm.init(context.Background()); err != nil { if err := dbm.init(); err != nil {
log.Panic("Unable to init DB") log.Panic("Unable to init DB")
} }
@ -62,7 +67,7 @@ func NewMgr(c *config.Config) *DBManager {
} }
// init loads the DB manager // init loads the DB manager
func (dbm *DBManager) init(ctx context.Context) error { func (dbm *DBManager) init() error {
// Build DB Location // Build DB Location
var dbLocation string var dbLocation string
switch dbm.cfg.DBType { switch dbm.cfg.DBType {
@ -108,14 +113,14 @@ func (dbm *DBManager) init(ctx context.Context) error {
} }
// Update settings // Update settings
err = dbm.updateSettings(ctx) err = dbm.updateSettings()
if err != nil { if err != nil {
log.Panicf("Error running DB settings update: %v", err) log.Panicf("Error running DB settings update: %v", err)
return err return err
} }
// Cache tables // Cache tables
if err := dbm.CacheTempTables(ctx); err != nil { if err := dbm.CacheTempTables(); err != nil {
log.Warn("Refreshing temp table cache failed: ", err) log.Warn("Refreshing temp table cache failed: ", err)
} }
@ -123,7 +128,7 @@ func (dbm *DBManager) init(ctx context.Context) error {
} }
// Reload closes the DB & reinits // Reload closes the DB & reinits
func (dbm *DBManager) Reload(ctx context.Context) error { func (dbm *DBManager) Reload() error {
// Close handle // Close handle
err := dbm.DB.Close() err := dbm.DB.Close()
if err != nil { if err != nil {
@ -131,7 +136,7 @@ func (dbm *DBManager) Reload(ctx context.Context) error {
} }
// Reinit DB // Reinit DB
if err := dbm.init(ctx); err != nil { if err := dbm.init(); err != nil {
return err return err
} }
@ -139,15 +144,15 @@ func (dbm *DBManager) Reload(ctx context.Context) error {
} }
// CacheTempTables clears existing statistics and recalculates // CacheTempTables clears existing statistics and recalculates
func (dbm *DBManager) CacheTempTables(ctx context.Context) error { func (dbm *DBManager) CacheTempTables() error {
start := time.Now() start := time.Now()
if _, err := dbm.DB.ExecContext(ctx, user_streaks); err != nil { if _, err := dbm.DB.ExecContext(dbm.Ctx, user_streaks); err != nil {
return err return err
} }
log.Debug("Cached 'user_streaks' in: ", time.Since(start)) log.Debug("Cached 'user_streaks' in: ", time.Since(start))
start = time.Now() start = time.Now()
if _, err := dbm.DB.ExecContext(ctx, document_user_statistics); err != nil { if _, err := dbm.DB.ExecContext(dbm.Ctx, document_user_statistics); err != nil {
return err return err
} }
log.Debug("Cached 'document_user_statistics' in: ", time.Since(start)) log.Debug("Cached 'document_user_statistics' in: ", time.Since(start))
@ -157,7 +162,7 @@ func (dbm *DBManager) CacheTempTables(ctx context.Context) error {
// updateSettings ensures that we're enforcing foreign keys and enable journal // updateSettings ensures that we're enforcing foreign keys and enable journal
// mode. // mode.
func (dbm *DBManager) updateSettings(ctx context.Context) error { func (dbm *DBManager) updateSettings() error {
// Set SQLite PRAGMA Settings // Set SQLite PRAGMA Settings
pragmaQuery := ` pragmaQuery := `
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
@ -169,7 +174,7 @@ func (dbm *DBManager) updateSettings(ctx context.Context) error {
} }
// Update Antholume Version in DB // Update Antholume Version in DB
if _, err := dbm.Queries.UpdateSettings(ctx, UpdateSettingsParams{ if _, err := dbm.Queries.UpdateSettings(dbm.Ctx, UpdateSettingsParams{
Name: "version", Name: "version",
Value: dbm.cfg.Version, Value: dbm.cfg.Version,
}); err != nil { }); err != nil {

View File

@ -1,7 +1,6 @@
package database package database
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@ -47,7 +46,7 @@ func (suite *DatabaseTestSuite) SetupTest() {
// Create User // Create User
rawAuthHash, _ := utils.GenerateToken(64) rawAuthHash, _ := utils.GenerateToken(64)
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
_, err := suite.dbm.Queries.CreateUser(context.Background(), CreateUserParams{ _, err := suite.dbm.Queries.CreateUser(suite.dbm.Ctx, CreateUserParams{
ID: userID, ID: userID,
Pass: &userPass, Pass: &userPass,
AuthHash: &authHash, AuthHash: &authHash,
@ -55,7 +54,7 @@ func (suite *DatabaseTestSuite) SetupTest() {
suite.NoError(err) suite.NoError(err)
// Create Document // Create Document
_, err = suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ _, err = suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{
ID: documentID, ID: documentID,
Title: &documentTitle, Title: &documentTitle,
Author: &documentAuthor, Author: &documentAuthor,
@ -65,7 +64,7 @@ func (suite *DatabaseTestSuite) SetupTest() {
suite.NoError(err) suite.NoError(err)
// Create Device // Create Device
_, err = suite.dbm.Queries.UpsertDevice(context.Background(), UpsertDeviceParams{ _, err = suite.dbm.Queries.UpsertDevice(suite.dbm.Ctx, UpsertDeviceParams{
ID: deviceID, ID: deviceID,
UserID: userID, UserID: userID,
DeviceName: deviceName, DeviceName: deviceName,
@ -81,7 +80,7 @@ func (suite *DatabaseTestSuite) SetupTest() {
counter += 1 counter += 1
// Add Item // Add Item
activity, err := suite.dbm.Queries.AddActivity(context.Background(), AddActivityParams{ activity, err := suite.dbm.Queries.AddActivity(suite.dbm.Ctx, AddActivityParams{
DocumentID: documentID, DocumentID: documentID,
DeviceID: deviceID, DeviceID: deviceID,
UserID: userID, UserID: userID,
@ -96,7 +95,7 @@ func (suite *DatabaseTestSuite) SetupTest() {
} }
// Initiate Cache // Initiate Cache
err = suite.dbm.CacheTempTables(context.Background()) err = suite.dbm.CacheTempTables()
suite.NoError(err) suite.NoError(err)
} }
@ -106,7 +105,7 @@ func (suite *DatabaseTestSuite) SetupTest() {
// - 󰊕 (q *Queries) UpsertDevice // - 󰊕 (q *Queries) UpsertDevice
func (suite *DatabaseTestSuite) TestDevice() { func (suite *DatabaseTestSuite) TestDevice() {
testDevice := "dev123" testDevice := "dev123"
device, err := suite.dbm.Queries.UpsertDevice(context.Background(), UpsertDeviceParams{ device, err := suite.dbm.Queries.UpsertDevice(suite.dbm.Ctx, UpsertDeviceParams{
ID: testDevice, ID: testDevice,
UserID: userID, UserID: userID,
DeviceName: deviceName, DeviceName: deviceName,
@ -124,7 +123,7 @@ func (suite *DatabaseTestSuite) TestDevice() {
// - 󰊕 (q *Queries) GetLastActivity // - 󰊕 (q *Queries) GetLastActivity
func (suite *DatabaseTestSuite) TestActivity() { func (suite *DatabaseTestSuite) TestActivity() {
// Validate Exists // Validate Exists
existsRows, err := suite.dbm.Queries.GetActivity(context.Background(), GetActivityParams{ existsRows, err := suite.dbm.Queries.GetActivity(suite.dbm.Ctx, GetActivityParams{
UserID: userID, UserID: userID,
Offset: 0, Offset: 0,
Limit: 50, Limit: 50,
@ -134,7 +133,7 @@ func (suite *DatabaseTestSuite) TestActivity() {
suite.Len(existsRows, 10, "should have correct number of rows get activity") suite.Len(existsRows, 10, "should have correct number of rows get activity")
// Validate Doesn't Exist // Validate Doesn't Exist
doesntExistsRows, err := suite.dbm.Queries.GetActivity(context.Background(), GetActivityParams{ doesntExistsRows, err := suite.dbm.Queries.GetActivity(suite.dbm.Ctx, GetActivityParams{
UserID: userID, UserID: userID,
DocumentID: "unknownDoc", DocumentID: "unknownDoc",
DocFilter: true, DocFilter: true,
@ -152,7 +151,7 @@ func (suite *DatabaseTestSuite) TestActivity() {
// - 󰊕 (q *Queries) GetDatabaseInfo // - 󰊕 (q *Queries) GetDatabaseInfo
// - 󰊕 (q *Queries) UpdateSettings // - 󰊕 (q *Queries) UpdateSettings
func (suite *DatabaseTestSuite) TestGetDailyReadStats() { func (suite *DatabaseTestSuite) TestGetDailyReadStats() {
readStats, err := suite.dbm.Queries.GetDailyReadStats(context.Background(), userID) readStats, err := suite.dbm.Queries.GetDailyReadStats(suite.dbm.Ctx, userID)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Len(readStats, 30, "should have length of 30") suite.Len(readStats, 30, "should have length of 30")

View File

@ -163,6 +163,42 @@ ORDER BY
DESC DESC
LIMIT 1; LIMIT 1;
-- name: GetDocumentWithStats :one
SELECT
docs.id,
docs.title,
docs.author,
docs.description,
docs.isbn10,
docs.isbn13,
docs.filepath,
docs.words,
CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm,
COALESCE(dus.read_percentage, 0) AS read_percentage,
COALESCE(dus.total_time_seconds, 0) AS total_time_seconds,
STRFTIME('%Y-%m-%d %H:%M:%S', LOCAL_TIME(COALESCE(dus.last_read, STRFTIME('%Y-%m-%dT%H:%M:%SZ', 0, 'unixepoch')), users.timezone))
AS last_read,
ROUND(CAST(CASE
WHEN dus.percentage IS NULL THEN 0.0
WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0
ELSE dus.percentage * 100.0
END AS REAL), 2) AS percentage,
CAST(CASE
WHEN dus.total_time_seconds IS NULL THEN 0.0
ELSE
CAST(dus.total_time_seconds AS REAL)
/ (dus.read_percentage * 100.0)
END AS INTEGER) AS seconds_per_percent
FROM documents AS docs
LEFT JOIN users ON users.id = $user_id
LEFT JOIN
document_user_statistics AS dus
ON dus.document_id = docs.id AND dus.user_id = $user_id
WHERE users.id = $user_id
AND docs.id = $document_id
LIMIT 1;
-- name: GetDocuments :many -- name: GetDocuments :many
SELECT * FROM documents SELECT * FROM documents
ORDER BY created_at DESC ORDER BY created_at DESC
@ -200,25 +236,26 @@ SELECT
WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0 WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0
ELSE dus.percentage * 100.0 ELSE dus.percentage * 100.0
END AS REAL), 2) AS percentage, END AS REAL), 2) AS percentage,
CAST(CASE
CASE
WHEN dus.total_time_seconds IS NULL THEN 0.0 WHEN dus.total_time_seconds IS NULL THEN 0.0
ELSE ELSE
CAST(dus.total_time_seconds AS REAL) ROUND(
/ (dus.read_percentage * 100.0) CAST(dus.total_time_seconds AS REAL)
END AS INTEGER) AS seconds_per_percent / (dus.read_percentage * 100.0)
)
END AS seconds_per_percent
FROM documents AS docs FROM documents AS docs
LEFT JOIN users ON users.id = $user_id LEFT JOIN users ON users.id = $user_id
LEFT JOIN LEFT JOIN
document_user_statistics AS dus document_user_statistics AS dus
ON dus.document_id = docs.id AND dus.user_id = $user_id ON dus.document_id = docs.id AND dus.user_id = $user_id
WHERE WHERE
(docs.id = sqlc.narg('id') OR $id IS NULL) docs.deleted = false AND (
AND (docs.deleted = sqlc.narg(deleted) OR $deleted IS NULL) $query IS NULL OR (
AND ( docs.title LIKE $query OR
(
docs.title LIKE sqlc.narg('query') OR
docs.author LIKE $query docs.author LIKE $query
) OR $query IS NULL )
) )
ORDER BY dus.last_read DESC, docs.created_at DESC ORDER BY dus.last_read DESC, docs.created_at DESC
LIMIT $limit LIMIT $limit

View File

@ -543,6 +543,87 @@ func (q *Queries) GetDocumentProgress(ctx context.Context, arg GetDocumentProgre
return i, err return i, err
} }
const getDocumentWithStats = `-- name: GetDocumentWithStats :one
SELECT
docs.id,
docs.title,
docs.author,
docs.description,
docs.isbn10,
docs.isbn13,
docs.filepath,
docs.words,
CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm,
COALESCE(dus.read_percentage, 0) AS read_percentage,
COALESCE(dus.total_time_seconds, 0) AS total_time_seconds,
STRFTIME('%Y-%m-%d %H:%M:%S', LOCAL_TIME(COALESCE(dus.last_read, STRFTIME('%Y-%m-%dT%H:%M:%SZ', 0, 'unixepoch')), users.timezone))
AS last_read,
ROUND(CAST(CASE
WHEN dus.percentage IS NULL THEN 0.0
WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0
ELSE dus.percentage * 100.0
END AS REAL), 2) AS percentage,
CAST(CASE
WHEN dus.total_time_seconds IS NULL THEN 0.0
ELSE
CAST(dus.total_time_seconds AS REAL)
/ (dus.read_percentage * 100.0)
END AS INTEGER) AS seconds_per_percent
FROM documents AS docs
LEFT JOIN users ON users.id = ?1
LEFT JOIN
document_user_statistics AS dus
ON dus.document_id = docs.id AND dus.user_id = ?1
WHERE users.id = ?1
AND docs.id = ?2
LIMIT 1
`
type GetDocumentWithStatsParams struct {
UserID string `json:"user_id"`
DocumentID string `json:"document_id"`
}
type GetDocumentWithStatsRow struct {
ID string `json:"id"`
Title *string `json:"title"`
Author *string `json:"author"`
Description *string `json:"description"`
Isbn10 *string `json:"isbn10"`
Isbn13 *string `json:"isbn13"`
Filepath *string `json:"filepath"`
Words *int64 `json:"words"`
Wpm int64 `json:"wpm"`
ReadPercentage float64 `json:"read_percentage"`
TotalTimeSeconds int64 `json:"total_time_seconds"`
LastRead interface{} `json:"last_read"`
Percentage float64 `json:"percentage"`
SecondsPerPercent int64 `json:"seconds_per_percent"`
}
func (q *Queries) GetDocumentWithStats(ctx context.Context, arg GetDocumentWithStatsParams) (GetDocumentWithStatsRow, error) {
row := q.db.QueryRowContext(ctx, getDocumentWithStats, arg.UserID, arg.DocumentID)
var i GetDocumentWithStatsRow
err := row.Scan(
&i.ID,
&i.Title,
&i.Author,
&i.Description,
&i.Isbn10,
&i.Isbn13,
&i.Filepath,
&i.Words,
&i.Wpm,
&i.ReadPercentage,
&i.TotalTimeSeconds,
&i.LastRead,
&i.Percentage,
&i.SecondsPerPercent,
)
return i, err
}
const getDocuments = `-- name: GetDocuments :many const getDocuments = `-- name: GetDocuments :many
SELECT id, md5, basepath, filepath, coverfile, title, author, series, series_index, lang, description, words, gbid, olid, isbn10, isbn13, synced, deleted, updated_at, created_at FROM documents SELECT id, md5, basepath, filepath, coverfile, title, author, series, series_index, lang, description, words, gbid, olid, isbn10, isbn13, synced, deleted, updated_at, created_at FROM documents
ORDER BY created_at DESC ORDER BY created_at DESC
@ -638,38 +719,37 @@ SELECT
WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0 WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0
ELSE dus.percentage * 100.0 ELSE dus.percentage * 100.0
END AS REAL), 2) AS percentage, END AS REAL), 2) AS percentage,
CAST(CASE
CASE
WHEN dus.total_time_seconds IS NULL THEN 0.0 WHEN dus.total_time_seconds IS NULL THEN 0.0
ELSE ELSE
CAST(dus.total_time_seconds AS REAL) ROUND(
/ (dus.read_percentage * 100.0) CAST(dus.total_time_seconds AS REAL)
END AS INTEGER) AS seconds_per_percent / (dus.read_percentage * 100.0)
)
END AS seconds_per_percent
FROM documents AS docs FROM documents AS docs
LEFT JOIN users ON users.id = ?1 LEFT JOIN users ON users.id = ?1
LEFT JOIN LEFT JOIN
document_user_statistics AS dus document_user_statistics AS dus
ON dus.document_id = docs.id AND dus.user_id = ?1 ON dus.document_id = docs.id AND dus.user_id = ?1
WHERE WHERE
(docs.id = ?2 OR ?2 IS NULL) docs.deleted = false AND (
AND (docs.deleted = ?3 OR ?3 IS NULL) ?2 IS NULL OR (
AND ( docs.title LIKE ?2 OR
( docs.author LIKE ?2
docs.title LIKE ?4 OR )
docs.author LIKE ?4
) OR ?4 IS NULL
) )
ORDER BY dus.last_read DESC, docs.created_at DESC ORDER BY dus.last_read DESC, docs.created_at DESC
LIMIT ?6 LIMIT ?4
OFFSET ?5 OFFSET ?3
` `
type GetDocumentsWithStatsParams struct { type GetDocumentsWithStatsParams struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
ID *string `json:"id"` Query interface{} `json:"query"`
Deleted *bool `json:"-"` Offset int64 `json:"offset"`
Query *string `json:"query"` Limit int64 `json:"limit"`
Offset int64 `json:"offset"`
Limit int64 `json:"limit"`
} }
type GetDocumentsWithStatsRow struct { type GetDocumentsWithStatsRow struct {
@ -686,14 +766,12 @@ type GetDocumentsWithStatsRow struct {
TotalTimeSeconds int64 `json:"total_time_seconds"` TotalTimeSeconds int64 `json:"total_time_seconds"`
LastRead interface{} `json:"last_read"` LastRead interface{} `json:"last_read"`
Percentage float64 `json:"percentage"` Percentage float64 `json:"percentage"`
SecondsPerPercent int64 `json:"seconds_per_percent"` SecondsPerPercent interface{} `json:"seconds_per_percent"`
} }
func (q *Queries) GetDocumentsWithStats(ctx context.Context, arg GetDocumentsWithStatsParams) ([]GetDocumentsWithStatsRow, error) { func (q *Queries) GetDocumentsWithStats(ctx context.Context, arg GetDocumentsWithStatsParams) ([]GetDocumentsWithStatsRow, error) {
rows, err := q.db.QueryContext(ctx, getDocumentsWithStats, rows, err := q.db.QueryContext(ctx, getDocumentsWithStats,
arg.UserID, arg.UserID,
arg.ID,
arg.Deleted,
arg.Query, arg.Query,
arg.Offset, arg.Offset,
arg.Limit, arg.Limit,

View File

@ -1,7 +1,6 @@
package database package database
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"testing" "testing"
@ -37,7 +36,7 @@ func (suite *UsersTestSuite) SetupTest() {
// Create User // Create User
rawAuthHash, _ := utils.GenerateToken(64) rawAuthHash, _ := utils.GenerateToken(64)
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
_, err := suite.dbm.Queries.CreateUser(context.Background(), CreateUserParams{ _, err := suite.dbm.Queries.CreateUser(suite.dbm.Ctx, CreateUserParams{
ID: testUserID, ID: testUserID,
Pass: &testUserPass, Pass: &testUserPass,
AuthHash: &authHash, AuthHash: &authHash,
@ -45,7 +44,7 @@ func (suite *UsersTestSuite) SetupTest() {
suite.NoError(err) suite.NoError(err)
// Create Document // Create Document
_, err = suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{ _, err = suite.dbm.Queries.UpsertDocument(suite.dbm.Ctx, UpsertDocumentParams{
ID: documentID, ID: documentID,
Title: &documentTitle, Title: &documentTitle,
Author: &documentAuthor, Author: &documentAuthor,
@ -54,7 +53,7 @@ func (suite *UsersTestSuite) SetupTest() {
suite.NoError(err) suite.NoError(err)
// Create Device // Create Device
_, err = suite.dbm.Queries.UpsertDevice(context.Background(), UpsertDeviceParams{ _, err = suite.dbm.Queries.UpsertDevice(suite.dbm.Ctx, UpsertDeviceParams{
ID: deviceID, ID: deviceID,
UserID: testUserID, UserID: testUserID,
DeviceName: deviceName, DeviceName: deviceName,
@ -63,7 +62,7 @@ func (suite *UsersTestSuite) SetupTest() {
} }
func (suite *UsersTestSuite) TestGetUser() { func (suite *UsersTestSuite) TestGetUser() {
user, err := suite.dbm.Queries.GetUser(context.Background(), testUserID) user, err := suite.dbm.Queries.GetUser(suite.dbm.Ctx, testUserID)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Equal(testUserPass, *user.Pass) suite.Equal(testUserPass, *user.Pass)
} }
@ -77,7 +76,7 @@ func (suite *UsersTestSuite) TestCreateUser() {
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
changed, err := suite.dbm.Queries.CreateUser(context.Background(), CreateUserParams{ changed, err := suite.dbm.Queries.CreateUser(suite.dbm.Ctx, CreateUserParams{
ID: testUser, ID: testUser,
Pass: &testPass, Pass: &testPass,
AuthHash: &authHash, AuthHash: &authHash,
@ -86,29 +85,29 @@ func (suite *UsersTestSuite) TestCreateUser() {
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Equal(int64(1), changed) suite.Equal(int64(1), changed)
user, err := suite.dbm.Queries.GetUser(context.Background(), testUser) user, err := suite.dbm.Queries.GetUser(suite.dbm.Ctx, testUser)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Equal(testPass, *user.Pass) suite.Equal(testPass, *user.Pass)
} }
func (suite *UsersTestSuite) TestDeleteUser() { func (suite *UsersTestSuite) TestDeleteUser() {
changed, err := suite.dbm.Queries.DeleteUser(context.Background(), testUserID) changed, err := suite.dbm.Queries.DeleteUser(suite.dbm.Ctx, testUserID)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Equal(int64(1), changed, "should have one changed row") suite.Equal(int64(1), changed, "should have one changed row")
_, err = suite.dbm.Queries.GetUser(context.Background(), testUserID) _, err = suite.dbm.Queries.GetUser(suite.dbm.Ctx, testUserID)
suite.ErrorIs(err, sql.ErrNoRows, "should have no rows error") suite.ErrorIs(err, sql.ErrNoRows, "should have no rows error")
} }
func (suite *UsersTestSuite) TestGetUsers() { func (suite *UsersTestSuite) TestGetUsers() {
users, err := suite.dbm.Queries.GetUsers(context.Background()) users, err := suite.dbm.Queries.GetUsers(suite.dbm.Ctx)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Len(users, 1, "should have single user") suite.Len(users, 1, "should have single user")
} }
func (suite *UsersTestSuite) TestUpdateUser() { func (suite *UsersTestSuite) TestUpdateUser() {
newPassword := "newPass123" newPassword := "newPass123"
user, err := suite.dbm.Queries.UpdateUser(context.Background(), UpdateUserParams{ user, err := suite.dbm.Queries.UpdateUser(suite.dbm.Ctx, UpdateUserParams{
UserID: testUserID, UserID: testUserID,
Password: &newPassword, Password: &newPassword,
}) })
@ -117,11 +116,11 @@ func (suite *UsersTestSuite) TestUpdateUser() {
} }
func (suite *UsersTestSuite) TestGetUserStatistics() { func (suite *UsersTestSuite) TestGetUserStatistics() {
err := suite.dbm.CacheTempTables(context.Background()) err := suite.dbm.CacheTempTables()
suite.NoError(err) suite.NoError(err)
// Ensure Zero Items // Ensure Zero Items
userStats, err := suite.dbm.Queries.GetUserStatistics(context.Background()) userStats, err := suite.dbm.Queries.GetUserStatistics(suite.dbm.Ctx)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Empty(userStats, "should be empty") suite.Empty(userStats, "should be empty")
@ -134,7 +133,7 @@ func (suite *UsersTestSuite) TestGetUserStatistics() {
counter += 1 counter += 1
// Add Item // Add Item
activity, err := suite.dbm.Queries.AddActivity(context.Background(), AddActivityParams{ activity, err := suite.dbm.Queries.AddActivity(suite.dbm.Ctx, AddActivityParams{
DocumentID: documentID, DocumentID: documentID,
DeviceID: deviceID, DeviceID: deviceID,
UserID: testUserID, UserID: testUserID,
@ -148,21 +147,21 @@ func (suite *UsersTestSuite) TestGetUserStatistics() {
suite.Equal(counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter)) suite.Equal(counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter))
} }
err = suite.dbm.CacheTempTables(context.Background()) err = suite.dbm.CacheTempTables()
suite.NoError(err) suite.NoError(err)
// Ensure One Item // Ensure One Item
userStats, err = suite.dbm.Queries.GetUserStatistics(context.Background()) userStats, err = suite.dbm.Queries.GetUserStatistics(suite.dbm.Ctx)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Len(userStats, 1, "should have length of one") suite.Len(userStats, 1, "should have length of one")
} }
func (suite *UsersTestSuite) TestGetUsersStreaks() { func (suite *UsersTestSuite) TestGetUsersStreaks() {
err := suite.dbm.CacheTempTables(context.Background()) err := suite.dbm.CacheTempTables()
suite.NoError(err) suite.NoError(err)
// Ensure Zero Items // Ensure Zero Items
userStats, err := suite.dbm.Queries.GetUserStreaks(context.Background(), testUserID) userStats, err := suite.dbm.Queries.GetUserStreaks(suite.dbm.Ctx, testUserID)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Empty(userStats, "should be empty") suite.Empty(userStats, "should be empty")
@ -175,7 +174,7 @@ func (suite *UsersTestSuite) TestGetUsersStreaks() {
counter += 1 counter += 1
// Add Item // Add Item
activity, err := suite.dbm.Queries.AddActivity(context.Background(), AddActivityParams{ activity, err := suite.dbm.Queries.AddActivity(suite.dbm.Ctx, AddActivityParams{
DocumentID: documentID, DocumentID: documentID,
DeviceID: deviceID, DeviceID: deviceID,
UserID: testUserID, UserID: testUserID,
@ -189,11 +188,11 @@ func (suite *UsersTestSuite) TestGetUsersStreaks() {
suite.Equal(counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter)) suite.Equal(counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter))
} }
err = suite.dbm.CacheTempTables(context.Background()) err = suite.dbm.CacheTempTables()
suite.NoError(err) suite.NoError(err)
// Ensure Two Item // Ensure Two Item
userStats, err = suite.dbm.Queries.GetUserStreaks(context.Background(), testUserID) userStats, err = suite.dbm.Queries.GetUserStreaks(suite.dbm.Ctx, testUserID)
suite.Nil(err, "should have nil err") suite.Nil(err, "should have nil err")
suite.Len(userStats, 2, "should have length of two") suite.Len(userStats, 2, "should have length of two")

View File

@ -1,37 +0,0 @@
package formatters
import (
"fmt"
"strings"
"time"
)
// FormatDuration takes a duration and returns a human-readable duration string.
// For example: 1928371 seconds -> "22d 7h 39m 31s"
func FormatDuration(d time.Duration) string {
if d == 0 {
return "N/A"
}
var parts []string
days := int(d.Hours()) / 24
hours := int(d.Hours()) % 24
minutes := int(d.Minutes()) % 60
seconds := int(d.Seconds()) % 60
if days > 0 {
parts = append(parts, fmt.Sprintf("%dd", days))
}
if hours > 0 {
parts = append(parts, fmt.Sprintf("%dh", hours))
}
if minutes > 0 {
parts = append(parts, fmt.Sprintf("%dm", minutes))
}
if seconds > 0 {
parts = append(parts, fmt.Sprintf("%ds", seconds))
}
return strings.Join(parts, " ")
}

View File

@ -1,45 +0,0 @@
package formatters
import (
"fmt"
"math"
)
// FormatNumber takes an int64 and returns a human-readable string.
// For example: 19823 -> "19.8k", 1500000 -> "1.5M"
func FormatNumber(input int64) string {
if input == 0 {
return "0"
}
// Handle Negative
negative := input < 0
if negative {
input = -input
}
abbreviations := []string{"", "k", "M", "B", "T"}
abbrevIndex := int(math.Log10(float64(input)) / 3)
// Bounds Check
if abbrevIndex >= len(abbreviations) {
abbrevIndex = len(abbreviations) - 1
}
scaledNumber := float64(input) / math.Pow(10, float64(abbrevIndex*3))
var result string
if scaledNumber >= 100 {
result = fmt.Sprintf("%.0f%s", scaledNumber, abbreviations[abbrevIndex])
} else if scaledNumber >= 10 {
result = fmt.Sprintf("%.1f%s", scaledNumber, abbreviations[abbrevIndex])
} else {
result = fmt.Sprintf("%.2f%s", scaledNumber, abbreviations[abbrevIndex])
}
if negative {
result = "-" + result
}
return result
}

View File

@ -1,13 +0,0 @@
package ptr
func Of[T any](v T) *T {
return &v
}
func Deref[T any](v *T) T {
var zeroT T
if v == nil {
return zeroT
}
return *v
}

View File

@ -1,17 +0,0 @@
package sliceutils
func First[T any](s []T) (T, bool) {
if len(s) == 0 {
var zeroT T
return zeroT, false
}
return s[0], true
}
func Map[R, I any](s []I, f func(I) R) []R {
r := make([]R, 0, len(s))
for _, v := range s {
r = append(r, f(v))
}
return r
}

View File

@ -1,18 +0,0 @@
package utils
func Ternary[T any](cond bool, tVal, fVal T) T {
if cond {
return tVal
}
return fVal
}
func FirstNonZero[T comparable](v ...T) T {
var zero T
for _, val := range v {
if val != zero {
return val
}
}
return zero
}

View File

@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"io/fs" "io/fs"
"net/http" "net/http"
"sync" "sync"
@ -53,14 +52,12 @@ func (s *server) Start() {
ticker := time.NewTicker(15 * time.Minute) ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Minute))
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
s.runScheduledTasks(ctx) s.runScheduledTasks()
case <-s.done: case <-s.done:
log.Info("Stopping task runner...") log.Info("Stopping task runner...")
cancel()
return return
} }
} }
@ -84,9 +81,9 @@ func (s *server) Stop() {
} }
// Run normal scheduled tasks // Run normal scheduled tasks
func (s *server) runScheduledTasks(ctx context.Context) { func (s *server) runScheduledTasks() {
start := time.Now() start := time.Now()
if err := s.db.CacheTempTables(ctx); err != nil { if err := s.db.CacheTempTables(); err != nil {
log.Warn("Refreshing temp table cache failed: ", err) log.Warn("Refreshing temp table cache failed: ", err)
} }
log.Debug("Completed in: ", time.Since(start)) log.Debug("Completed in: ", time.Since(start))