diff --git a/api/api.go b/api/api.go index 5aecdcd..75b7e3a 100644 --- a/api/api.go +++ b/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/rand" "embed" "fmt" @@ -22,29 +23,37 @@ import ( ) type API struct { - Router *gin.Engine - Config *config.Config - DB *database.DBManager - HTMLPolicy *bluemonday.Policy - Assets *embed.FS - Templates map[string]*template.Template + db *database.DBManager + cfg *config.Config + assets *embed.FS + templates map[string]*template.Template + httpServer *http.Server } +var htmlPolicy = bluemonday.StrictPolicy() + func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API { api := &API{ - HTMLPolicy: bluemonday.StrictPolicy(), - Router: gin.New(), - Config: c, - DB: db, - Assets: assets, + db: db, + cfg: c, + assets: assets, + } + + // Create Router + router := gin.New() + + // Add Server + api.httpServer = &http.Server{ + Handler: router, + Addr: (":" + c.ListenPort), } // Add Logger - api.Router.Use(apiLogger()) + router.Use(apiLogger()) // Assets & Web App Templates assetsDir, _ := fs.Sub(assets, "assets") - api.Router.StaticFS("/assets", http.FS(assetsDir)) + router.StaticFS("/assets", http.FS(assetsDir)) // Generate Auth Token var newToken []byte @@ -78,74 +87,92 @@ func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API { HttpOnly: c.CookieHTTPOnly, SameSite: http.SameSiteStrictMode, }) - api.Router.Use(sessions.Sessions("token", store)) + router.Use(sessions.Sessions("token", store)) // Register Web App Route - api.registerWebAppRoutes() + api.registerWebAppRoutes(router) // Register API Routes - apiGroup := api.Router.Group("/api") + apiGroup := router.Group("/api") api.registerKOAPIRoutes(apiGroup) api.registerOPDSRoutes(apiGroup) return api } -func (api *API) registerWebAppRoutes() { +func (api *API) Start() error { + return api.httpServer.ListenAndServe() + +} +func (api *API) Stop() error { + // Stop Server + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + err := api.httpServer.Shutdown(ctx) + if err != nil { + return err + } + + // Close DB + return api.db.DB.Close() + +} + +func (api *API) registerWebAppRoutes(router *gin.Engine) { // Generate Templates - api.Router.HTMLRender = *api.generateTemplates() + router.HTMLRender = *api.generateTemplates() // Static Assets (Required @ Root) - api.Router.GET("/manifest.json", api.appWebManifest) - api.Router.GET("/favicon.ico", api.appFaviconIcon) - api.Router.GET("/sw.js", api.appServiceWorker) + router.GET("/manifest.json", api.appWebManifest) + router.GET("/favicon.ico", api.appFaviconIcon) + router.GET("/sw.js", api.appServiceWorker) // Local / Offline Static Pages (No Template, No Auth) - api.Router.GET("/local", api.appLocalDocuments) + router.GET("/local", api.appLocalDocuments) // Reader (Reader Page, Document Progress, Devices) - api.Router.GET("/reader", api.appDocumentReader) - api.Router.GET("/reader/devices", api.authWebAppMiddleware, api.appGetDevices) - api.Router.GET("/reader/progress/:document", api.authWebAppMiddleware, api.appGetDocumentProgress) + router.GET("/reader", api.appDocumentReader) + router.GET("/reader/devices", api.authWebAppMiddleware, api.appGetDevices) + router.GET("/reader/progress/:document", api.authWebAppMiddleware, api.appGetDocumentProgress) // Web App - api.Router.GET("/", api.authWebAppMiddleware, api.appGetHome) - api.Router.GET("/activity", api.authWebAppMiddleware, api.appGetActivity) - api.Router.GET("/progress", api.authWebAppMiddleware, api.appGetProgress) - api.Router.GET("/documents", api.authWebAppMiddleware, api.appGetDocuments) - api.Router.GET("/documents/:document", api.authWebAppMiddleware, api.appGetDocument) - api.Router.GET("/documents/:document/cover", api.authWebAppMiddleware, api.createGetCoverHandler(appErrorPage)) - api.Router.GET("/documents/:document/file", api.authWebAppMiddleware, api.createDownloadDocumentHandler(appErrorPage)) - api.Router.GET("/login", api.appGetLogin) - api.Router.GET("/logout", api.authWebAppMiddleware, api.appAuthLogout) - api.Router.GET("/register", api.appGetRegister) - api.Router.GET("/settings", api.authWebAppMiddleware, api.appGetSettings) - api.Router.GET("/admin/logs", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminLogs) - api.Router.GET("/admin/users", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminUsers) - api.Router.GET("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdmin) - api.Router.POST("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appPerformAdminAction) - api.Router.POST("/login", api.appAuthFormLogin) - api.Router.POST("/register", api.appAuthFormRegister) + router.GET("/", api.authWebAppMiddleware, api.appGetHome) + router.GET("/activity", api.authWebAppMiddleware, api.appGetActivity) + router.GET("/progress", api.authWebAppMiddleware, api.appGetProgress) + router.GET("/documents", api.authWebAppMiddleware, api.appGetDocuments) + router.GET("/documents/:document", api.authWebAppMiddleware, api.appGetDocument) + router.GET("/documents/:document/cover", api.authWebAppMiddleware, api.createGetCoverHandler(appErrorPage)) + router.GET("/documents/:document/file", api.authWebAppMiddleware, api.createDownloadDocumentHandler(appErrorPage)) + router.GET("/login", api.appGetLogin) + router.GET("/logout", api.authWebAppMiddleware, api.appAuthLogout) + router.GET("/register", api.appGetRegister) + router.GET("/settings", api.authWebAppMiddleware, api.appGetSettings) + router.GET("/admin/logs", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminLogs) + router.GET("/admin/users", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminUsers) + router.GET("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdmin) + router.POST("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appPerformAdminAction) + router.POST("/login", api.appAuthFormLogin) + router.POST("/register", api.appAuthFormRegister) // Demo Mode Enabled Configuration - if api.Config.DemoMode { - api.Router.POST("/documents", api.authWebAppMiddleware, api.appDemoModeError) - api.Router.POST("/documents/:document/delete", api.authWebAppMiddleware, api.appDemoModeError) - api.Router.POST("/documents/:document/edit", api.authWebAppMiddleware, api.appDemoModeError) - api.Router.POST("/documents/:document/identify", api.authWebAppMiddleware, api.appDemoModeError) - api.Router.POST("/settings", api.authWebAppMiddleware, api.appDemoModeError) + if api.cfg.DemoMode { + router.POST("/documents", api.authWebAppMiddleware, api.appDemoModeError) + router.POST("/documents/:document/delete", api.authWebAppMiddleware, api.appDemoModeError) + router.POST("/documents/:document/edit", api.authWebAppMiddleware, api.appDemoModeError) + router.POST("/documents/:document/identify", api.authWebAppMiddleware, api.appDemoModeError) + router.POST("/settings", api.authWebAppMiddleware, api.appDemoModeError) } else { - api.Router.POST("/documents", api.authWebAppMiddleware, api.appUploadNewDocument) - api.Router.POST("/documents/:document/delete", api.authWebAppMiddleware, api.appDeleteDocument) - api.Router.POST("/documents/:document/edit", api.authWebAppMiddleware, api.appEditDocument) - api.Router.POST("/documents/:document/identify", api.authWebAppMiddleware, api.appIdentifyDocument) - api.Router.POST("/settings", api.authWebAppMiddleware, api.appEditSettings) + router.POST("/documents", api.authWebAppMiddleware, api.appUploadNewDocument) + router.POST("/documents/:document/delete", api.authWebAppMiddleware, api.appDeleteDocument) + router.POST("/documents/:document/edit", api.authWebAppMiddleware, api.appEditDocument) + router.POST("/documents/:document/identify", api.authWebAppMiddleware, api.appIdentifyDocument) + router.POST("/settings", api.authWebAppMiddleware, api.appEditSettings) } // Search Enabled Configuration - if api.Config.SearchEnabled { - api.Router.GET("/search", api.authWebAppMiddleware, api.appGetSearch) - api.Router.POST("/search", api.authWebAppMiddleware, api.appSaveNewDocument) + if api.cfg.SearchEnabled { + router.GET("/search", api.authWebAppMiddleware, api.appGetSearch) + router.POST("/search", api.authWebAppMiddleware, api.appSaveNewDocument) } } @@ -162,7 +189,7 @@ func (api *API) registerKOAPIRoutes(apiGroup *gin.RouterGroup) { koGroup.PUT("/syncs/progress", api.authKOMiddleware, api.koSetProgress) // Demo Mode Enabled Configuration - if api.Config.DemoMode { + if api.cfg.DemoMode { koGroup.POST("/documents", api.authKOMiddleware, api.koDemoModeJSONError) koGroup.POST("/syncs/documents", api.authKOMiddleware, api.koDemoModeJSONError) koGroup.PUT("/documents/:document/file", api.authKOMiddleware, api.koDemoModeJSONError) @@ -200,50 +227,50 @@ func (api *API) generateTemplates() *multitemplate.Renderer { } // Load Base - b, _ := api.Assets.ReadFile("templates/base.tmpl") + b, _ := api.assets.ReadFile("templates/base.tmpl") baseTemplate := template.Must(template.New("base").Funcs(helperFuncs).Parse(string(b))) // Load SVGs - svgs, _ := api.Assets.ReadDir("templates/svgs") + svgs, _ := api.assets.ReadDir("templates/svgs") for _, item := range svgs { basename := item.Name() path := fmt.Sprintf("templates/svgs/%s", basename) name := strings.TrimSuffix(basename, filepath.Ext(basename)) - b, _ := api.Assets.ReadFile(path) + b, _ := api.assets.ReadFile(path) baseTemplate = template.Must(baseTemplate.New("svg/" + name).Parse(string(b))) templates["svg/"+name] = baseTemplate } // Load Components - components, _ := api.Assets.ReadDir("templates/components") + components, _ := api.assets.ReadDir("templates/components") for _, item := range components { basename := item.Name() path := fmt.Sprintf("templates/components/%s", basename) name := strings.TrimSuffix(basename, filepath.Ext(basename)) // Clone Base Template - b, _ := api.Assets.ReadFile(path) + b, _ := api.assets.ReadFile(path) baseTemplate = template.Must(baseTemplate.New("component/" + name).Parse(string(b))) render.Add("component/"+name, baseTemplate) templates["component/"+name] = baseTemplate } // Load Pages - pages, _ := api.Assets.ReadDir("templates/pages") + pages, _ := api.assets.ReadDir("templates/pages") for _, item := range pages { basename := item.Name() path := fmt.Sprintf("templates/pages/%s", basename) name := strings.TrimSuffix(basename, filepath.Ext(basename)) // Clone Base Template - b, _ := api.Assets.ReadFile(path) + b, _ := api.assets.ReadFile(path) pageTemplate, _ := template.Must(baseTemplate.Clone()).New("page/" + name).Parse(string(b)) render.Add("page/"+name, pageTemplate) templates["page/"+name] = pageTemplate } - api.Templates = templates + api.templates = templates return &render } diff --git a/api/app-routes.go b/api/app-routes.go index 8afd938..1342e98 100644 --- a/api/app-routes.go +++ b/api/app-routes.go @@ -116,23 +116,23 @@ type requestDocumentAdd struct { func (api *API) appWebManifest(c *gin.Context) { c.Header("Content-Type", "application/manifest+json") - c.FileFromFS("assets/manifest.json", http.FS(api.Assets)) + c.FileFromFS("assets/manifest.json", http.FS(api.assets)) } func (api *API) appServiceWorker(c *gin.Context) { - c.FileFromFS("assets/sw.js", http.FS(api.Assets)) + c.FileFromFS("assets/sw.js", http.FS(api.assets)) } func (api *API) appFaviconIcon(c *gin.Context) { - c.FileFromFS("assets/icons/favicon.ico", http.FS(api.Assets)) + c.FileFromFS("assets/icons/favicon.ico", http.FS(api.assets)) } func (api *API) appLocalDocuments(c *gin.Context) { - c.FileFromFS("assets/local/index.htm", http.FS(api.Assets)) + c.FileFromFS("assets/local/index.htm", http.FS(api.assets)) } func (api *API) appDocumentReader(c *gin.Context) { - c.FileFromFS("assets/reader/index.htm", http.FS(api.Assets)) + c.FileFromFS("assets/reader/index.htm", http.FS(api.assets)) } func (api *API) appGetDocuments(c *gin.Context) { @@ -145,7 +145,7 @@ func (api *API) appGetDocuments(c *gin.Context) { query = &search } - documents, err := api.DB.Queries.GetDocumentsWithStats(api.DB.Ctx, database.GetDocumentsWithStatsParams{ + documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{ UserID: auth.UserName, Query: query, Offset: (*qParams.Page - 1) * *qParams.Limit, @@ -157,7 +157,7 @@ func (api *API) appGetDocuments(c *gin.Context) { return } - length, err := api.DB.Queries.GetDocumentsSize(api.DB.Ctx, query) + length, err := api.db.Queries.GetDocumentsSize(api.db.Ctx, query) if err != nil { log.Error("GetDocumentsSize DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err)) @@ -196,7 +196,7 @@ func (api *API) appGetDocument(c *gin.Context) { return } - document, err := api.DB.Queries.GetDocumentWithStats(api.DB.Ctx, database.GetDocumentWithStatsParams{ + document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{ UserID: auth.UserName, DocumentID: rDocID.DocumentID, }) @@ -228,7 +228,7 @@ func (api *API) appGetProgress(c *gin.Context) { progressFilter.DocumentID = *qParams.Document } - progress, err := api.DB.Queries.GetProgress(api.DB.Ctx, progressFilter) + progress, err := api.db.Queries.GetProgress(api.db.Ctx, progressFilter) if err != nil { log.Error("GetProgress DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) @@ -255,7 +255,7 @@ func (api *API) appGetActivity(c *gin.Context) { activityFilter.DocumentID = *qParams.Document } - activity, err := api.DB.Queries.GetActivity(api.DB.Ctx, activityFilter) + activity, err := api.db.Queries.GetActivity(api.db.Ctx, activityFilter) if err != nil { log.Error("GetActivity DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) @@ -271,7 +271,7 @@ func (api *API) appGetHome(c *gin.Context) { templateVars, auth := api.getBaseTemplateVars("home", c) start := time.Now() - graphData, err := api.DB.Queries.GetDailyReadStats(api.DB.Ctx, auth.UserName) + graphData, err := api.db.Queries.GetDailyReadStats(api.db.Ctx, auth.UserName) if err != nil { log.Error("GetDailyReadStats DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDailyReadStats DB Error: %v", err)) @@ -280,7 +280,7 @@ func (api *API) appGetHome(c *gin.Context) { log.Debug("GetDailyReadStats DB Performance: ", time.Since(start)) start = time.Now() - databaseInfo, err := api.DB.Queries.GetDatabaseInfo(api.DB.Ctx, auth.UserName) + databaseInfo, err := api.db.Queries.GetDatabaseInfo(api.db.Ctx, auth.UserName) if err != nil { log.Error("GetDatabaseInfo DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDatabaseInfo DB Error: %v", err)) @@ -289,7 +289,7 @@ func (api *API) appGetHome(c *gin.Context) { log.Debug("GetDatabaseInfo DB Performance: ", time.Since(start)) start = time.Now() - streaks, err := api.DB.Queries.GetUserStreaks(api.DB.Ctx, auth.UserName) + streaks, err := api.db.Queries.GetUserStreaks(api.db.Ctx, auth.UserName) if err != nil { log.Error("GetUserStreaks DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStreaks DB Error: %v", err)) @@ -298,7 +298,7 @@ func (api *API) appGetHome(c *gin.Context) { log.Debug("GetUserStreaks DB Performance: ", time.Since(start)) start = time.Now() - userStatistics, err := api.DB.Queries.GetUserStatistics(api.DB.Ctx) + userStatistics, err := api.db.Queries.GetUserStatistics(api.db.Ctx) if err != nil { log.Error("GetUserStatistics DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStatistics DB Error: %v", err)) @@ -319,14 +319,14 @@ func (api *API) appGetHome(c *gin.Context) { func (api *API) appGetSettings(c *gin.Context) { templateVars, auth := api.getBaseTemplateVars("settings", c) - user, err := api.DB.Queries.GetUser(api.DB.Ctx, auth.UserName) + user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName) if err != nil { log.Error("GetUser DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) return } - devices, err := api.DB.Queries.GetDevices(api.DB.Ctx, auth.UserName) + devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName) if err != nil { log.Error("GetDevices DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) @@ -350,7 +350,7 @@ func (api *API) appGetAdminLogs(c *gin.Context) { templateVars, _ := api.getBaseTemplateVars("admin-logs", c) // Open Log File - logPath := filepath.Join(api.Config.ConfigPath, "logs/antholume.log") + logPath := filepath.Join(api.cfg.ConfigPath, "logs/antholume.log") logFile, err := os.Open(logPath) if err != nil { appErrorPage(c, http.StatusBadRequest, "Missing AnthoLume log file.") @@ -388,7 +388,7 @@ func (api *API) appGetAdminLogs(c *gin.Context) { func (api *API) appGetAdminUsers(c *gin.Context) { templateVars, _ := api.getBaseTemplateVars("admin-users", c) - users, err := api.DB.Queries.GetUsers(api.DB.Ctx) + users, err := api.db.Queries.GetUsers(api.db.Ctx) if err != nil { log.Error("GetUsers DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err)) @@ -422,12 +422,21 @@ func (api *API) appPerformAdminAction(c *gin.Context) { // 1. Documents xref most recent metadata table? // 2. Select all / deselect? case adminCacheTables: - go api.DB.CacheTempTables() + go api.db.CacheTempTables() case adminRestore: api.processRestoreFile(rAdminAction, c) case adminBackup: + // Vacuum + _, err := api.db.DB.ExecContext(api.db.Ctx, "VACUUM;") + if err != nil { + log.Error("Unable to vacuum DB: ", err) + appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database.") + return + } + + // Set Headers c.Header("Content-type", "application/octet-stream") - c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"AnthoLumeBackup_%s.zip\"", time.Now().Format("20060102"))) + c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"AnthoLumeBackup_%s.zip\"", time.Now().Format("20060102150405"))) // Stream Backup ZIP Archive c.Stream(func(w io.Writer) bool { @@ -479,18 +488,18 @@ func (api *API) appGetSearch(c *gin.Context) { func (api *API) appGetLogin(c *gin.Context) { templateVars, _ := api.getBaseTemplateVars("login", c) - templateVars["RegistrationEnabled"] = api.Config.RegistrationEnabled + templateVars["RegistrationEnabled"] = api.cfg.RegistrationEnabled c.HTML(http.StatusOK, "page/login", templateVars) } func (api *API) appGetRegister(c *gin.Context) { - if !api.Config.RegistrationEnabled { + if !api.cfg.RegistrationEnabled { c.Redirect(http.StatusFound, "/login") return } templateVars, _ := api.getBaseTemplateVars("login", c) - templateVars["RegistrationEnabled"] = api.Config.RegistrationEnabled + templateVars["RegistrationEnabled"] = api.cfg.RegistrationEnabled templateVars["Register"] = true c.HTML(http.StatusOK, "page/login", templateVars) } @@ -508,7 +517,7 @@ func (api *API) appGetDocumentProgress(c *gin.Context) { return } - progress, err := api.DB.Queries.GetDocumentProgress(api.DB.Ctx, database.GetDocumentProgressParams{ + progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{ DocumentID: rDoc.DocumentID, UserID: auth.UserName, }) @@ -519,7 +528,7 @@ func (api *API) appGetDocumentProgress(c *gin.Context) { return } - document, err := api.DB.Queries.GetDocumentWithStats(api.DB.Ctx, database.GetDocumentWithStatsParams{ + document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{ UserID: auth.UserName, DocumentID: rDoc.DocumentID, }) @@ -545,7 +554,7 @@ func (api *API) appGetDevices(c *gin.Context) { auth = data.(authData) } - devices, err := api.DB.Queries.GetDevices(api.DB.Ctx, auth.UserName) + devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName) if err != nil && err != sql.ErrNoRows { log.Error("GetDevices DB Error: ", err) @@ -627,7 +636,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) { } // Check Exists - _, err = api.DB.Queries.GetDocument(api.DB.Ctx, partialMD5) + _, err = api.db.Queries.GetDocument(api.db.Ctx, partialMD5) if err == nil { c.Redirect(http.StatusFound, fmt.Sprintf("./documents/%s", partialMD5)) return @@ -670,7 +679,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) { fileName = "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, partialMD5, fileExtension)) // Generate Storage Path & Open File - safePath := filepath.Join(api.Config.DataPath, "documents", fileName) + safePath := filepath.Join(api.cfg.DataPath, "documents", fileName) destFile, err := os.Create(safePath) if err != nil { log.Error("Dest File Error: ", err) @@ -687,7 +696,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) { } // Upsert Document - if _, err = api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ + if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ ID: partialMD5, Title: metadataInfo.Title, Author: metadataInfo.Author, @@ -764,7 +773,7 @@ func (api *API) appEditDocument(c *gin.Context) { // Generate Storage Path fileName := fmt.Sprintf("%s%s", rDocID.DocumentID, fileExtension) - safePath := filepath.Join(api.Config.DataPath, "covers", fileName) + safePath := filepath.Join(api.cfg.DataPath, "covers", fileName) // Save err = c.SaveUploadedFile(rDocEdit.CoverFile, safePath) @@ -776,7 +785,7 @@ func (api *API) appEditDocument(c *gin.Context) { coverFileName = &fileName } else if rDocEdit.CoverGBID != nil { - var coverDir string = filepath.Join(api.Config.DataPath, "covers") + var coverDir string = filepath.Join(api.cfg.DataPath, "covers") fileName, err := metadata.CacheCover(*rDocEdit.CoverGBID, coverDir, rDocID.DocumentID, true) if err == nil { coverFileName = fileName @@ -784,7 +793,7 @@ func (api *API) appEditDocument(c *gin.Context) { } // Update Document - if _, err := api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ + if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ ID: rDocID.DocumentID, Title: api.sanitizeInput(rDocEdit.Title), Author: api.sanitizeInput(rDocEdit.Author), @@ -809,7 +818,7 @@ func (api *API) appDeleteDocument(c *gin.Context) { appErrorPage(c, http.StatusNotFound, "Invalid document.") return } - changed, err := api.DB.Queries.DeleteDocument(api.DB.Ctx, rDocID.DocumentID) + changed, err := api.db.Queries.DeleteDocument(api.db.Ctx, rDocID.DocumentID) if err != nil { log.Error("DeleteDocument DB Error") appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("DeleteDocument DB Error: %v", err)) @@ -871,7 +880,7 @@ func (api *API) appIdentifyDocument(c *gin.Context) { firstResult := metadataResults[0] // Store First Metadata Result - if _, err = api.DB.Queries.AddMetadata(api.DB.Ctx, database.AddMetadataParams{ + if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{ DocumentID: rDocID.DocumentID, Title: firstResult.Title, Author: firstResult.Author, @@ -890,7 +899,7 @@ func (api *API) appIdentifyDocument(c *gin.Context) { templateVars["MetadataError"] = "No Metadata Found" } - document, err := api.DB.Queries.GetDocumentWithStats(api.DB.Ctx, database.GetDocumentWithStatsParams{ + document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{ UserID: auth.UserName, DocumentID: rDocID.DocumentID, }) @@ -1001,7 +1010,7 @@ func (api *API) appSaveNewDocument(c *gin.Context) { defer sourceFile.Close() // Generate Storage Path & Open File - safePath := filepath.Join(api.Config.DataPath, "documents", fileName) + safePath := filepath.Join(api.cfg.DataPath, "documents", fileName) destFile, err := os.Create(safePath) if err != nil { log.Error("Dest File Error: ", err) @@ -1043,7 +1052,7 @@ func (api *API) appSaveNewDocument(c *gin.Context) { sendDownloadMessage("Saving to database...", gin.H{"Progress": 90}) // Upsert Document - if _, err = api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ + if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ ID: partialMD5, Title: rDocAdd.Title, Author: rDocAdd.Author, @@ -1110,7 +1119,7 @@ func (api *API) appEditSettings(c *gin.Context) { } // Update User - _, err := api.DB.Queries.UpdateUser(api.DB.Ctx, newUserSettings) + _, err := api.db.Queries.UpdateUser(api.db.Ctx, newUserSettings) if err != nil { log.Error("UpdateUser DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpdateUser DB Error: %v", err)) @@ -1118,7 +1127,7 @@ func (api *API) appEditSettings(c *gin.Context) { } // Get User - user, err := api.DB.Queries.GetUser(api.DB.Ctx, auth.UserName) + user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName) if err != nil { log.Error("GetUser DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) @@ -1126,7 +1135,7 @@ func (api *API) appEditSettings(c *gin.Context) { } // Get Devices - devices, err := api.DB.Queries.GetDevices(api.DB.Ctx, auth.UserName) + devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName) if err != nil { log.Error("GetDevices DB Error: ", err) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) @@ -1147,7 +1156,7 @@ func (api *API) appDemoModeError(c *gin.Context) { func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStatsRow) error { // Do Transaction - tx, err := api.DB.DB.Begin() + tx, err := api.db.DB.Begin() if err != nil { log.Error("Transaction Begin DB Error: ", err) return err @@ -1155,16 +1164,16 @@ func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStats // Defer & Start Transaction defer tx.Rollback() - qtx := api.DB.Queries.WithTx(tx) + qtx := api.db.Queries.WithTx(tx) for _, item := range documents { if item.Words == nil && item.Filepath != nil { - filePath := filepath.Join(api.Config.DataPath, "documents", *item.Filepath) + filePath := filepath.Join(api.cfg.DataPath, "documents", *item.Filepath) wordCount, err := metadata.GetWordCount(filePath) if err != nil { log.Warn("Word Count Error: ", err) } else { - if _, err := qtx.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ + if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ ID: item.ID, Words: &wordCount, }); err != nil { @@ -1194,9 +1203,9 @@ func (api *API) getBaseTemplateVars(routeName string, c *gin.Context) (gin.H, au "Authorization": auth, "RouteName": routeName, "Config": gin.H{ - "Version": api.Config.Version, - "SearchEnabled": api.Config.SearchEnabled, - "RegistrationEnabled": api.Config.RegistrationEnabled, + "Version": api.cfg.Version, + "SearchEnabled": api.cfg.SearchEnabled, + "RegistrationEnabled": api.cfg.RegistrationEnabled, }, }, auth } @@ -1402,7 +1411,7 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte } // Create Backup File - backupFilePath := filepath.Join(api.Config.ConfigPath, fmt.Sprintf("backup/AnthoLumeBackup_%s.zip", time.Now().Format("20060102"))) + backupFilePath := filepath.Join(api.cfg.ConfigPath, fmt.Sprintf("backups/AnthoLumeBackup_%s.zip", time.Now().Format("20060102150405"))) backupFile, err := os.Create(backupFilePath) if err != nil { log.Error("Unable to create backup file: ", err) @@ -1411,6 +1420,14 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte } defer backupFile.Close() + // Vacuum DB + _, err = api.db.DB.ExecContext(api.db.Ctx, "VACUUM;") + if err != nil { + log.Error("Unable to vacuum DB: ", err) + appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database.") + return + } + // Save Backup File w := bufio.NewWriter(backupFile) err = api.createBackup(w, []string{"covers", "documents"}) @@ -1423,6 +1440,7 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte // Remove Data err = api.removeData() if err != nil { + log.Error("Unable to delete data: ", err) appErrorPage(c, http.StatusInternalServerError, "Unable to delete data.") return } @@ -1431,19 +1449,26 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte err = api.restoreData(zipReader) if err != nil { appErrorPage(c, http.StatusInternalServerError, "Unable to restore data.") - - // Panic? - - log.Panic("Oh no") - + log.Panic("Unable to restore data: ", err) return } - // TODO: - // - Extract from temp directory + // Close DB + err = api.db.DB.Close() + if err != nil { + appErrorPage(c, http.StatusInternalServerError, "Unable to close DB.") + log.Panic("Unable to close DB: ", err) + } + + // Reinit DB + api.db.Reload() } func (api *API) restoreData(zipReader *zip.Reader) error { + // Ensure Directories + api.cfg.EnsureDirectories() + + // Restore Data for _, file := range zipReader.File { rc, err := file.Open() if err != nil { @@ -1451,7 +1476,7 @@ func (api *API) restoreData(zipReader *zip.Reader) error { } defer rc.Close() - destPath := filepath.Join(api.Config.DataPath, file.Name) + destPath := filepath.Join(api.cfg.DataPath, file.Name) destFile, err := os.Create(destPath) if err != nil { fmt.Println("Error creating destination file:", err) @@ -1481,7 +1506,7 @@ func (api *API) removeData() error { } for _, name := range allPaths { - fullPath := filepath.Join(api.Config.DataPath, name) + fullPath := filepath.Join(api.cfg.DataPath, name) err := os.RemoveAll(fullPath) if err != nil { log.Errorf("Unable to delete %s: %v", name, err) @@ -1531,8 +1556,8 @@ func (api *API) createBackup(w io.Writer, directories []string) error { } // Get DB Path - fileName := fmt.Sprintf("%s.db", api.Config.DBName) - dbLocation := filepath.Join(api.Config.ConfigPath, fileName) + fileName := fmt.Sprintf("%s.db", api.cfg.DBName) + dbLocation := filepath.Join(api.cfg.ConfigPath, fileName) // Copy Database File dbFile, err := os.Open(dbLocation) @@ -1549,7 +1574,7 @@ func (api *API) createBackup(w io.Writer, directories []string) error { // Backup Covers & Documents for _, dir := range directories { - err = filepath.WalkDir(filepath.Join(api.Config.DataPath, dir), exportWalker) + err = filepath.WalkDir(filepath.Join(api.cfg.DataPath, dir), exportWalker) if err != nil { return err } diff --git a/api/auth.go b/api/auth.go index 663bf53..8fdc07e 100644 --- a/api/auth.go +++ b/api/auth.go @@ -32,7 +32,7 @@ type authOPDSHeader struct { } func (api *API) authorizeCredentials(username string, password string) (auth *authData) { - user, err := api.DB.Queries.GetUser(api.DB.Ctx, username) + user, err := api.db.Queries.GetUser(api.db.Ctx, username) if err != nil { return } @@ -174,7 +174,7 @@ func (api *API) appAuthFormLogin(c *gin.Context) { } func (api *API) appAuthFormRegister(c *gin.Context) { - if !api.Config.RegistrationEnabled { + if !api.cfg.RegistrationEnabled { appErrorPage(c, http.StatusUnauthorized, "Nice try. Registration is disabled.") return } @@ -199,7 +199,7 @@ func (api *API) appAuthFormRegister(c *gin.Context) { return } - rows, err := api.DB.Queries.CreateUser(api.DB.Ctx, database.CreateUserParams{ + rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{ ID: username, Pass: &hashedPassword, }) @@ -221,7 +221,7 @@ func (api *API) appAuthFormRegister(c *gin.Context) { } // Get User - user, err := api.DB.Queries.GetUser(api.DB.Ctx, username) + user, err := api.db.Queries.GetUser(api.db.Ctx, username) if err != nil { log.Error("GetUser DB Error:", err) templateVars["Error"] = "Registration Disabled or User Already Exists" diff --git a/api/common.go b/api/common.go index 5c59f4a..f145dc7 100644 --- a/api/common.go +++ b/api/common.go @@ -21,7 +21,7 @@ func (api *API) createDownloadDocumentHandler(errorFunc func(*gin.Context, int, } // Get Document - document, err := api.DB.Queries.GetDocument(api.DB.Ctx, rDoc.DocumentID) + document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID) if err != nil { log.Error("GetDocument DB Error:", err) errorFunc(c, http.StatusBadRequest, "Unknown Document") @@ -35,12 +35,12 @@ func (api *API) createDownloadDocumentHandler(errorFunc func(*gin.Context, int, } // Derive Storage Location - filePath := filepath.Join(api.Config.DataPath, "documents", *document.Filepath) + filePath := filepath.Join(api.cfg.DataPath, "documents", *document.Filepath) // Validate File Exists _, err = os.Stat(filePath) if os.IsNotExist(err) { - log.Error("File Doesn't Exist:", rDoc.DocumentID) + log.Error("File should but doesn't exist: ", err) errorFunc(c, http.StatusBadRequest, "Document Doesn't Exist") return } @@ -61,7 +61,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string)) } // Validate Document Exists in DB - document, err := api.DB.Queries.GetDocument(api.DB.Ctx, rDoc.DocumentID) + document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID) if err != nil { log.Error("GetDocument DB Error:", err) errorFunc(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) @@ -71,18 +71,18 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string)) // Handle Identified Document if document.Coverfile != nil { if *document.Coverfile == "UNKNOWN" { - c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.Assets)) + c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.assets)) return } // Derive Path - safePath := filepath.Join(api.Config.DataPath, "covers", *document.Coverfile) + safePath := filepath.Join(api.cfg.DataPath, "covers", *document.Coverfile) // Validate File Exists _, err = os.Stat(safePath) if err != nil { - log.Error("File Should But Doesn't Exist:", err) - c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.Assets)) + log.Error("File should but doesn't exist: ", err) + c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.assets)) return } @@ -91,7 +91,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string)) } // Attempt Metadata - var coverDir string = filepath.Join(api.Config.DataPath, "covers") + var coverDir string = filepath.Join(api.cfg.DataPath, "covers") var coverFile string = "UNKNOWN" // Identify Documents & Save Covers @@ -110,7 +110,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string)) } // Store First Metadata Result - if _, err = api.DB.Queries.AddMetadata(api.DB.Ctx, database.AddMetadataParams{ + if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{ DocumentID: document.ID, Title: firstResult.Title, Author: firstResult.Author, @@ -125,7 +125,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string)) } // Upsert Document - if _, err = api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ + if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ ID: document.ID, Coverfile: &coverFile, }); err != nil { @@ -134,7 +134,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string)) // Return Unknown Cover if coverFile == "UNKNOWN" { - c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.Assets)) + c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.assets)) return } diff --git a/api/ko-routes.go b/api/ko-routes.go index a394a8a..14d88fc 100644 --- a/api/ko-routes.go +++ b/api/ko-routes.go @@ -82,7 +82,7 @@ func (api *API) koAuthorizeUser(c *gin.Context) { } func (api *API) koCreateUser(c *gin.Context) { - if !api.Config.RegistrationEnabled { + if !api.cfg.RegistrationEnabled { c.AbortWithStatus(http.StatusConflict) return } @@ -107,7 +107,7 @@ func (api *API) koCreateUser(c *gin.Context) { return } - rows, err := api.DB.Queries.CreateUser(api.DB.Ctx, database.CreateUserParams{ + rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{ ID: rUser.Username, Pass: &hashedPassword, }) @@ -143,7 +143,7 @@ func (api *API) koSetProgress(c *gin.Context) { } // Upsert Device - if _, err := api.DB.Queries.UpsertDevice(api.DB.Ctx, database.UpsertDeviceParams{ + if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ ID: rPosition.DeviceID, UserID: auth.UserName, DeviceName: rPosition.Device, @@ -153,14 +153,14 @@ func (api *API) koSetProgress(c *gin.Context) { } // Upsert Document - if _, err := api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ + if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ ID: rPosition.DocumentID, }); err != nil { log.Error("UpsertDocument DB Error:", err) } // Create or Replace Progress - progress, err := api.DB.Queries.UpdateProgress(api.DB.Ctx, database.UpdateProgressParams{ + progress, err := api.db.Queries.UpdateProgress(api.db.Ctx, database.UpdateProgressParams{ Percentage: rPosition.Percentage, DocumentID: rPosition.DocumentID, DeviceID: rPosition.DeviceID, @@ -192,7 +192,7 @@ func (api *API) koGetProgress(c *gin.Context) { return } - progress, err := api.DB.Queries.GetDocumentProgress(api.DB.Ctx, database.GetDocumentProgressParams{ + progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{ DocumentID: rDocID.DocumentID, UserID: auth.UserName, }) @@ -230,7 +230,7 @@ func (api *API) koAddActivities(c *gin.Context) { } // Do Transaction - tx, err := api.DB.DB.Begin() + tx, err := api.db.DB.Begin() if err != nil { log.Error("Transaction Begin DB Error:", err) apiErrorPage(c, http.StatusBadRequest, "Unknown Error") @@ -246,11 +246,11 @@ func (api *API) koAddActivities(c *gin.Context) { // Defer & Start Transaction defer tx.Rollback() - qtx := api.DB.Queries.WithTx(tx) + qtx := api.db.Queries.WithTx(tx) // Upsert Documents for _, doc := range allDocuments { - if _, err := qtx.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ + if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ ID: doc, }); err != nil { log.Error("UpsertDocument DB Error:", err) @@ -260,7 +260,7 @@ func (api *API) koAddActivities(c *gin.Context) { } // Upsert Device - if _, err = qtx.UpsertDevice(api.DB.Ctx, database.UpsertDeviceParams{ + if _, err = qtx.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ ID: rActivity.DeviceID, UserID: auth.UserName, DeviceName: rActivity.Device, @@ -273,7 +273,7 @@ func (api *API) koAddActivities(c *gin.Context) { // Add All Activity for _, item := range rActivity.Activity { - if _, err := qtx.AddActivity(api.DB.Ctx, database.AddActivityParams{ + if _, err := qtx.AddActivity(api.db.Ctx, database.AddActivityParams{ UserID: auth.UserName, DocumentID: item.DocumentID, DeviceID: rActivity.DeviceID, @@ -314,7 +314,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) { } // Upsert Device - if _, err := api.DB.Queries.UpsertDevice(api.DB.Ctx, database.UpsertDeviceParams{ + if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ ID: rCheckActivity.DeviceID, UserID: auth.UserName, DeviceName: rCheckActivity.Device, @@ -326,7 +326,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) { } // Get Last Device Activity - lastActivity, err := api.DB.Queries.GetLastActivity(api.DB.Ctx, database.GetLastActivityParams{ + lastActivity, err := api.db.Queries.GetLastActivity(api.db.Ctx, database.GetLastActivityParams{ UserID: auth.UserName, DeviceID: rCheckActivity.DeviceID, }) @@ -360,7 +360,7 @@ func (api *API) koAddDocuments(c *gin.Context) { } // Do Transaction - tx, err := api.DB.DB.Begin() + tx, err := api.db.DB.Begin() if err != nil { log.Error("Transaction Begin DB Error:", err) apiErrorPage(c, http.StatusBadRequest, "Unknown Error") @@ -369,11 +369,11 @@ func (api *API) koAddDocuments(c *gin.Context) { // Defer & Start Transaction defer tx.Rollback() - qtx := api.DB.Queries.WithTx(tx) + qtx := api.db.Queries.WithTx(tx) // Upsert Documents for _, doc := range rNewDocs.Documents { - _, err := qtx.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ + _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ ID: doc.ID, Title: api.sanitizeInput(doc.Title), Author: api.sanitizeInput(doc.Author), @@ -415,7 +415,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) { } // Upsert Device - _, err := api.DB.Queries.UpsertDevice(api.DB.Ctx, database.UpsertDeviceParams{ + _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ ID: rCheckDocs.DeviceID, UserID: auth.UserName, DeviceName: rCheckDocs.Device, @@ -431,7 +431,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) { deletedDocIDs := []string{} // Get Missing Documents - missingDocs, err = api.DB.Queries.GetMissingDocuments(api.DB.Ctx, rCheckDocs.Have) + missingDocs, err = api.db.Queries.GetMissingDocuments(api.db.Ctx, rCheckDocs.Have) if err != nil { log.Error("GetMissingDocuments DB Error", err) apiErrorPage(c, http.StatusBadRequest, "Invalid Request") @@ -439,7 +439,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) { } // Get Deleted Documents - deletedDocIDs, err = api.DB.Queries.GetDeletedDocuments(api.DB.Ctx, rCheckDocs.Have) + deletedDocIDs, err = api.db.Queries.GetDeletedDocuments(api.db.Ctx, rCheckDocs.Have) if err != nil { log.Error("GetDeletedDocuments DB Error", err) apiErrorPage(c, http.StatusBadRequest, "Invalid Request") @@ -454,7 +454,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) { return } - wantedDocs, err := api.DB.Queries.GetWantedDocuments(api.DB.Ctx, string(jsonHaves)) + wantedDocs, err := api.db.Queries.GetWantedDocuments(api.db.Ctx, string(jsonHaves)) if err != nil { log.Error("GetWantedDocuments DB Error", err) apiErrorPage(c, http.StatusBadRequest, "Invalid Request") @@ -524,7 +524,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) { } // Validate Document Exists in DB - document, err := api.DB.Queries.GetDocument(api.DB.Ctx, rDoc.DocumentID) + document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID) if err != nil { log.Error("GetDocument DB Error:", err) apiErrorPage(c, http.StatusBadRequest, "Unknown Document") @@ -552,7 +552,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) { fileName = "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, document.ID, fileExtension)) // Generate Storage Path - safePath := filepath.Join(api.Config.DataPath, "documents", fileName) + safePath := filepath.Join(api.cfg.DataPath, "documents", fileName) // Save & Prevent Overwrites _, err = os.Stat(safePath) @@ -582,7 +582,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) { } // Upsert Document - if _, err = api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ + if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ ID: document.ID, Md5: fileHash, Filepath: &fileName, @@ -610,12 +610,12 @@ func (api *API) sanitizeInput(val any) *string { switch v := val.(type) { case *string: if v != nil { - newString := html.UnescapeString(api.HTMLPolicy.Sanitize(string(*v))) + newString := html.UnescapeString(htmlPolicy.Sanitize(string(*v))) return &newString } case string: if v != "" { - newString := html.UnescapeString(api.HTMLPolicy.Sanitize(string(v))) + newString := html.UnescapeString(htmlPolicy.Sanitize(string(v))) return &newString } } diff --git a/api/opds-routes.go b/api/opds-routes.go index 85505e8..31434d4 100644 --- a/api/opds-routes.go +++ b/api/opds-routes.go @@ -77,7 +77,7 @@ func (api *API) opdsDocuments(c *gin.Context) { } // Get Documents - documents, err := api.DB.Queries.GetDocumentsWithStats(api.DB.Ctx, database.GetDocumentsWithStatsParams{ + documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{ UserID: auth.UserName, Query: query, Offset: (*qParams.Page - 1) * *qParams.Limit, diff --git a/api/streamer.go b/api/streamer.go index d3e8865..c533b01 100644 --- a/api/streamer.go +++ b/api/streamer.go @@ -19,7 +19,7 @@ type streamer struct { func (api *API) newStreamer(c *gin.Context, data string) *streamer { stream := &streamer{ - templates: api.Templates, + templates: api.templates, writer: c.Writer, completeCh: make(chan struct{}), } diff --git a/config/config.go b/config/config.go index 0c6db75..a69b9d7 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path" + "path/filepath" "runtime" "strings" @@ -40,7 +41,7 @@ type customFormatter struct { log.Formatter } -// Force UTC & Set Type (app) +// Force UTC & Set type (app) func (cf customFormatter) Format(e *log.Entry) ([]byte, error) { if e.Data["type"] == nil { e.Data["type"] = "app" @@ -70,18 +71,18 @@ func Load() *Config { CookieHTTPOnly: trimLowerString(getEnv("COOKIE_HTTP_ONLY", "true")) == "true", } - // Log Level + // Parse log level logLevel, err := log.ParseLevel(c.LogLevel) if err != nil { logLevel = log.InfoLevel } - // Log Formatter + // Create custom formatter logFormatter := &customFormatter{&log.JSONFormatter{ CallerPrettyfier: prettyCaller, }} - // Log Rotater + // Create log rotator rotateFileHook, err := NewRotateFileHook(RotateFileConfig{ Filename: path.Join(c.ConfigPath, "logs/antholume.log"), MaxSize: 50, @@ -94,17 +95,34 @@ func Load() *Config { log.Fatal("Unable to initialize file rotate hook") } - // Rotate Now + // Rotate now rotateFileHook.Rotate() + // Set logger settings log.SetLevel(logLevel) log.SetFormatter(logFormatter) log.SetReportCaller(true) log.AddHook(rotateFileHook) + // Ensure directories exist + c.EnsureDirectories() + return c } +// Ensures needed directories exist +func (c *Config) EnsureDirectories() { + os.Mkdir(c.ConfigPath, 0755) + os.Mkdir(c.DataPath, 0755) + + docDir := filepath.Join(c.DataPath, "documents") + coversDir := filepath.Join(c.DataPath, "covers") + backupDir := filepath.Join(c.DataPath, "backups") + os.Mkdir(docDir, 0755) + os.Mkdir(coversDir, 0755) + os.Mkdir(backupDir, 0755) +} + func getEnv(key, fallback string) string { if value, ok := os.LookupEnv(key); ok { return value diff --git a/database/manager.go b/database/manager.go index 64194cc..210176a 100644 --- a/database/manager.go +++ b/database/manager.go @@ -19,6 +19,7 @@ type DBManager struct { DB *sql.DB Ctx context.Context Queries *Queries + cfg *config.Config } //go:embed schema.sql @@ -27,17 +28,25 @@ var ddl string //go:embed migrations/* var migrations embed.FS +// Returns an initialized manager func NewMgr(c *config.Config) *DBManager { // Create Manager dbm := &DBManager{ Ctx: context.Background(), + cfg: c, } - // Create Database - if c.DBType == "sqlite" || c.DBType == "memory" { + dbm.init() + + return dbm +} + +// Init manager +func (dbm *DBManager) init() { + if dbm.cfg.DBType == "sqlite" || dbm.cfg.DBType == "memory" { var dbLocation string = ":memory:" - if c.DBType == "sqlite" { - dbLocation = filepath.Join(c.ConfigPath, fmt.Sprintf("%s.db", c.DBName)) + if dbm.cfg.DBType == "sqlite" { + dbLocation = filepath.Join(dbm.cfg.ConfigPath, fmt.Sprintf("%s.db", dbm.cfg.DBName)) } var err error @@ -67,12 +76,20 @@ func NewMgr(c *config.Config) *DBManager { } dbm.Queries = New(dbm.DB) - - return dbm } -func (dbm *DBManager) Shutdown() error { - return dbm.DB.Close() +// Reload manager (close DB & reinit) +func (dbm *DBManager) Reload() error { + // Close handle + err := dbm.DB.Close() + if err != nil { + return err + } + + // Reinit DB + dbm.init() + + return nil } func (dbm *DBManager) CacheTempTables() error { diff --git a/main.go b/main.go index 9b7fd62..cf308e7 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,6 @@ import ( "embed" "os" "os/signal" - "sync" "syscall" log "github.com/sirupsen/logrus" @@ -38,18 +37,16 @@ func cmdServer(ctx *cli.Context) error { log.Info("Starting AnthoLume Server") // Create Channel - wg := sync.WaitGroup{} - done := make(chan struct{}) - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt, syscall.SIGTERM) // Start Server - server := server.NewServer(&assets) - server.StartServer(&wg, done) + s := server.New(&assets) + s.Start() // Wait & Close - <-interrupt - server.StopServer(&wg, done) + <-signals + s.Stop() // Stop Server os.Exit(0) diff --git a/server/server.go b/server/server.go index f79ab71..a9d89d6 100644 --- a/server/server.go +++ b/server/server.go @@ -1,11 +1,8 @@ package server import ( - "context" "embed" "net/http" - "os" - "path/filepath" "sync" "time" @@ -16,91 +13,79 @@ import ( "reichard.io/antholume/database" ) -type Server struct { - API *api.API - Config *config.Config - Database *database.DBManager - httpServer *http.Server +type server struct { + db *database.DBManager + api *api.API + done chan int + wg sync.WaitGroup } -func NewServer(assets *embed.FS) *Server { +// Create new server +func New(assets *embed.FS) *server { c := config.Load() db := database.NewMgr(c) api := api.NewApi(db, c, assets) - // Create Paths - os.Mkdir(c.ConfigPath, 0755) - os.Mkdir(c.DataPath, 0755) - - // Create Subpaths - docDir := filepath.Join(c.DataPath, "documents") - coversDir := filepath.Join(c.DataPath, "covers") - backupDir := filepath.Join(c.DataPath, "backup") - os.Mkdir(docDir, 0755) - os.Mkdir(coversDir, 0755) - os.Mkdir(backupDir, 0755) - - return &Server{ - API: api, - Config: c, - Database: db, - httpServer: &http.Server{ - Handler: api.Router, - Addr: (":" + c.ListenPort), - }, + return &server{ + db: db, + api: api, + done: make(chan int), } } -func (s *Server) StartServer(wg *sync.WaitGroup, done <-chan struct{}) { - ticker := time.NewTicker(15 * time.Minute) - - wg.Add(2) +// Start server +func (s *server) Start() { + log.Info("Starting server...") + s.wg.Add(2) go func() { - defer wg.Done() + defer s.wg.Done() - err := s.httpServer.ListenAndServe() + err := s.api.Start() if err != nil && err != http.ErrServerClosed { - log.Error("Error starting server:", err) + log.Error("Starting server failed: ", err) } }() go func() { - defer wg.Done() + defer s.wg.Done() + + ticker := time.NewTicker(15 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: - s.RunScheduledTasks() - case <-done: + s.runScheduledTasks() + case <-s.done: log.Info("Stopping task runner...") return } } }() + + log.Info("Server started") } -func (s *Server) RunScheduledTasks() { - start := time.Now() - if err := s.API.DB.CacheTempTables(); err != nil { - log.Warn("Refreshing temp table cache failure:", err) +// Stop server +func (s *server) Stop() { + log.Info("Stopping server...") + + if err := s.api.Stop(); err != nil { + log.Error("HTTP server stop failed: ", err) } - log.Debug("Completed in: ", time.Since(start)) -} -func (s *Server) StopServer(wg *sync.WaitGroup, done chan<- struct{}) { - log.Info("Stopping HTTP server...") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := s.httpServer.Shutdown(ctx); err != nil { - log.Info("HTTP server shutdown error: ", err) - } - s.API.DB.Shutdown() - - close(done) - wg.Wait() + close(s.done) + s.wg.Wait() log.Info("Server stopped") } + +// Run normal scheduled tasks +func (s *server) runScheduledTasks() { + start := time.Now() + if err := s.db.CacheTempTables(); err != nil { + log.Warn("Refreshing temp table cache failed: ", err) + } + log.Debug("Completed in: ", time.Since(start)) +}