refactor(managers): privatize manager struct fields
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Evan Reichard 2024-01-27 14:56:01 -05:00
parent 8c4c1022c3
commit 9792a6ff19
11 changed files with 316 additions and 247 deletions

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"crypto/rand" "crypto/rand"
"embed" "embed"
"fmt" "fmt"
@ -22,29 +23,37 @@ import (
) )
type API struct { type API struct {
Router *gin.Engine db *database.DBManager
Config *config.Config cfg *config.Config
DB *database.DBManager assets *embed.FS
HTMLPolicy *bluemonday.Policy templates map[string]*template.Template
Assets *embed.FS httpServer *http.Server
Templates map[string]*template.Template
} }
var htmlPolicy = bluemonday.StrictPolicy()
func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API { func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API {
api := &API{ api := &API{
HTMLPolicy: bluemonday.StrictPolicy(), db: db,
Router: gin.New(), cfg: c,
Config: c, assets: assets,
DB: db, }
Assets: assets,
// Create Router
router := gin.New()
// Add Server
api.httpServer = &http.Server{
Handler: router,
Addr: (":" + c.ListenPort),
} }
// Add Logger // Add Logger
api.Router.Use(apiLogger()) router.Use(apiLogger())
// Assets & Web App Templates // Assets & Web App Templates
assetsDir, _ := fs.Sub(assets, "assets") assetsDir, _ := fs.Sub(assets, "assets")
api.Router.StaticFS("/assets", http.FS(assetsDir)) router.StaticFS("/assets", http.FS(assetsDir))
// Generate Auth Token // Generate Auth Token
var newToken []byte var newToken []byte
@ -78,74 +87,92 @@ func NewApi(db *database.DBManager, c *config.Config, assets *embed.FS) *API {
HttpOnly: c.CookieHTTPOnly, HttpOnly: c.CookieHTTPOnly,
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteStrictMode,
}) })
api.Router.Use(sessions.Sessions("token", store)) router.Use(sessions.Sessions("token", store))
// Register Web App Route // Register Web App Route
api.registerWebAppRoutes() api.registerWebAppRoutes(router)
// Register API Routes // Register API Routes
apiGroup := api.Router.Group("/api") apiGroup := router.Group("/api")
api.registerKOAPIRoutes(apiGroup) api.registerKOAPIRoutes(apiGroup)
api.registerOPDSRoutes(apiGroup) api.registerOPDSRoutes(apiGroup)
return api return api
} }
func (api *API) registerWebAppRoutes() { func (api *API) Start() error {
return api.httpServer.ListenAndServe()
}
func (api *API) Stop() error {
// Stop Server
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := api.httpServer.Shutdown(ctx)
if err != nil {
return err
}
// Close DB
return api.db.DB.Close()
}
func (api *API) registerWebAppRoutes(router *gin.Engine) {
// Generate Templates // Generate Templates
api.Router.HTMLRender = *api.generateTemplates() router.HTMLRender = *api.generateTemplates()
// Static Assets (Required @ Root) // Static Assets (Required @ Root)
api.Router.GET("/manifest.json", api.appWebManifest) router.GET("/manifest.json", api.appWebManifest)
api.Router.GET("/favicon.ico", api.appFaviconIcon) router.GET("/favicon.ico", api.appFaviconIcon)
api.Router.GET("/sw.js", api.appServiceWorker) router.GET("/sw.js", api.appServiceWorker)
// Local / Offline Static Pages (No Template, No Auth) // Local / Offline Static Pages (No Template, No Auth)
api.Router.GET("/local", api.appLocalDocuments) router.GET("/local", api.appLocalDocuments)
// Reader (Reader Page, Document Progress, Devices) // Reader (Reader Page, Document Progress, Devices)
api.Router.GET("/reader", api.appDocumentReader) router.GET("/reader", api.appDocumentReader)
api.Router.GET("/reader/devices", api.authWebAppMiddleware, api.appGetDevices) router.GET("/reader/devices", api.authWebAppMiddleware, api.appGetDevices)
api.Router.GET("/reader/progress/:document", api.authWebAppMiddleware, api.appGetDocumentProgress) router.GET("/reader/progress/:document", api.authWebAppMiddleware, api.appGetDocumentProgress)
// Web App // Web App
api.Router.GET("/", api.authWebAppMiddleware, api.appGetHome) router.GET("/", api.authWebAppMiddleware, api.appGetHome)
api.Router.GET("/activity", api.authWebAppMiddleware, api.appGetActivity) router.GET("/activity", api.authWebAppMiddleware, api.appGetActivity)
api.Router.GET("/progress", api.authWebAppMiddleware, api.appGetProgress) router.GET("/progress", api.authWebAppMiddleware, api.appGetProgress)
api.Router.GET("/documents", api.authWebAppMiddleware, api.appGetDocuments) router.GET("/documents", api.authWebAppMiddleware, api.appGetDocuments)
api.Router.GET("/documents/:document", api.authWebAppMiddleware, api.appGetDocument) router.GET("/documents/:document", api.authWebAppMiddleware, api.appGetDocument)
api.Router.GET("/documents/:document/cover", api.authWebAppMiddleware, api.createGetCoverHandler(appErrorPage)) router.GET("/documents/:document/cover", api.authWebAppMiddleware, api.createGetCoverHandler(appErrorPage))
api.Router.GET("/documents/:document/file", api.authWebAppMiddleware, api.createDownloadDocumentHandler(appErrorPage)) router.GET("/documents/:document/file", api.authWebAppMiddleware, api.createDownloadDocumentHandler(appErrorPage))
api.Router.GET("/login", api.appGetLogin) router.GET("/login", api.appGetLogin)
api.Router.GET("/logout", api.authWebAppMiddleware, api.appAuthLogout) router.GET("/logout", api.authWebAppMiddleware, api.appAuthLogout)
api.Router.GET("/register", api.appGetRegister) router.GET("/register", api.appGetRegister)
api.Router.GET("/settings", api.authWebAppMiddleware, api.appGetSettings) router.GET("/settings", api.authWebAppMiddleware, api.appGetSettings)
api.Router.GET("/admin/logs", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminLogs) router.GET("/admin/logs", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminLogs)
api.Router.GET("/admin/users", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminUsers) router.GET("/admin/users", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminUsers)
api.Router.GET("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdmin) router.GET("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdmin)
api.Router.POST("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appPerformAdminAction) router.POST("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appPerformAdminAction)
api.Router.POST("/login", api.appAuthFormLogin) router.POST("/login", api.appAuthFormLogin)
api.Router.POST("/register", api.appAuthFormRegister) router.POST("/register", api.appAuthFormRegister)
// Demo Mode Enabled Configuration // Demo Mode Enabled Configuration
if api.Config.DemoMode { if api.cfg.DemoMode {
api.Router.POST("/documents", api.authWebAppMiddleware, api.appDemoModeError) router.POST("/documents", api.authWebAppMiddleware, api.appDemoModeError)
api.Router.POST("/documents/:document/delete", api.authWebAppMiddleware, api.appDemoModeError) router.POST("/documents/:document/delete", api.authWebAppMiddleware, api.appDemoModeError)
api.Router.POST("/documents/:document/edit", api.authWebAppMiddleware, api.appDemoModeError) router.POST("/documents/:document/edit", api.authWebAppMiddleware, api.appDemoModeError)
api.Router.POST("/documents/:document/identify", api.authWebAppMiddleware, api.appDemoModeError) router.POST("/documents/:document/identify", api.authWebAppMiddleware, api.appDemoModeError)
api.Router.POST("/settings", api.authWebAppMiddleware, api.appDemoModeError) router.POST("/settings", api.authWebAppMiddleware, api.appDemoModeError)
} else { } else {
api.Router.POST("/documents", api.authWebAppMiddleware, api.appUploadNewDocument) router.POST("/documents", api.authWebAppMiddleware, api.appUploadNewDocument)
api.Router.POST("/documents/:document/delete", api.authWebAppMiddleware, api.appDeleteDocument) router.POST("/documents/:document/delete", api.authWebAppMiddleware, api.appDeleteDocument)
api.Router.POST("/documents/:document/edit", api.authWebAppMiddleware, api.appEditDocument) router.POST("/documents/:document/edit", api.authWebAppMiddleware, api.appEditDocument)
api.Router.POST("/documents/:document/identify", api.authWebAppMiddleware, api.appIdentifyDocument) router.POST("/documents/:document/identify", api.authWebAppMiddleware, api.appIdentifyDocument)
api.Router.POST("/settings", api.authWebAppMiddleware, api.appEditSettings) router.POST("/settings", api.authWebAppMiddleware, api.appEditSettings)
} }
// Search Enabled Configuration // Search Enabled Configuration
if api.Config.SearchEnabled { if api.cfg.SearchEnabled {
api.Router.GET("/search", api.authWebAppMiddleware, api.appGetSearch) router.GET("/search", api.authWebAppMiddleware, api.appGetSearch)
api.Router.POST("/search", api.authWebAppMiddleware, api.appSaveNewDocument) router.POST("/search", api.authWebAppMiddleware, api.appSaveNewDocument)
} }
} }
@ -162,7 +189,7 @@ func (api *API) registerKOAPIRoutes(apiGroup *gin.RouterGroup) {
koGroup.PUT("/syncs/progress", api.authKOMiddleware, api.koSetProgress) koGroup.PUT("/syncs/progress", api.authKOMiddleware, api.koSetProgress)
// Demo Mode Enabled Configuration // Demo Mode Enabled Configuration
if api.Config.DemoMode { if api.cfg.DemoMode {
koGroup.POST("/documents", api.authKOMiddleware, api.koDemoModeJSONError) koGroup.POST("/documents", api.authKOMiddleware, api.koDemoModeJSONError)
koGroup.POST("/syncs/documents", api.authKOMiddleware, api.koDemoModeJSONError) koGroup.POST("/syncs/documents", api.authKOMiddleware, api.koDemoModeJSONError)
koGroup.PUT("/documents/:document/file", api.authKOMiddleware, api.koDemoModeJSONError) koGroup.PUT("/documents/:document/file", api.authKOMiddleware, api.koDemoModeJSONError)
@ -200,50 +227,50 @@ func (api *API) generateTemplates() *multitemplate.Renderer {
} }
// Load Base // Load Base
b, _ := api.Assets.ReadFile("templates/base.tmpl") b, _ := api.assets.ReadFile("templates/base.tmpl")
baseTemplate := template.Must(template.New("base").Funcs(helperFuncs).Parse(string(b))) baseTemplate := template.Must(template.New("base").Funcs(helperFuncs).Parse(string(b)))
// Load SVGs // Load SVGs
svgs, _ := api.Assets.ReadDir("templates/svgs") svgs, _ := api.assets.ReadDir("templates/svgs")
for _, item := range svgs { for _, item := range svgs {
basename := item.Name() basename := item.Name()
path := fmt.Sprintf("templates/svgs/%s", basename) path := fmt.Sprintf("templates/svgs/%s", basename)
name := strings.TrimSuffix(basename, filepath.Ext(basename)) name := strings.TrimSuffix(basename, filepath.Ext(basename))
b, _ := api.Assets.ReadFile(path) b, _ := api.assets.ReadFile(path)
baseTemplate = template.Must(baseTemplate.New("svg/" + name).Parse(string(b))) baseTemplate = template.Must(baseTemplate.New("svg/" + name).Parse(string(b)))
templates["svg/"+name] = baseTemplate templates["svg/"+name] = baseTemplate
} }
// Load Components // Load Components
components, _ := api.Assets.ReadDir("templates/components") components, _ := api.assets.ReadDir("templates/components")
for _, item := range components { for _, item := range components {
basename := item.Name() basename := item.Name()
path := fmt.Sprintf("templates/components/%s", basename) path := fmt.Sprintf("templates/components/%s", basename)
name := strings.TrimSuffix(basename, filepath.Ext(basename)) name := strings.TrimSuffix(basename, filepath.Ext(basename))
// Clone Base Template // Clone Base Template
b, _ := api.Assets.ReadFile(path) b, _ := api.assets.ReadFile(path)
baseTemplate = template.Must(baseTemplate.New("component/" + name).Parse(string(b))) baseTemplate = template.Must(baseTemplate.New("component/" + name).Parse(string(b)))
render.Add("component/"+name, baseTemplate) render.Add("component/"+name, baseTemplate)
templates["component/"+name] = baseTemplate templates["component/"+name] = baseTemplate
} }
// Load Pages // Load Pages
pages, _ := api.Assets.ReadDir("templates/pages") pages, _ := api.assets.ReadDir("templates/pages")
for _, item := range pages { for _, item := range pages {
basename := item.Name() basename := item.Name()
path := fmt.Sprintf("templates/pages/%s", basename) path := fmt.Sprintf("templates/pages/%s", basename)
name := strings.TrimSuffix(basename, filepath.Ext(basename)) name := strings.TrimSuffix(basename, filepath.Ext(basename))
// Clone Base Template // Clone Base Template
b, _ := api.Assets.ReadFile(path) b, _ := api.assets.ReadFile(path)
pageTemplate, _ := template.Must(baseTemplate.Clone()).New("page/" + name).Parse(string(b)) pageTemplate, _ := template.Must(baseTemplate.Clone()).New("page/" + name).Parse(string(b))
render.Add("page/"+name, pageTemplate) render.Add("page/"+name, pageTemplate)
templates["page/"+name] = pageTemplate templates["page/"+name] = pageTemplate
} }
api.Templates = templates api.templates = templates
return &render return &render
} }

View File

@ -116,23 +116,23 @@ type requestDocumentAdd struct {
func (api *API) appWebManifest(c *gin.Context) { func (api *API) appWebManifest(c *gin.Context) {
c.Header("Content-Type", "application/manifest+json") c.Header("Content-Type", "application/manifest+json")
c.FileFromFS("assets/manifest.json", http.FS(api.Assets)) c.FileFromFS("assets/manifest.json", http.FS(api.assets))
} }
func (api *API) appServiceWorker(c *gin.Context) { func (api *API) appServiceWorker(c *gin.Context) {
c.FileFromFS("assets/sw.js", http.FS(api.Assets)) c.FileFromFS("assets/sw.js", http.FS(api.assets))
} }
func (api *API) appFaviconIcon(c *gin.Context) { func (api *API) appFaviconIcon(c *gin.Context) {
c.FileFromFS("assets/icons/favicon.ico", http.FS(api.Assets)) c.FileFromFS("assets/icons/favicon.ico", http.FS(api.assets))
} }
func (api *API) appLocalDocuments(c *gin.Context) { func (api *API) appLocalDocuments(c *gin.Context) {
c.FileFromFS("assets/local/index.htm", http.FS(api.Assets)) c.FileFromFS("assets/local/index.htm", http.FS(api.assets))
} }
func (api *API) appDocumentReader(c *gin.Context) { func (api *API) appDocumentReader(c *gin.Context) {
c.FileFromFS("assets/reader/index.htm", http.FS(api.Assets)) c.FileFromFS("assets/reader/index.htm", http.FS(api.assets))
} }
func (api *API) appGetDocuments(c *gin.Context) { func (api *API) appGetDocuments(c *gin.Context) {
@ -145,7 +145,7 @@ func (api *API) appGetDocuments(c *gin.Context) {
query = &search query = &search
} }
documents, err := api.DB.Queries.GetDocumentsWithStats(api.DB.Ctx, database.GetDocumentsWithStatsParams{ documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{
UserID: auth.UserName, UserID: auth.UserName,
Query: query, Query: query,
Offset: (*qParams.Page - 1) * *qParams.Limit, Offset: (*qParams.Page - 1) * *qParams.Limit,
@ -157,7 +157,7 @@ func (api *API) appGetDocuments(c *gin.Context) {
return return
} }
length, err := api.DB.Queries.GetDocumentsSize(api.DB.Ctx, query) length, err := api.db.Queries.GetDocumentsSize(api.db.Ctx, query)
if err != nil { if err != nil {
log.Error("GetDocumentsSize DB Error: ", err) log.Error("GetDocumentsSize DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err))
@ -196,7 +196,7 @@ func (api *API) appGetDocument(c *gin.Context) {
return return
} }
document, err := api.DB.Queries.GetDocumentWithStats(api.DB.Ctx, database.GetDocumentWithStatsParams{ document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{
UserID: auth.UserName, UserID: auth.UserName,
DocumentID: rDocID.DocumentID, DocumentID: rDocID.DocumentID,
}) })
@ -228,7 +228,7 @@ func (api *API) appGetProgress(c *gin.Context) {
progressFilter.DocumentID = *qParams.Document progressFilter.DocumentID = *qParams.Document
} }
progress, err := api.DB.Queries.GetProgress(api.DB.Ctx, progressFilter) progress, err := api.db.Queries.GetProgress(api.db.Ctx, progressFilter)
if err != nil { if err != nil {
log.Error("GetProgress DB Error: ", err) log.Error("GetProgress DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err))
@ -255,7 +255,7 @@ func (api *API) appGetActivity(c *gin.Context) {
activityFilter.DocumentID = *qParams.Document activityFilter.DocumentID = *qParams.Document
} }
activity, err := api.DB.Queries.GetActivity(api.DB.Ctx, activityFilter) activity, err := api.db.Queries.GetActivity(api.db.Ctx, activityFilter)
if err != nil { if err != nil {
log.Error("GetActivity DB Error: ", err) log.Error("GetActivity DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err))
@ -271,7 +271,7 @@ func (api *API) appGetHome(c *gin.Context) {
templateVars, auth := api.getBaseTemplateVars("home", c) templateVars, auth := api.getBaseTemplateVars("home", c)
start := time.Now() start := time.Now()
graphData, err := api.DB.Queries.GetDailyReadStats(api.DB.Ctx, auth.UserName) graphData, err := api.db.Queries.GetDailyReadStats(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDailyReadStats DB Error: ", err) log.Error("GetDailyReadStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDailyReadStats DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDailyReadStats DB Error: %v", err))
@ -280,7 +280,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetDailyReadStats DB Performance: ", time.Since(start)) log.Debug("GetDailyReadStats DB Performance: ", time.Since(start))
start = time.Now() start = time.Now()
databaseInfo, err := api.DB.Queries.GetDatabaseInfo(api.DB.Ctx, auth.UserName) databaseInfo, err := api.db.Queries.GetDatabaseInfo(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDatabaseInfo DB Error: ", err) log.Error("GetDatabaseInfo DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDatabaseInfo DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDatabaseInfo DB Error: %v", err))
@ -289,7 +289,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetDatabaseInfo DB Performance: ", time.Since(start)) log.Debug("GetDatabaseInfo DB Performance: ", time.Since(start))
start = time.Now() start = time.Now()
streaks, err := api.DB.Queries.GetUserStreaks(api.DB.Ctx, auth.UserName) streaks, err := api.db.Queries.GetUserStreaks(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetUserStreaks DB Error: ", err) log.Error("GetUserStreaks DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStreaks DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStreaks DB Error: %v", err))
@ -298,7 +298,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetUserStreaks DB Performance: ", time.Since(start)) log.Debug("GetUserStreaks DB Performance: ", time.Since(start))
start = time.Now() start = time.Now()
userStatistics, err := api.DB.Queries.GetUserStatistics(api.DB.Ctx) userStatistics, err := api.db.Queries.GetUserStatistics(api.db.Ctx)
if err != nil { if err != nil {
log.Error("GetUserStatistics DB Error: ", err) log.Error("GetUserStatistics DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStatistics DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStatistics DB Error: %v", err))
@ -319,14 +319,14 @@ func (api *API) appGetHome(c *gin.Context) {
func (api *API) appGetSettings(c *gin.Context) { func (api *API) appGetSettings(c *gin.Context) {
templateVars, auth := api.getBaseTemplateVars("settings", c) templateVars, auth := api.getBaseTemplateVars("settings", c)
user, err := api.DB.Queries.GetUser(api.DB.Ctx, auth.UserName) user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetUser DB Error: ", err) log.Error("GetUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err))
return return
} }
devices, err := api.DB.Queries.GetDevices(api.DB.Ctx, auth.UserName) devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDevices DB Error: ", err) log.Error("GetDevices DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err))
@ -350,7 +350,7 @@ func (api *API) appGetAdminLogs(c *gin.Context) {
templateVars, _ := api.getBaseTemplateVars("admin-logs", c) templateVars, _ := api.getBaseTemplateVars("admin-logs", c)
// Open Log File // Open Log File
logPath := filepath.Join(api.Config.ConfigPath, "logs/antholume.log") logPath := filepath.Join(api.cfg.ConfigPath, "logs/antholume.log")
logFile, err := os.Open(logPath) logFile, err := os.Open(logPath)
if err != nil { if err != nil {
appErrorPage(c, http.StatusBadRequest, "Missing AnthoLume log file.") appErrorPage(c, http.StatusBadRequest, "Missing AnthoLume log file.")
@ -388,7 +388,7 @@ func (api *API) appGetAdminLogs(c *gin.Context) {
func (api *API) appGetAdminUsers(c *gin.Context) { func (api *API) appGetAdminUsers(c *gin.Context) {
templateVars, _ := api.getBaseTemplateVars("admin-users", c) templateVars, _ := api.getBaseTemplateVars("admin-users", c)
users, err := api.DB.Queries.GetUsers(api.DB.Ctx) users, err := api.db.Queries.GetUsers(api.db.Ctx)
if err != nil { if err != nil {
log.Error("GetUsers DB Error: ", err) log.Error("GetUsers DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err))
@ -422,12 +422,21 @@ func (api *API) appPerformAdminAction(c *gin.Context) {
// 1. Documents xref most recent metadata table? // 1. Documents xref most recent metadata table?
// 2. Select all / deselect? // 2. Select all / deselect?
case adminCacheTables: case adminCacheTables:
go api.DB.CacheTempTables() go api.db.CacheTempTables()
case adminRestore: case adminRestore:
api.processRestoreFile(rAdminAction, c) api.processRestoreFile(rAdminAction, c)
case adminBackup: case adminBackup:
// Vacuum
_, err := api.db.DB.ExecContext(api.db.Ctx, "VACUUM;")
if err != nil {
log.Error("Unable to vacuum DB: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database.")
return
}
// Set Headers
c.Header("Content-type", "application/octet-stream") c.Header("Content-type", "application/octet-stream")
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"AnthoLumeBackup_%s.zip\"", time.Now().Format("20060102"))) c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"AnthoLumeBackup_%s.zip\"", time.Now().Format("20060102150405")))
// Stream Backup ZIP Archive // Stream Backup ZIP Archive
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
@ -479,18 +488,18 @@ func (api *API) appGetSearch(c *gin.Context) {
func (api *API) appGetLogin(c *gin.Context) { func (api *API) appGetLogin(c *gin.Context) {
templateVars, _ := api.getBaseTemplateVars("login", c) templateVars, _ := api.getBaseTemplateVars("login", c)
templateVars["RegistrationEnabled"] = api.Config.RegistrationEnabled templateVars["RegistrationEnabled"] = api.cfg.RegistrationEnabled
c.HTML(http.StatusOK, "page/login", templateVars) c.HTML(http.StatusOK, "page/login", templateVars)
} }
func (api *API) appGetRegister(c *gin.Context) { func (api *API) appGetRegister(c *gin.Context) {
if !api.Config.RegistrationEnabled { if !api.cfg.RegistrationEnabled {
c.Redirect(http.StatusFound, "/login") c.Redirect(http.StatusFound, "/login")
return return
} }
templateVars, _ := api.getBaseTemplateVars("login", c) templateVars, _ := api.getBaseTemplateVars("login", c)
templateVars["RegistrationEnabled"] = api.Config.RegistrationEnabled templateVars["RegistrationEnabled"] = api.cfg.RegistrationEnabled
templateVars["Register"] = true templateVars["Register"] = true
c.HTML(http.StatusOK, "page/login", templateVars) c.HTML(http.StatusOK, "page/login", templateVars)
} }
@ -508,7 +517,7 @@ func (api *API) appGetDocumentProgress(c *gin.Context) {
return return
} }
progress, err := api.DB.Queries.GetDocumentProgress(api.DB.Ctx, database.GetDocumentProgressParams{ progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{
DocumentID: rDoc.DocumentID, DocumentID: rDoc.DocumentID,
UserID: auth.UserName, UserID: auth.UserName,
}) })
@ -519,7 +528,7 @@ func (api *API) appGetDocumentProgress(c *gin.Context) {
return return
} }
document, err := api.DB.Queries.GetDocumentWithStats(api.DB.Ctx, database.GetDocumentWithStatsParams{ document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{
UserID: auth.UserName, UserID: auth.UserName,
DocumentID: rDoc.DocumentID, DocumentID: rDoc.DocumentID,
}) })
@ -545,7 +554,7 @@ func (api *API) appGetDevices(c *gin.Context) {
auth = data.(authData) auth = data.(authData)
} }
devices, err := api.DB.Queries.GetDevices(api.DB.Ctx, auth.UserName) devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
log.Error("GetDevices DB Error: ", err) log.Error("GetDevices DB Error: ", err)
@ -627,7 +636,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
} }
// Check Exists // Check Exists
_, err = api.DB.Queries.GetDocument(api.DB.Ctx, partialMD5) _, err = api.db.Queries.GetDocument(api.db.Ctx, partialMD5)
if err == nil { if err == nil {
c.Redirect(http.StatusFound, fmt.Sprintf("./documents/%s", partialMD5)) c.Redirect(http.StatusFound, fmt.Sprintf("./documents/%s", partialMD5))
return return
@ -670,7 +679,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
fileName = "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, partialMD5, fileExtension)) fileName = "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, partialMD5, fileExtension))
// Generate Storage Path & Open File // Generate Storage Path & Open File
safePath := filepath.Join(api.Config.DataPath, "documents", fileName) safePath := filepath.Join(api.cfg.DataPath, "documents", fileName)
destFile, err := os.Create(safePath) destFile, err := os.Create(safePath)
if err != nil { if err != nil {
log.Error("Dest File Error: ", err) log.Error("Dest File Error: ", err)
@ -687,7 +696,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
} }
// Upsert Document // Upsert Document
if _, err = api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: partialMD5, ID: partialMD5,
Title: metadataInfo.Title, Title: metadataInfo.Title,
Author: metadataInfo.Author, Author: metadataInfo.Author,
@ -764,7 +773,7 @@ func (api *API) appEditDocument(c *gin.Context) {
// Generate Storage Path // Generate Storage Path
fileName := fmt.Sprintf("%s%s", rDocID.DocumentID, fileExtension) fileName := fmt.Sprintf("%s%s", rDocID.DocumentID, fileExtension)
safePath := filepath.Join(api.Config.DataPath, "covers", fileName) safePath := filepath.Join(api.cfg.DataPath, "covers", fileName)
// Save // Save
err = c.SaveUploadedFile(rDocEdit.CoverFile, safePath) err = c.SaveUploadedFile(rDocEdit.CoverFile, safePath)
@ -776,7 +785,7 @@ func (api *API) appEditDocument(c *gin.Context) {
coverFileName = &fileName coverFileName = &fileName
} else if rDocEdit.CoverGBID != nil { } else if rDocEdit.CoverGBID != nil {
var coverDir string = filepath.Join(api.Config.DataPath, "covers") var coverDir string = filepath.Join(api.cfg.DataPath, "covers")
fileName, err := metadata.CacheCover(*rDocEdit.CoverGBID, coverDir, rDocID.DocumentID, true) fileName, err := metadata.CacheCover(*rDocEdit.CoverGBID, coverDir, rDocID.DocumentID, true)
if err == nil { if err == nil {
coverFileName = fileName coverFileName = fileName
@ -784,7 +793,7 @@ func (api *API) appEditDocument(c *gin.Context) {
} }
// Update Document // Update Document
if _, err := api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: rDocID.DocumentID, ID: rDocID.DocumentID,
Title: api.sanitizeInput(rDocEdit.Title), Title: api.sanitizeInput(rDocEdit.Title),
Author: api.sanitizeInput(rDocEdit.Author), Author: api.sanitizeInput(rDocEdit.Author),
@ -809,7 +818,7 @@ func (api *API) appDeleteDocument(c *gin.Context) {
appErrorPage(c, http.StatusNotFound, "Invalid document.") appErrorPage(c, http.StatusNotFound, "Invalid document.")
return return
} }
changed, err := api.DB.Queries.DeleteDocument(api.DB.Ctx, rDocID.DocumentID) changed, err := api.db.Queries.DeleteDocument(api.db.Ctx, rDocID.DocumentID)
if err != nil { if err != nil {
log.Error("DeleteDocument DB Error") log.Error("DeleteDocument DB Error")
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("DeleteDocument DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("DeleteDocument DB Error: %v", err))
@ -871,7 +880,7 @@ func (api *API) appIdentifyDocument(c *gin.Context) {
firstResult := metadataResults[0] firstResult := metadataResults[0]
// Store First Metadata Result // Store First Metadata Result
if _, err = api.DB.Queries.AddMetadata(api.DB.Ctx, database.AddMetadataParams{ if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{
DocumentID: rDocID.DocumentID, DocumentID: rDocID.DocumentID,
Title: firstResult.Title, Title: firstResult.Title,
Author: firstResult.Author, Author: firstResult.Author,
@ -890,7 +899,7 @@ func (api *API) appIdentifyDocument(c *gin.Context) {
templateVars["MetadataError"] = "No Metadata Found" templateVars["MetadataError"] = "No Metadata Found"
} }
document, err := api.DB.Queries.GetDocumentWithStats(api.DB.Ctx, database.GetDocumentWithStatsParams{ document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{
UserID: auth.UserName, UserID: auth.UserName,
DocumentID: rDocID.DocumentID, DocumentID: rDocID.DocumentID,
}) })
@ -1001,7 +1010,7 @@ func (api *API) appSaveNewDocument(c *gin.Context) {
defer sourceFile.Close() defer sourceFile.Close()
// Generate Storage Path & Open File // Generate Storage Path & Open File
safePath := filepath.Join(api.Config.DataPath, "documents", fileName) safePath := filepath.Join(api.cfg.DataPath, "documents", fileName)
destFile, err := os.Create(safePath) destFile, err := os.Create(safePath)
if err != nil { if err != nil {
log.Error("Dest File Error: ", err) log.Error("Dest File Error: ", err)
@ -1043,7 +1052,7 @@ func (api *API) appSaveNewDocument(c *gin.Context) {
sendDownloadMessage("Saving to database...", gin.H{"Progress": 90}) sendDownloadMessage("Saving to database...", gin.H{"Progress": 90})
// Upsert Document // Upsert Document
if _, err = api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: partialMD5, ID: partialMD5,
Title: rDocAdd.Title, Title: rDocAdd.Title,
Author: rDocAdd.Author, Author: rDocAdd.Author,
@ -1110,7 +1119,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Update User // Update User
_, err := api.DB.Queries.UpdateUser(api.DB.Ctx, newUserSettings) _, err := api.db.Queries.UpdateUser(api.db.Ctx, newUserSettings)
if err != nil { if err != nil {
log.Error("UpdateUser DB Error: ", err) log.Error("UpdateUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpdateUser DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpdateUser DB Error: %v", err))
@ -1118,7 +1127,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Get User // Get User
user, err := api.DB.Queries.GetUser(api.DB.Ctx, auth.UserName) user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetUser DB Error: ", err) log.Error("GetUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err))
@ -1126,7 +1135,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Get Devices // Get Devices
devices, err := api.DB.Queries.GetDevices(api.DB.Ctx, auth.UserName) devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDevices DB Error: ", err) log.Error("GetDevices DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err))
@ -1147,7 +1156,7 @@ func (api *API) appDemoModeError(c *gin.Context) {
func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStatsRow) error { func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStatsRow) error {
// Do Transaction // Do Transaction
tx, err := api.DB.DB.Begin() tx, err := api.db.DB.Begin()
if err != nil { if err != nil {
log.Error("Transaction Begin DB Error: ", err) log.Error("Transaction Begin DB Error: ", err)
return err return err
@ -1155,16 +1164,16 @@ func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStats
// Defer & Start Transaction // Defer & Start Transaction
defer tx.Rollback() defer tx.Rollback()
qtx := api.DB.Queries.WithTx(tx) qtx := api.db.Queries.WithTx(tx)
for _, item := range documents { for _, item := range documents {
if item.Words == nil && item.Filepath != nil { if item.Words == nil && item.Filepath != nil {
filePath := filepath.Join(api.Config.DataPath, "documents", *item.Filepath) filePath := filepath.Join(api.cfg.DataPath, "documents", *item.Filepath)
wordCount, err := metadata.GetWordCount(filePath) wordCount, err := metadata.GetWordCount(filePath)
if err != nil { if err != nil {
log.Warn("Word Count Error: ", err) log.Warn("Word Count Error: ", err)
} else { } else {
if _, err := qtx.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: item.ID, ID: item.ID,
Words: &wordCount, Words: &wordCount,
}); err != nil { }); err != nil {
@ -1194,9 +1203,9 @@ func (api *API) getBaseTemplateVars(routeName string, c *gin.Context) (gin.H, au
"Authorization": auth, "Authorization": auth,
"RouteName": routeName, "RouteName": routeName,
"Config": gin.H{ "Config": gin.H{
"Version": api.Config.Version, "Version": api.cfg.Version,
"SearchEnabled": api.Config.SearchEnabled, "SearchEnabled": api.cfg.SearchEnabled,
"RegistrationEnabled": api.Config.RegistrationEnabled, "RegistrationEnabled": api.cfg.RegistrationEnabled,
}, },
}, auth }, auth
} }
@ -1402,7 +1411,7 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
} }
// Create Backup File // Create Backup File
backupFilePath := filepath.Join(api.Config.ConfigPath, fmt.Sprintf("backup/AnthoLumeBackup_%s.zip", time.Now().Format("20060102"))) backupFilePath := filepath.Join(api.cfg.ConfigPath, fmt.Sprintf("backups/AnthoLumeBackup_%s.zip", time.Now().Format("20060102150405")))
backupFile, err := os.Create(backupFilePath) backupFile, err := os.Create(backupFilePath)
if err != nil { if err != nil {
log.Error("Unable to create backup file: ", err) log.Error("Unable to create backup file: ", err)
@ -1411,6 +1420,14 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
} }
defer backupFile.Close() defer backupFile.Close()
// Vacuum DB
_, err = api.db.DB.ExecContext(api.db.Ctx, "VACUUM;")
if err != nil {
log.Error("Unable to vacuum DB: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database.")
return
}
// Save Backup File // Save Backup File
w := bufio.NewWriter(backupFile) w := bufio.NewWriter(backupFile)
err = api.createBackup(w, []string{"covers", "documents"}) err = api.createBackup(w, []string{"covers", "documents"})
@ -1423,6 +1440,7 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
// Remove Data // Remove Data
err = api.removeData() err = api.removeData()
if err != nil { if err != nil {
log.Error("Unable to delete data: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to delete data.") appErrorPage(c, http.StatusInternalServerError, "Unable to delete data.")
return return
} }
@ -1431,19 +1449,26 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
err = api.restoreData(zipReader) err = api.restoreData(zipReader)
if err != nil { if err != nil {
appErrorPage(c, http.StatusInternalServerError, "Unable to restore data.") appErrorPage(c, http.StatusInternalServerError, "Unable to restore data.")
log.Panic("Unable to restore data: ", err)
// Panic?
log.Panic("Oh no")
return return
} }
// TODO: // Close DB
// - Extract from temp directory err = api.db.DB.Close()
if err != nil {
appErrorPage(c, http.StatusInternalServerError, "Unable to close DB.")
log.Panic("Unable to close DB: ", err)
}
// Reinit DB
api.db.Reload()
} }
func (api *API) restoreData(zipReader *zip.Reader) error { func (api *API) restoreData(zipReader *zip.Reader) error {
// Ensure Directories
api.cfg.EnsureDirectories()
// Restore Data
for _, file := range zipReader.File { for _, file := range zipReader.File {
rc, err := file.Open() rc, err := file.Open()
if err != nil { if err != nil {
@ -1451,7 +1476,7 @@ func (api *API) restoreData(zipReader *zip.Reader) error {
} }
defer rc.Close() defer rc.Close()
destPath := filepath.Join(api.Config.DataPath, file.Name) destPath := filepath.Join(api.cfg.DataPath, file.Name)
destFile, err := os.Create(destPath) destFile, err := os.Create(destPath)
if err != nil { if err != nil {
fmt.Println("Error creating destination file:", err) fmt.Println("Error creating destination file:", err)
@ -1481,7 +1506,7 @@ func (api *API) removeData() error {
} }
for _, name := range allPaths { for _, name := range allPaths {
fullPath := filepath.Join(api.Config.DataPath, name) fullPath := filepath.Join(api.cfg.DataPath, name)
err := os.RemoveAll(fullPath) err := os.RemoveAll(fullPath)
if err != nil { if err != nil {
log.Errorf("Unable to delete %s: %v", name, err) log.Errorf("Unable to delete %s: %v", name, err)
@ -1531,8 +1556,8 @@ func (api *API) createBackup(w io.Writer, directories []string) error {
} }
// Get DB Path // Get DB Path
fileName := fmt.Sprintf("%s.db", api.Config.DBName) fileName := fmt.Sprintf("%s.db", api.cfg.DBName)
dbLocation := filepath.Join(api.Config.ConfigPath, fileName) dbLocation := filepath.Join(api.cfg.ConfigPath, fileName)
// Copy Database File // Copy Database File
dbFile, err := os.Open(dbLocation) dbFile, err := os.Open(dbLocation)
@ -1549,7 +1574,7 @@ func (api *API) createBackup(w io.Writer, directories []string) error {
// Backup Covers & Documents // Backup Covers & Documents
for _, dir := range directories { for _, dir := range directories {
err = filepath.WalkDir(filepath.Join(api.Config.DataPath, dir), exportWalker) err = filepath.WalkDir(filepath.Join(api.cfg.DataPath, dir), exportWalker)
if err != nil { if err != nil {
return err return err
} }

View File

@ -32,7 +32,7 @@ type authOPDSHeader struct {
} }
func (api *API) authorizeCredentials(username string, password string) (auth *authData) { func (api *API) authorizeCredentials(username string, password string) (auth *authData) {
user, err := api.DB.Queries.GetUser(api.DB.Ctx, username) user, err := api.db.Queries.GetUser(api.db.Ctx, username)
if err != nil { if err != nil {
return return
} }
@ -174,7 +174,7 @@ func (api *API) appAuthFormLogin(c *gin.Context) {
} }
func (api *API) appAuthFormRegister(c *gin.Context) { func (api *API) appAuthFormRegister(c *gin.Context) {
if !api.Config.RegistrationEnabled { if !api.cfg.RegistrationEnabled {
appErrorPage(c, http.StatusUnauthorized, "Nice try. Registration is disabled.") appErrorPage(c, http.StatusUnauthorized, "Nice try. Registration is disabled.")
return return
} }
@ -199,7 +199,7 @@ func (api *API) appAuthFormRegister(c *gin.Context) {
return return
} }
rows, err := api.DB.Queries.CreateUser(api.DB.Ctx, database.CreateUserParams{ rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
ID: username, ID: username,
Pass: &hashedPassword, Pass: &hashedPassword,
}) })
@ -221,7 +221,7 @@ func (api *API) appAuthFormRegister(c *gin.Context) {
} }
// Get User // Get User
user, err := api.DB.Queries.GetUser(api.DB.Ctx, username) user, err := api.db.Queries.GetUser(api.db.Ctx, username)
if err != nil { if err != nil {
log.Error("GetUser DB Error:", err) log.Error("GetUser DB Error:", err)
templateVars["Error"] = "Registration Disabled or User Already Exists" templateVars["Error"] = "Registration Disabled or User Already Exists"

View File

@ -21,7 +21,7 @@ func (api *API) createDownloadDocumentHandler(errorFunc func(*gin.Context, int,
} }
// Get Document // Get Document
document, err := api.DB.Queries.GetDocument(api.DB.Ctx, rDoc.DocumentID) document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID)
if err != nil { if err != nil {
log.Error("GetDocument DB Error:", err) log.Error("GetDocument DB Error:", err)
errorFunc(c, http.StatusBadRequest, "Unknown Document") errorFunc(c, http.StatusBadRequest, "Unknown Document")
@ -35,12 +35,12 @@ func (api *API) createDownloadDocumentHandler(errorFunc func(*gin.Context, int,
} }
// Derive Storage Location // Derive Storage Location
filePath := filepath.Join(api.Config.DataPath, "documents", *document.Filepath) filePath := filepath.Join(api.cfg.DataPath, "documents", *document.Filepath)
// Validate File Exists // Validate File Exists
_, err = os.Stat(filePath) _, err = os.Stat(filePath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
log.Error("File Doesn't Exist:", rDoc.DocumentID) log.Error("File should but doesn't exist: ", err)
errorFunc(c, http.StatusBadRequest, "Document Doesn't Exist") errorFunc(c, http.StatusBadRequest, "Document Doesn't Exist")
return return
} }
@ -61,7 +61,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Validate Document Exists in DB // Validate Document Exists in DB
document, err := api.DB.Queries.GetDocument(api.DB.Ctx, rDoc.DocumentID) document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID)
if err != nil { if err != nil {
log.Error("GetDocument DB Error:", err) log.Error("GetDocument DB Error:", err)
errorFunc(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) errorFunc(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
@ -71,18 +71,18 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
// Handle Identified Document // Handle Identified Document
if document.Coverfile != nil { if document.Coverfile != nil {
if *document.Coverfile == "UNKNOWN" { if *document.Coverfile == "UNKNOWN" {
c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.Assets)) c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.assets))
return return
} }
// Derive Path // Derive Path
safePath := filepath.Join(api.Config.DataPath, "covers", *document.Coverfile) safePath := filepath.Join(api.cfg.DataPath, "covers", *document.Coverfile)
// Validate File Exists // Validate File Exists
_, err = os.Stat(safePath) _, err = os.Stat(safePath)
if err != nil { if err != nil {
log.Error("File Should But Doesn't Exist:", err) log.Error("File should but doesn't exist: ", err)
c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.Assets)) c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.assets))
return return
} }
@ -91,7 +91,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Attempt Metadata // Attempt Metadata
var coverDir string = filepath.Join(api.Config.DataPath, "covers") var coverDir string = filepath.Join(api.cfg.DataPath, "covers")
var coverFile string = "UNKNOWN" var coverFile string = "UNKNOWN"
// Identify Documents & Save Covers // Identify Documents & Save Covers
@ -110,7 +110,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Store First Metadata Result // Store First Metadata Result
if _, err = api.DB.Queries.AddMetadata(api.DB.Ctx, database.AddMetadataParams{ if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{
DocumentID: document.ID, DocumentID: document.ID,
Title: firstResult.Title, Title: firstResult.Title,
Author: firstResult.Author, Author: firstResult.Author,
@ -125,7 +125,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Upsert Document // Upsert Document
if _, err = api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: document.ID, ID: document.ID,
Coverfile: &coverFile, Coverfile: &coverFile,
}); err != nil { }); err != nil {
@ -134,7 +134,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
// Return Unknown Cover // Return Unknown Cover
if coverFile == "UNKNOWN" { if coverFile == "UNKNOWN" {
c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.Assets)) c.FileFromFS("assets/images/no-cover.jpg", http.FS(api.assets))
return return
} }

View File

@ -82,7 +82,7 @@ func (api *API) koAuthorizeUser(c *gin.Context) {
} }
func (api *API) koCreateUser(c *gin.Context) { func (api *API) koCreateUser(c *gin.Context) {
if !api.Config.RegistrationEnabled { if !api.cfg.RegistrationEnabled {
c.AbortWithStatus(http.StatusConflict) c.AbortWithStatus(http.StatusConflict)
return return
} }
@ -107,7 +107,7 @@ func (api *API) koCreateUser(c *gin.Context) {
return return
} }
rows, err := api.DB.Queries.CreateUser(api.DB.Ctx, database.CreateUserParams{ rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{
ID: rUser.Username, ID: rUser.Username,
Pass: &hashedPassword, Pass: &hashedPassword,
}) })
@ -143,7 +143,7 @@ func (api *API) koSetProgress(c *gin.Context) {
} }
// Upsert Device // Upsert Device
if _, err := api.DB.Queries.UpsertDevice(api.DB.Ctx, database.UpsertDeviceParams{ if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
ID: rPosition.DeviceID, ID: rPosition.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rPosition.Device, DeviceName: rPosition.Device,
@ -153,14 +153,14 @@ func (api *API) koSetProgress(c *gin.Context) {
} }
// Upsert Document // Upsert Document
if _, err := api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: rPosition.DocumentID, ID: rPosition.DocumentID,
}); err != nil { }); err != nil {
log.Error("UpsertDocument DB Error:", err) log.Error("UpsertDocument DB Error:", err)
} }
// Create or Replace Progress // Create or Replace Progress
progress, err := api.DB.Queries.UpdateProgress(api.DB.Ctx, database.UpdateProgressParams{ progress, err := api.db.Queries.UpdateProgress(api.db.Ctx, database.UpdateProgressParams{
Percentage: rPosition.Percentage, Percentage: rPosition.Percentage,
DocumentID: rPosition.DocumentID, DocumentID: rPosition.DocumentID,
DeviceID: rPosition.DeviceID, DeviceID: rPosition.DeviceID,
@ -192,7 +192,7 @@ func (api *API) koGetProgress(c *gin.Context) {
return return
} }
progress, err := api.DB.Queries.GetDocumentProgress(api.DB.Ctx, database.GetDocumentProgressParams{ progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{
DocumentID: rDocID.DocumentID, DocumentID: rDocID.DocumentID,
UserID: auth.UserName, UserID: auth.UserName,
}) })
@ -230,7 +230,7 @@ func (api *API) koAddActivities(c *gin.Context) {
} }
// Do Transaction // Do Transaction
tx, err := api.DB.DB.Begin() tx, err := api.db.DB.Begin()
if err != nil { if err != nil {
log.Error("Transaction Begin DB Error:", err) log.Error("Transaction Begin DB Error:", err)
apiErrorPage(c, http.StatusBadRequest, "Unknown Error") apiErrorPage(c, http.StatusBadRequest, "Unknown Error")
@ -246,11 +246,11 @@ func (api *API) koAddActivities(c *gin.Context) {
// Defer & Start Transaction // Defer & Start Transaction
defer tx.Rollback() defer tx.Rollback()
qtx := api.DB.Queries.WithTx(tx) qtx := api.db.Queries.WithTx(tx)
// Upsert Documents // Upsert Documents
for _, doc := range allDocuments { for _, doc := range allDocuments {
if _, err := qtx.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: doc, ID: doc,
}); err != nil { }); err != nil {
log.Error("UpsertDocument DB Error:", err) log.Error("UpsertDocument DB Error:", err)
@ -260,7 +260,7 @@ func (api *API) koAddActivities(c *gin.Context) {
} }
// Upsert Device // Upsert Device
if _, err = qtx.UpsertDevice(api.DB.Ctx, database.UpsertDeviceParams{ if _, err = qtx.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
ID: rActivity.DeviceID, ID: rActivity.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rActivity.Device, DeviceName: rActivity.Device,
@ -273,7 +273,7 @@ func (api *API) koAddActivities(c *gin.Context) {
// Add All Activity // Add All Activity
for _, item := range rActivity.Activity { for _, item := range rActivity.Activity {
if _, err := qtx.AddActivity(api.DB.Ctx, database.AddActivityParams{ if _, err := qtx.AddActivity(api.db.Ctx, database.AddActivityParams{
UserID: auth.UserName, UserID: auth.UserName,
DocumentID: item.DocumentID, DocumentID: item.DocumentID,
DeviceID: rActivity.DeviceID, DeviceID: rActivity.DeviceID,
@ -314,7 +314,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) {
} }
// Upsert Device // Upsert Device
if _, err := api.DB.Queries.UpsertDevice(api.DB.Ctx, database.UpsertDeviceParams{ if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
ID: rCheckActivity.DeviceID, ID: rCheckActivity.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rCheckActivity.Device, DeviceName: rCheckActivity.Device,
@ -326,7 +326,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) {
} }
// Get Last Device Activity // Get Last Device Activity
lastActivity, err := api.DB.Queries.GetLastActivity(api.DB.Ctx, database.GetLastActivityParams{ lastActivity, err := api.db.Queries.GetLastActivity(api.db.Ctx, database.GetLastActivityParams{
UserID: auth.UserName, UserID: auth.UserName,
DeviceID: rCheckActivity.DeviceID, DeviceID: rCheckActivity.DeviceID,
}) })
@ -360,7 +360,7 @@ func (api *API) koAddDocuments(c *gin.Context) {
} }
// Do Transaction // Do Transaction
tx, err := api.DB.DB.Begin() tx, err := api.db.DB.Begin()
if err != nil { if err != nil {
log.Error("Transaction Begin DB Error:", err) log.Error("Transaction Begin DB Error:", err)
apiErrorPage(c, http.StatusBadRequest, "Unknown Error") apiErrorPage(c, http.StatusBadRequest, "Unknown Error")
@ -369,11 +369,11 @@ func (api *API) koAddDocuments(c *gin.Context) {
// Defer & Start Transaction // Defer & Start Transaction
defer tx.Rollback() defer tx.Rollback()
qtx := api.DB.Queries.WithTx(tx) qtx := api.db.Queries.WithTx(tx)
// Upsert Documents // Upsert Documents
for _, doc := range rNewDocs.Documents { for _, doc := range rNewDocs.Documents {
_, err := qtx.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: doc.ID, ID: doc.ID,
Title: api.sanitizeInput(doc.Title), Title: api.sanitizeInput(doc.Title),
Author: api.sanitizeInput(doc.Author), Author: api.sanitizeInput(doc.Author),
@ -415,7 +415,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
} }
// Upsert Device // Upsert Device
_, err := api.DB.Queries.UpsertDevice(api.DB.Ctx, database.UpsertDeviceParams{ _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{
ID: rCheckDocs.DeviceID, ID: rCheckDocs.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rCheckDocs.Device, DeviceName: rCheckDocs.Device,
@ -431,7 +431,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
deletedDocIDs := []string{} deletedDocIDs := []string{}
// Get Missing Documents // Get Missing Documents
missingDocs, err = api.DB.Queries.GetMissingDocuments(api.DB.Ctx, rCheckDocs.Have) missingDocs, err = api.db.Queries.GetMissingDocuments(api.db.Ctx, rCheckDocs.Have)
if err != nil { if err != nil {
log.Error("GetMissingDocuments DB Error", err) log.Error("GetMissingDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request") apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@ -439,7 +439,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
} }
// Get Deleted Documents // Get Deleted Documents
deletedDocIDs, err = api.DB.Queries.GetDeletedDocuments(api.DB.Ctx, rCheckDocs.Have) deletedDocIDs, err = api.db.Queries.GetDeletedDocuments(api.db.Ctx, rCheckDocs.Have)
if err != nil { if err != nil {
log.Error("GetDeletedDocuments DB Error", err) log.Error("GetDeletedDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request") apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@ -454,7 +454,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
return return
} }
wantedDocs, err := api.DB.Queries.GetWantedDocuments(api.DB.Ctx, string(jsonHaves)) wantedDocs, err := api.db.Queries.GetWantedDocuments(api.db.Ctx, string(jsonHaves))
if err != nil { if err != nil {
log.Error("GetWantedDocuments DB Error", err) log.Error("GetWantedDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request") apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@ -524,7 +524,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
} }
// Validate Document Exists in DB // Validate Document Exists in DB
document, err := api.DB.Queries.GetDocument(api.DB.Ctx, rDoc.DocumentID) document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID)
if err != nil { if err != nil {
log.Error("GetDocument DB Error:", err) log.Error("GetDocument DB Error:", err)
apiErrorPage(c, http.StatusBadRequest, "Unknown Document") apiErrorPage(c, http.StatusBadRequest, "Unknown Document")
@ -552,7 +552,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
fileName = "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, document.ID, fileExtension)) fileName = "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, document.ID, fileExtension))
// Generate Storage Path // Generate Storage Path
safePath := filepath.Join(api.Config.DataPath, "documents", fileName) safePath := filepath.Join(api.cfg.DataPath, "documents", fileName)
// Save & Prevent Overwrites // Save & Prevent Overwrites
_, err = os.Stat(safePath) _, err = os.Stat(safePath)
@ -582,7 +582,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
} }
// Upsert Document // Upsert Document
if _, err = api.DB.Queries.UpsertDocument(api.DB.Ctx, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{
ID: document.ID, ID: document.ID,
Md5: fileHash, Md5: fileHash,
Filepath: &fileName, Filepath: &fileName,
@ -610,12 +610,12 @@ func (api *API) sanitizeInput(val any) *string {
switch v := val.(type) { switch v := val.(type) {
case *string: case *string:
if v != nil { if v != nil {
newString := html.UnescapeString(api.HTMLPolicy.Sanitize(string(*v))) newString := html.UnescapeString(htmlPolicy.Sanitize(string(*v)))
return &newString return &newString
} }
case string: case string:
if v != "" { if v != "" {
newString := html.UnescapeString(api.HTMLPolicy.Sanitize(string(v))) newString := html.UnescapeString(htmlPolicy.Sanitize(string(v)))
return &newString return &newString
} }
} }

View File

@ -77,7 +77,7 @@ func (api *API) opdsDocuments(c *gin.Context) {
} }
// Get Documents // Get Documents
documents, err := api.DB.Queries.GetDocumentsWithStats(api.DB.Ctx, database.GetDocumentsWithStatsParams{ documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{
UserID: auth.UserName, UserID: auth.UserName,
Query: query, Query: query,
Offset: (*qParams.Page - 1) * *qParams.Limit, Offset: (*qParams.Page - 1) * *qParams.Limit,

View File

@ -19,7 +19,7 @@ type streamer struct {
func (api *API) newStreamer(c *gin.Context, data string) *streamer { func (api *API) newStreamer(c *gin.Context, data string) *streamer {
stream := &streamer{ stream := &streamer{
templates: api.Templates, templates: api.templates,
writer: c.Writer, writer: c.Writer,
completeCh: make(chan struct{}), completeCh: make(chan struct{}),
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"path" "path"
"path/filepath"
"runtime" "runtime"
"strings" "strings"
@ -40,7 +41,7 @@ type customFormatter struct {
log.Formatter log.Formatter
} }
// Force UTC & Set Type (app) // Force UTC & Set type (app)
func (cf customFormatter) Format(e *log.Entry) ([]byte, error) { func (cf customFormatter) Format(e *log.Entry) ([]byte, error) {
if e.Data["type"] == nil { if e.Data["type"] == nil {
e.Data["type"] = "app" e.Data["type"] = "app"
@ -70,18 +71,18 @@ func Load() *Config {
CookieHTTPOnly: trimLowerString(getEnv("COOKIE_HTTP_ONLY", "true")) == "true", CookieHTTPOnly: trimLowerString(getEnv("COOKIE_HTTP_ONLY", "true")) == "true",
} }
// Log Level // Parse log level
logLevel, err := log.ParseLevel(c.LogLevel) logLevel, err := log.ParseLevel(c.LogLevel)
if err != nil { if err != nil {
logLevel = log.InfoLevel logLevel = log.InfoLevel
} }
// Log Formatter // Create custom formatter
logFormatter := &customFormatter{&log.JSONFormatter{ logFormatter := &customFormatter{&log.JSONFormatter{
CallerPrettyfier: prettyCaller, CallerPrettyfier: prettyCaller,
}} }}
// Log Rotater // Create log rotator
rotateFileHook, err := NewRotateFileHook(RotateFileConfig{ rotateFileHook, err := NewRotateFileHook(RotateFileConfig{
Filename: path.Join(c.ConfigPath, "logs/antholume.log"), Filename: path.Join(c.ConfigPath, "logs/antholume.log"),
MaxSize: 50, MaxSize: 50,
@ -94,17 +95,34 @@ func Load() *Config {
log.Fatal("Unable to initialize file rotate hook") log.Fatal("Unable to initialize file rotate hook")
} }
// Rotate Now // Rotate now
rotateFileHook.Rotate() rotateFileHook.Rotate()
// Set logger settings
log.SetLevel(logLevel) log.SetLevel(logLevel)
log.SetFormatter(logFormatter) log.SetFormatter(logFormatter)
log.SetReportCaller(true) log.SetReportCaller(true)
log.AddHook(rotateFileHook) log.AddHook(rotateFileHook)
// Ensure directories exist
c.EnsureDirectories()
return c return c
} }
// Ensures needed directories exist
func (c *Config) EnsureDirectories() {
os.Mkdir(c.ConfigPath, 0755)
os.Mkdir(c.DataPath, 0755)
docDir := filepath.Join(c.DataPath, "documents")
coversDir := filepath.Join(c.DataPath, "covers")
backupDir := filepath.Join(c.DataPath, "backups")
os.Mkdir(docDir, 0755)
os.Mkdir(coversDir, 0755)
os.Mkdir(backupDir, 0755)
}
func getEnv(key, fallback string) string { func getEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok { if value, ok := os.LookupEnv(key); ok {
return value return value

View File

@ -19,6 +19,7 @@ type DBManager struct {
DB *sql.DB DB *sql.DB
Ctx context.Context Ctx context.Context
Queries *Queries Queries *Queries
cfg *config.Config
} }
//go:embed schema.sql //go:embed schema.sql
@ -27,17 +28,25 @@ var ddl string
//go:embed migrations/* //go:embed migrations/*
var migrations embed.FS var migrations embed.FS
// Returns an initialized manager
func NewMgr(c *config.Config) *DBManager { func NewMgr(c *config.Config) *DBManager {
// Create Manager // Create Manager
dbm := &DBManager{ dbm := &DBManager{
Ctx: context.Background(), Ctx: context.Background(),
cfg: c,
} }
// Create Database dbm.init()
if c.DBType == "sqlite" || c.DBType == "memory" {
return dbm
}
// Init manager
func (dbm *DBManager) init() {
if dbm.cfg.DBType == "sqlite" || dbm.cfg.DBType == "memory" {
var dbLocation string = ":memory:" var dbLocation string = ":memory:"
if c.DBType == "sqlite" { if dbm.cfg.DBType == "sqlite" {
dbLocation = filepath.Join(c.ConfigPath, fmt.Sprintf("%s.db", c.DBName)) dbLocation = filepath.Join(dbm.cfg.ConfigPath, fmt.Sprintf("%s.db", dbm.cfg.DBName))
} }
var err error var err error
@ -67,12 +76,20 @@ func NewMgr(c *config.Config) *DBManager {
} }
dbm.Queries = New(dbm.DB) dbm.Queries = New(dbm.DB)
return dbm
} }
func (dbm *DBManager) Shutdown() error { // Reload manager (close DB & reinit)
return dbm.DB.Close() func (dbm *DBManager) Reload() error {
// Close handle
err := dbm.DB.Close()
if err != nil {
return err
}
// Reinit DB
dbm.init()
return nil
} }
func (dbm *DBManager) CacheTempTables() error { func (dbm *DBManager) CacheTempTables() error {

15
main.go
View File

@ -4,7 +4,6 @@ import (
"embed" "embed"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -38,18 +37,16 @@ func cmdServer(ctx *cli.Context) error {
log.Info("Starting AnthoLume Server") log.Info("Starting AnthoLume Server")
// Create Channel // Create Channel
wg := sync.WaitGroup{} signals := make(chan os.Signal, 1)
done := make(chan struct{}) signal.Notify(signals, os.Interrupt, syscall.SIGTERM)
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM)
// Start Server // Start Server
server := server.NewServer(&assets) s := server.New(&assets)
server.StartServer(&wg, done) s.Start()
// Wait & Close // Wait & Close
<-interrupt <-signals
server.StopServer(&wg, done) s.Stop()
// Stop Server // Stop Server
os.Exit(0) os.Exit(0)

View File

@ -1,11 +1,8 @@
package server package server
import ( import (
"context"
"embed" "embed"
"net/http" "net/http"
"os"
"path/filepath"
"sync" "sync"
"time" "time"
@ -16,91 +13,79 @@ import (
"reichard.io/antholume/database" "reichard.io/antholume/database"
) )
type Server struct { type server struct {
API *api.API db *database.DBManager
Config *config.Config api *api.API
Database *database.DBManager done chan int
httpServer *http.Server wg sync.WaitGroup
} }
func NewServer(assets *embed.FS) *Server { // Create new server
func New(assets *embed.FS) *server {
c := config.Load() c := config.Load()
db := database.NewMgr(c) db := database.NewMgr(c)
api := api.NewApi(db, c, assets) api := api.NewApi(db, c, assets)
// Create Paths return &server{
os.Mkdir(c.ConfigPath, 0755) db: db,
os.Mkdir(c.DataPath, 0755) api: api,
done: make(chan int),
// Create Subpaths
docDir := filepath.Join(c.DataPath, "documents")
coversDir := filepath.Join(c.DataPath, "covers")
backupDir := filepath.Join(c.DataPath, "backup")
os.Mkdir(docDir, 0755)
os.Mkdir(coversDir, 0755)
os.Mkdir(backupDir, 0755)
return &Server{
API: api,
Config: c,
Database: db,
httpServer: &http.Server{
Handler: api.Router,
Addr: (":" + c.ListenPort),
},
} }
} }
func (s *Server) StartServer(wg *sync.WaitGroup, done <-chan struct{}) { // Start server
ticker := time.NewTicker(15 * time.Minute) func (s *server) Start() {
log.Info("Starting server...")
wg.Add(2) s.wg.Add(2)
go func() { go func() {
defer wg.Done() defer s.wg.Done()
err := s.httpServer.ListenAndServe() err := s.api.Start()
if err != nil && err != http.ErrServerClosed { if err != nil && err != http.ErrServerClosed {
log.Error("Error starting server:", err) log.Error("Starting server failed: ", err)
} }
}() }()
go func() { go func() {
defer wg.Done() defer s.wg.Done()
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
s.RunScheduledTasks() s.runScheduledTasks()
case <-done: case <-s.done:
log.Info("Stopping task runner...") log.Info("Stopping task runner...")
return return
} }
} }
}() }()
log.Info("Server started")
} }
func (s *Server) RunScheduledTasks() { // Stop server
start := time.Now() func (s *server) Stop() {
if err := s.API.DB.CacheTempTables(); err != nil { log.Info("Stopping server...")
log.Warn("Refreshing temp table cache failure:", err)
} if err := s.api.Stop(); err != nil {
log.Debug("Completed in: ", time.Since(start)) log.Error("HTTP server stop failed: ", err)
} }
func (s *Server) StopServer(wg *sync.WaitGroup, done chan<- struct{}) { close(s.done)
log.Info("Stopping HTTP server...") s.wg.Wait()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.httpServer.Shutdown(ctx); err != nil {
log.Info("HTTP server shutdown error: ", err)
}
s.API.DB.Shutdown()
close(done)
wg.Wait()
log.Info("Server stopped") log.Info("Server stopped")
} }
// Run normal scheduled tasks
func (s *server) runScheduledTasks() {
start := time.Now()
if err := s.db.CacheTempTables(); err != nil {
log.Warn("Refreshing temp table cache failed: ", err)
}
log.Debug("Completed in: ", time.Since(start))
}