73 Commits

Author SHA1 Message Date
75c872264f chore: remove unnecessary crap ai added
Some checks failed
continuous-integration/drone/pr Build is failing
2026-04-03 19:46:05 -04:00
0930054847 more reader
Some checks failed
continuous-integration/drone/pr Build is failing
2026-04-03 13:45:17 -04:00
aa812c6917 wip reader migration 2026-04-03 12:15:48 -04:00
8ec3349b7c chore(api): update to allow CRUD progress and activity in v1 2026-04-03 10:37:50 -04:00
decc3f0195 fix: toast theme & error msgs 2026-04-03 10:08:13 -04:00
b13f9b362c theme draft 2 (done?) 2026-03-22 17:21:34 -04:00
6c2c4f6b8b remove dumb auth 2026-03-22 17:21:34 -04:00
d38392ac9a theme draft 1 2026-03-22 17:21:34 -04:00
63ad73755d wip 22 2026-03-22 17:21:34 -04:00
784e53c557 wip 21 2026-03-22 17:21:34 -04:00
9ed63b2695 wip 20 2026-03-22 17:21:34 -04:00
27e651c4f5 wip 19 2026-03-22 17:21:34 -04:00
7e96e41ba4 wip 18 2026-03-22 17:21:33 -04:00
ee1d62858b wip 17 2026-03-22 17:21:33 -04:00
4d133994ab wip 16 2026-03-22 17:21:33 -04:00
ba919bbde4 wip 15 2026-03-22 17:21:33 -04:00
197a1577c2 wip 14 2026-03-22 17:21:33 -04:00
fd9afe86b0 wip 13 2026-03-22 17:21:33 -04:00
93707ff513 wip 12 2026-03-22 17:21:33 -04:00
75e0228fe0 wip 11 2026-03-22 17:21:33 -04:00
b1b8eb297e wip 10 2026-03-22 17:21:33 -04:00
7c47f2d2eb wip 9 2026-03-22 17:21:33 -04:00
c46dcb440d wip 8 2026-03-22 17:21:33 -04:00
5cb17bace7 wip 7 2026-03-22 17:21:32 -04:00
ecf77fd105 wip 6 2026-03-22 17:21:32 -04:00
e289d1a29b wip 5 2026-03-22 17:21:32 -04:00
3e9a193d08 wip 4 2026-03-22 17:21:32 -04:00
4306d86080 wip 3 2026-03-22 17:21:32 -04:00
d40f8fc375 wip 2 2026-03-22 17:21:32 -04:00
c84bc2522e wip 1 2026-03-22 17:21:32 -04:00
0704b5d650 fix: book search
All checks were successful
continuous-integration/drone/push Build is passing
2026-03-22 17:21:16 -04:00
4c1789fc16 fix: doc parsing
All checks were successful
continuous-integration/drone Build is passing
2026-01-24 13:33:09 -05:00
082f7e926c fix: fix annas archive url 2026-01-24 13:33:01 -05:00
6031cf06d4 chore: update nix flake 2026-01-24 13:25:52 -05:00
8fd2aeb6a2 chore: add various tests 2025-12-13 14:04:32 -05:00
bc076a4f44 fix: metadata count test
All checks were successful
continuous-integration/drone/push Build is passing
2025-11-20 17:02:10 -05:00
f9f23f2d3f fix: word count calculation
Some checks failed
continuous-integration/drone/push Build is failing
2025-11-12 19:13:04 -05:00
3cff965393 fix: annas archive parsing
All checks were successful
continuous-integration/drone/push Build is passing
2025-08-17 17:04:46 -04:00
7937890acd fix: docker build
All checks were successful
continuous-integration/drone/push Build is passing
2025-08-10 13:18:37 -04:00
938dd69e5e chore(db): use context & add db helper 2025-08-10 13:17:51 -04:00
7c92c346fa feat(utils): add pkg utils 2025-08-10 13:17:44 -04:00
456b6e457c chore: update go & flake
Some checks failed
continuous-integration/drone/push Build is failing
2025-08-07 17:42:41 -04:00
d304421798 hm
All checks were successful
continuous-integration/drone/push Build is passing
2025-07-05 18:17:47 -04:00
0fe52bc541 fix: search parsing
Some checks failed
continuous-integration/drone/push Build is failing
2025-07-05 16:46:06 -04:00
49f3d53170 chore: nix flake
Some checks failed
continuous-integration/drone/push Build is failing
2025-07-05 15:21:44 -04:00
57f81e5dd7 fix(api): ko json content type
All checks were successful
continuous-integration/drone/push Build is passing
2025-05-13 12:37:45 -04:00
162adfbe16 feat: basic toc
All checks were successful
continuous-integration/drone/push Build is passing
2025-04-26 10:19:00 -04:00
e2cfdb3a0c update cicd
All checks were successful
continuous-integration/drone/push Build is passing
2025-03-14 08:36:01 -04:00
acf4119d9a fix(sql): document user stats
Some checks failed
continuous-integration/drone/push Build is passing
continuous-integration/drone Build was killed
2025-01-25 15:03:07 -05:00
f6dd8cee50 fix(streaks): incorrect calculation logic
All checks were successful
continuous-integration/drone/push Build is passing
2024-12-02 19:27:50 -05:00
a981d98ba5 feat(admin): basic log filter
All checks were successful
continuous-integration/drone/push Build is passing
2024-12-01 19:48:51 -05:00
a193f97d29 perf(db): incremental user streaks cache
All checks were successful
continuous-integration/drone/push Build is passing
2024-12-01 18:58:46 -05:00
841b29c425 improve(search): progress & retries
All checks were successful
continuous-integration/drone/push Build is passing
2024-12-01 17:04:41 -05:00
3d61d0f5ef perf(db): incremental document stats cache
All checks were successful
continuous-integration/drone/push Build is passing
2024-12-01 12:48:25 -05:00
5e388730a5 formatting: lua plugin 2024-12-01 11:28:33 -05:00
0a1dfeab65 fix(search): set user agent for dl
All checks were successful
continuous-integration/drone/push Build is passing
2024-08-13 22:32:16 -04:00
d4c8e4d2da fix(search): broken parser & download source
All checks were successful
continuous-integration/drone/push Build is passing
2024-08-11 11:02:46 -04:00
bbd3a00102 tests(db): additional document tests 2024-08-10 09:26:30 -04:00
3a633235ea tests(db): add additional tests & comments
All checks were successful
continuous-integration/drone/push Build is passing
2024-06-16 20:00:41 -04:00
9809a09d2e chore(prettier): format templates
All checks were successful
continuous-integration/drone/push Build is passing
2024-06-16 18:04:43 -04:00
f37bff365f chore(templates): prettier plugin & tables 2024-06-16 17:08:10 -04:00
77527bfb05 chore(templates): add better template loading
All checks were successful
continuous-integration/drone/push Build is passing
2024-05-27 20:20:47 -04:00
8de6fed5df fix(ui): document add styling 2024-05-27 14:01:10 -04:00
f9277d3b32 feat(admin): handle user deletion
All checks were successful
continuous-integration/drone/push Build is passing
2024-05-27 13:32:40 -04:00
db9629a618 chore(lint): address linter
All checks were successful
continuous-integration/drone/push Build is passing
2024-05-26 19:56:59 -04:00
546600db93 feat(admin): handle user demotion & promotion
All checks were successful
continuous-integration/drone/push Build is passing
2024-05-25 21:12:07 -04:00
7c6acad689 chore(templates): component-ize things
All checks were successful
continuous-integration/drone/push Build is passing
2024-05-25 20:04:26 -04:00
5482899075 feat(admin): adding user & importing 2024-05-25 20:02:57 -04:00
5a64ff7029 fix(tz): incorrect local_time function use
All checks were successful
continuous-integration/drone/push Build is passing
2024-04-06 20:56:30 -04:00
a7ecb1a6f8 fix(tz): add tzdata to docker image
All checks were successful
continuous-integration/drone/push Build is passing
2024-04-06 09:39:04 -04:00
2d206826d6 add(admin): add user
All checks were successful
continuous-integration/drone/push Build is passing
2024-03-11 22:20:41 -07:00
f1414e3e4e fix(timezones): move from utc offsets to timezones
This fixed various issues related to calculating streaks, etc. Now we
appropriately handle time as it was, vs as it is relative to an offset.
2024-03-11 22:20:21 -07:00
8e81acd381 fix(users): update user stomped on admin
All checks were successful
continuous-integration/drone/push Build is passing
2024-03-10 21:48:43 -04:00
278 changed files with 31874 additions and 3807 deletions

View File

@@ -1,7 +1,11 @@
kind: pipeline kind: pipeline
type: kubernetes type: docker
name: default name: default
trigger:
branch:
- master
steps: steps:
# Unit Tests # Unit Tests
- name: tests - name: tests
@@ -23,6 +27,8 @@ steps:
registry: gitea.va.reichard.io registry: gitea.va.reichard.io
tags: tags:
- dev - dev
custom_dns:
- 8.8.8.8
username: username:
from_secret: docker_username from_secret: docker_username
password: password:

2
.envrc
View File

@@ -1 +1 @@
use nix use flake

1
.gitignore vendored
View File

@@ -4,3 +4,4 @@ data/
build/ build/
.direnv/ .direnv/
cover.html cover.html
node_modules

3
.prettierrc Normal file
View File

@@ -0,0 +1,3 @@
{
"plugins": ["prettier-plugin-go-template"]
}

75
AGENTS.md Normal file
View File

@@ -0,0 +1,75 @@
# AnthoLume Agent Guide
## 1) Working Style
- Keep changes targeted.
- Do not refactor broadly unless the task requires it.
- Validate only what is relevant to the change when practical.
- If a fix will require substantial refactoring or wide-reaching changes, stop and ask first.
## 2) Hard Rules
- Never edit generated files directly.
- Never write ad-hoc SQL.
- For Go error wrapping, use `fmt.Errorf("message: %w", err)`.
- Do not use `github.com/pkg/errors`.
## 3) Generated Code
### OpenAPI
Edit:
- `api/v1/openapi.yaml`
Regenerate:
- `go generate ./api/v1/generate.go`
- `cd frontend && bun run generate:api`
Notes:
- If you add response headers in `api/v1/openapi.yaml` (for example `Set-Cookie`), `oapi-codegen` will generate typed response header structs in `api/v1/api.gen.go`; update the handler response values to populate those headers explicitly.
Examples of generated files:
- `api/v1/api.gen.go`
- `frontend/src/generated/**/*.ts`
### SQLC
Edit:
- `database/query.sql`
Regenerate:
- `sqlc generate`
## 4) Backend / Assets
### Common commands
- Dev server: `make dev`
- Direct dev run: `CONFIG_PATH=./data DATA_PATH=./data REGISTRATION_ENABLED=true go run main.go serve`
- Tests: `make tests`
- Tailwind asset build: `make build_tailwind`
### Notes
- The Go server embeds `templates/*` and `assets/*`.
- Root Tailwind output is built to `assets/style.css`.
- Be mindful of whether a change affects the embedded server-rendered app, the React frontend, or both.
- SQLite timestamps are stored as RFC3339 strings (usually with a trailing `Z`); prefer `parseTime` / `parseTimePtr` instead of ad-hoc `time.Parse` layouts.
## 5) Frontend
For frontend-specific implementation notes and commands, also read:
- `frontend/AGENTS.md`
## 6) Regeneration Summary
- Go API: `go generate ./api/v1/generate.go`
- Frontend API client: `cd frontend && bun run generate:api`
- SQLC: `sqlc generate`
## 7) Updating This File
After completing a task, update this `AGENTS.md` if you learned something general that would help future agents.
Rules for updates:
- Add only repository-wide guidance.
- Do not add one-off task history.
- Keep updates short, concrete, and organized.
- Place new guidance in the most relevant section.
- If the new information would help future agents avoid repeated mistakes, add it proactively.

View File

@@ -1,9 +1,9 @@
# Certificate Store # Certificates & Timezones
FROM alpine AS certs FROM alpine AS alpine
RUN apk update && apk add ca-certificates RUN apk update && apk add --no-cache ca-certificates tzdata
# Build Image # Build Image
FROM golang:1.21 AS build FROM golang:1.24 AS build
# Create Package Directory # Create Package Directory
RUN mkdir -p /opt/antholume RUN mkdir -p /opt/antholume
@@ -19,7 +19,8 @@ RUN go build \
# Create Image # Create Image
FROM busybox:1.36 FROM busybox:1.36
COPY --from=certs /etc/ssl/certs /etc/ssl/certs COPY --from=alpine /etc/ssl/certs /etc/ssl/certs
COPY --from=alpine /usr/share/zoneinfo /usr/share/zoneinfo
COPY --from=build /opt/antholume /opt/antholume COPY --from=build /opt/antholume /opt/antholume
WORKDIR /opt/antholume WORKDIR /opt/antholume
EXPOSE 8585 EXPOSE 8585

View File

@@ -1,6 +1,6 @@
# Certificate Store # Certificates & Timezones
FROM alpine AS certs FROM alpine AS alpine
RUN apk update && apk add ca-certificates RUN apk update && apk add --no-cache ca-certificates tzdata
# Build Image # Build Image
FROM --platform=$BUILDPLATFORM golang:1.21 AS build FROM --platform=$BUILDPLATFORM golang:1.21 AS build
@@ -21,7 +21,8 @@ RUN --mount=target=. \
# Create Image # Create Image
FROM busybox:1.36 FROM busybox:1.36
COPY --from=certs /etc/ssl/certs /etc/ssl/certs COPY --from=alpine /etc/ssl/certs /etc/ssl/certs
COPY --from=alpine /usr/share/zoneinfo /usr/share/zoneinfo
COPY --from=build /opt/antholume /opt/antholume COPY --from=build /opt/antholume /opt/antholume
WORKDIR /opt/antholume WORKDIR /opt/antholume
EXPOSE 8585 EXPOSE 8585

View File

@@ -27,7 +27,7 @@ docker_build_release_latest: build_tailwind
--push . --push .
build_tailwind: build_tailwind:
tailwind build -o ./assets/style.css --minify tailwindcss build -o ./assets/style.css --minify
dev: build_tailwind dev: build_tailwind
GIN_MODE=release \ GIN_MODE=release \

View File

@@ -118,7 +118,7 @@ See documentation in the `client` subfolder: [SyncNinja](https://gitea.va.reicha
## Development ## Development
SQLC Generation (v1.21.0): SQLC Generation (v1.26.0):
```bash ```bash
go install github.com/sqlc-dev/sqlc/cmd/sqlc@latest go install github.com/sqlc-dev/sqlc/cmd/sqlc@latest

BIN
antholume Executable file

Binary file not shown.

View File

@@ -6,6 +6,7 @@ import (
"html/template" "html/template"
"io/fs" "io/fs"
"net/http" "net/http"
"path"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
@@ -37,6 +38,7 @@ func NewApi(db *database.DBManager, c *config.Config, assets fs.FS) *API {
db: db, db: db,
cfg: c, cfg: c,
assets: assets, assets: assets,
templates: make(map[string]*template.Template),
userAuthCache: make(map[string]string), userAuthCache: make(map[string]string),
} }
@@ -111,6 +113,11 @@ func (api *API) Start() error {
return api.httpServer.ListenAndServe() return api.httpServer.ListenAndServe()
} }
// Handler returns the underlying http.Handler for the Gin router
func (api *API) Handler() http.Handler {
return api.httpServer.Handler
}
func (api *API) Stop() error { func (api *API) Stop() error {
// Stop server // Stop server
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -157,6 +164,7 @@ func (api *API) registerWebAppRoutes(router *gin.Engine) {
router.GET("/admin/import", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminImport) router.GET("/admin/import", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminImport)
router.POST("/admin/import", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appPerformAdminImport) router.POST("/admin/import", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appPerformAdminImport)
router.GET("/admin/users", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminUsers) router.GET("/admin/users", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdminUsers)
router.POST("/admin/users", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appUpdateAdminUsers)
router.GET("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdmin) router.GET("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appGetAdmin)
router.POST("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appPerformAdminAction) router.POST("/admin", api.authWebAppMiddleware, api.authAdminWebAppMiddleware, api.appPerformAdminAction)
router.POST("/login", api.appAuthLogin) router.POST("/login", api.appAuthLogin)
@@ -222,67 +230,112 @@ func (api *API) registerOPDSRoutes(apiGroup *gin.RouterGroup) {
func (api *API) generateTemplates() *multitemplate.Renderer { func (api *API) generateTemplates() *multitemplate.Renderer {
// Define templates & helper functions // Define templates & helper functions
templates := make(map[string]*template.Template)
render := multitemplate.NewRenderer() render := multitemplate.NewRenderer()
templates := make(map[string]*template.Template)
helperFuncs := template.FuncMap{ helperFuncs := template.FuncMap{
"dict": dict, "dict": dict,
"slice": slice,
"fields": fields, "fields": fields,
"getSVGGraphData": getSVGGraphData, "getSVGGraphData": getSVGGraphData,
"getUTCOffsets": getUTCOffsets, "getTimeZones": getTimeZones,
"hasPrefix": strings.HasPrefix, "hasPrefix": strings.HasPrefix,
"niceNumbers": niceNumbers, "niceNumbers": niceNumbers,
"niceSeconds": niceSeconds, "niceSeconds": niceSeconds,
} }
// Load base // Load Base
b, _ := fs.ReadFile(api.assets, "templates/base.tmpl") b, err := fs.ReadFile(api.assets, "templates/base.tmpl")
baseTemplate := template.Must(template.New("base").Funcs(helperFuncs).Parse(string(b))) if err != nil {
log.Errorf("error reading base template: %v", err)
return &render
}
// Parse Base
baseTemplate, err := template.New("base").Funcs(helperFuncs).Parse(string(b))
if err != nil {
log.Errorf("error parsing base template: %v", err)
return &render
}
// Load SVGs // Load SVGs
svgs, _ := fs.ReadDir(api.assets, "templates/svgs") err = api.loadTemplates("svg", baseTemplate, templates, false)
for _, item := range svgs { if err != nil {
basename := item.Name() log.Errorf("error loading svg templates: %v", err)
path := fmt.Sprintf("templates/svgs/%s", basename) return &render
name := strings.TrimSuffix(basename, filepath.Ext(basename))
b, _ := fs.ReadFile(api.assets, path)
baseTemplate = template.Must(baseTemplate.New("svg/" + name).Parse(string(b)))
templates["svg/"+name] = baseTemplate
} }
// Load components // Load Components
components, _ := fs.ReadDir(api.assets, "templates/components") err = api.loadTemplates("component", baseTemplate, templates, false)
for _, item := range components { if err != nil {
basename := item.Name() log.Errorf("error loading component templates: %v", err)
path := fmt.Sprintf("templates/components/%s", basename) return &render
name := strings.TrimSuffix(basename, filepath.Ext(basename))
// Clone Base Template
b, _ := fs.ReadFile(api.assets, path)
baseTemplate = template.Must(baseTemplate.New("component/" + name).Parse(string(b)))
render.Add("component/"+name, baseTemplate)
templates["component/"+name] = baseTemplate
} }
// Load pages // Load Pages
pages, _ := fs.ReadDir(api.assets, "templates/pages") err = api.loadTemplates("page", baseTemplate, templates, true)
for _, item := range pages { if err != nil {
basename := item.Name() log.Errorf("error loading page templates: %v", err)
path := fmt.Sprintf("templates/pages/%s", basename) return &render
name := strings.TrimSuffix(basename, filepath.Ext(basename))
// Clone Base Template
b, _ := fs.ReadFile(api.assets, path)
pageTemplate, _ := template.Must(baseTemplate.Clone()).New("page/" + name).Parse(string(b))
render.Add("page/"+name, pageTemplate)
templates["page/"+name] = pageTemplate
} }
// Populate Renderer
api.templates = templates api.templates = templates
for templateName, templateValue := range templates {
render.Add(templateName, templateValue)
}
return &render return &render
} }
func (api *API) loadTemplates(
basePath string,
baseTemplate *template.Template,
allTemplates map[string]*template.Template,
cloneBase bool,
) error {
// Load Templates (Pluralize)
templateDirectory := fmt.Sprintf("templates/%ss", basePath)
allFiles, err := fs.ReadDir(api.assets, templateDirectory)
if err != nil {
return fmt.Errorf("unable to read template dir %s: %w", templateDirectory, err)
}
// Generate Templates
for _, item := range allFiles {
templateFile := item.Name()
templatePath := path.Join(templateDirectory, templateFile)
templateName := fmt.Sprintf("%s/%s", basePath, strings.TrimSuffix(templateFile, filepath.Ext(templateFile)))
// Read Template
b, err := fs.ReadFile(api.assets, templatePath)
if err != nil {
return fmt.Errorf("unable to read template %s: %w", templateName, err)
}
// Clone? (Pages - Don't Stomp)
if cloneBase {
baseTemplate = template.Must(baseTemplate.Clone())
}
// Parse Template
baseTemplate, err = baseTemplate.New(templateName).Parse(string(b))
if err != nil {
return fmt.Errorf("unable to parse template %s: %w", templateName, err)
}
allTemplates[templateName] = baseTemplate
}
return nil
}
func (api *API) templateMiddleware(router *gin.Engine) gin.HandlerFunc {
return func(c *gin.Context) {
router.HTMLRender = *api.generateTemplates()
c.Next()
}
}
func loggingMiddleware(c *gin.Context) { func loggingMiddleware(c *gin.Context) {
// Start timer // Start timer
startTime := time.Now() startTime := time.Now()
@@ -298,7 +351,7 @@ func loggingMiddleware(c *gin.Context) {
logData := log.Fields{ logData := log.Fields{
"type": "access", "type": "access",
"ip": c.ClientIP(), "ip": c.ClientIP(),
"latency": fmt.Sprintf("%s", latency), "latency": latency.String(),
"status": c.Writer.Status(), "status": c.Writer.Status(),
"method": c.Request.Method, "method": c.Request.Method,
"path": c.Request.URL.Path, "path": c.Request.URL.Path,
@@ -318,10 +371,3 @@ func loggingMiddleware(c *gin.Context) {
// Log result // Log result
log.WithFields(logData).Info(fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path)) log.WithFields(logData).Info(fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
} }
func (api *API) templateMiddleware(router *gin.Engine) gin.HandlerFunc {
return func(c *gin.Context) {
router.HTMLRender = *api.generateTemplates()
c.Next()
}
}

View File

@@ -3,6 +3,8 @@ package api
import ( import (
"archive/zip" "archive/zip"
"bufio" "bufio"
"context"
"crypto/md5"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -12,14 +14,18 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"sort"
"strings" "strings"
"time" "time"
argon2 "github.com/alexedwards/argon2id"
"github.com/gabriel-vasile/mimetype" "github.com/gabriel-vasile/mimetype"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/itchyny/gojq" "github.com/itchyny/gojq"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
"reichard.io/antholume/metadata" "reichard.io/antholume/metadata"
"reichard.io/antholume/utils"
) )
type adminAction string type adminAction string
@@ -54,10 +60,41 @@ type requestAdminImport struct {
Type importType `form:"type"` Type importType `form:"type"`
} }
type operationType string
const (
opUpdate operationType = "UPDATE"
opCreate operationType = "CREATE"
opDelete operationType = "DELETE"
)
type requestAdminUpdateUser struct {
User string `form:"user"`
Password *string `form:"password"`
IsAdmin *bool `form:"is_admin"`
Operation operationType `form:"operation"`
}
type requestAdminLogs struct { type requestAdminLogs struct {
Filter string `form:"filter"` Filter string `form:"filter"`
} }
type importStatus string
const (
importFailed importStatus = "FAILED"
importSuccess importStatus = "SUCCESS"
importExists importStatus = "EXISTS"
)
type importResult struct {
ID string
Name string
Path string
Status importStatus
Error error
}
func (api *API) appPerformAdminAction(c *gin.Context) { func (api *API) appPerformAdminAction(c *gin.Context) {
templateVars, _ := api.getBaseTemplateVars("admin", c) templateVars, _ := api.getBaseTemplateVars("admin", c)
@@ -68,21 +105,24 @@ func (api *API) appPerformAdminAction(c *gin.Context) {
return return
} }
// TODO - Messages
switch rAdminAction.Action { switch rAdminAction.Action {
case adminMetadataMatch: case adminMetadataMatch:
// TODO // TODO
// 1. Documents xref most recent metadata table? // 1. Documents xref most recent metadata table?
// 2. Select all / deselect? // 2. Select all / deselect?
case adminCacheTables: case adminCacheTables:
go api.db.CacheTempTables() go func() {
// TODO - Message err := api.db.CacheTempTables(c)
if err != nil {
log.Error("Unable to cache temp tables: ", err)
}
}()
case adminRestore: case adminRestore:
api.processRestoreFile(rAdminAction, c) api.processRestoreFile(rAdminAction, c)
return return
case adminBackup: case adminBackup:
// Vacuum // Vacuum
_, err := api.db.DB.ExecContext(api.db.Ctx, "VACUUM;") _, err := api.db.DB.ExecContext(c, "VACUUM;")
if err != nil { if err != nil {
log.Error("Unable to vacuum DB: ", err) log.Error("Unable to vacuum DB: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database") appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database")
@@ -104,7 +144,7 @@ func (api *API) appPerformAdminAction(c *gin.Context) {
} }
} }
err := api.createBackup(w, directories) err := api.createBackup(c, w, directories)
if err != nil { if err != nil {
log.Error("Backup Error: ", err) log.Error("Backup Error: ", err)
} }
@@ -134,7 +174,10 @@ func (api *API) appGetAdminLogs(c *gin.Context) {
rAdminLogs.Filter = strings.TrimSpace(rAdminLogs.Filter) rAdminLogs.Filter = strings.TrimSpace(rAdminLogs.Filter)
var jqFilter *gojq.Code var jqFilter *gojq.Code
if rAdminLogs.Filter != "" { var basicFilter string
if strings.HasPrefix(rAdminLogs.Filter, "\"") && strings.HasSuffix(rAdminLogs.Filter, "\"") {
basicFilter = rAdminLogs.Filter[1 : len(rAdminLogs.Filter)-1]
} else if rAdminLogs.Filter != "" {
parsed, err := gojq.Parse(rAdminLogs.Filter) parsed, err := gojq.Parse(rAdminLogs.Filter)
if err != nil { if err != nil {
log.Error("Unable to parse JQ filter") log.Error("Unable to parse JQ filter")
@@ -166,7 +209,7 @@ func (api *API) appGetAdminLogs(c *gin.Context) {
rawLog := scanner.Text() rawLog := scanner.Text()
// Attempt JSON Pretty // Attempt JSON Pretty
var jsonMap map[string]interface{} var jsonMap map[string]any
err := json.Unmarshal([]byte(rawLog), &jsonMap) err := json.Unmarshal([]byte(rawLog), &jsonMap)
if err != nil { if err != nil {
logLines = append(logLines, scanner.Text()) logLines = append(logLines, scanner.Text())
@@ -180,12 +223,17 @@ func (api *API) appGetAdminLogs(c *gin.Context) {
continue continue
} }
// No Filter // Basic Filter
if jqFilter == nil { if basicFilter != "" && strings.Contains(string(rawData), basicFilter) {
logLines = append(logLines, string(rawData)) logLines = append(logLines, string(rawData))
continue continue
} }
// No JQ Filter
if jqFilter == nil {
continue
}
// Error or nil // Error or nil
result, _ := jqFilter.Run(jsonMap).Next() result, _ := jqFilter.Run(jsonMap).Next()
if _, ok := result.(error); ok { if _, ok := result.(error); ok {
@@ -213,7 +261,53 @@ func (api *API) appGetAdminLogs(c *gin.Context) {
func (api *API) appGetAdminUsers(c *gin.Context) { func (api *API) appGetAdminUsers(c *gin.Context) {
templateVars, _ := api.getBaseTemplateVars("admin-users", c) templateVars, _ := api.getBaseTemplateVars("admin-users", c)
users, err := api.db.Queries.GetUsers(api.db.Ctx) users, err := api.db.Queries.GetUsers(c)
if err != nil {
log.Error("GetUsers DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err))
return
}
templateVars["Data"] = users
c.HTML(http.StatusOK, "page/admin-users", templateVars)
}
func (api *API) appUpdateAdminUsers(c *gin.Context) {
templateVars, _ := api.getBaseTemplateVars("admin-users", c)
var rUpdate requestAdminUpdateUser
if err := c.ShouldBind(&rUpdate); err != nil {
log.Error("Invalid URI Bind")
appErrorPage(c, http.StatusNotFound, "Invalid user parameters")
return
}
// Ensure Username
if rUpdate.User == "" {
appErrorPage(c, http.StatusInternalServerError, "User cannot be empty")
return
}
var err error
switch rUpdate.Operation {
case opCreate:
err = api.createUser(c, rUpdate.User, rUpdate.Password, rUpdate.IsAdmin)
case opUpdate:
err = api.updateUser(c, rUpdate.User, rUpdate.Password, rUpdate.IsAdmin)
case opDelete:
err = api.deleteUser(c, rUpdate.User)
default:
appErrorPage(c, http.StatusNotFound, "Unknown user operation")
return
}
if err != nil {
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("Unable to create or update user: %v", err))
return
}
users, err := api.db.Queries.GetUsers(c)
if err != nil { if err != nil {
log.Error("GetUsers DB Error: ", err) log.Error("GetUsers DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUsers DB Error: %v", err))
@@ -285,46 +379,157 @@ func (api *API) appPerformAdminImport(c *gin.Context) {
return return
} }
// TODO - Store results for approval? // Get import directory
// Walk import directory & copy or import files
importDirectory := filepath.Clean(rAdminImport.Directory) importDirectory := filepath.Clean(rAdminImport.Directory)
_ = filepath.WalkDir(importDirectory, func(currentPath string, f fs.DirEntry, err error) error {
// Get data directory
absoluteDataPath, _ := filepath.Abs(filepath.Join(api.cfg.DataPath, "documents"))
// Validate different path
if absoluteDataPath == importDirectory {
appErrorPage(c, http.StatusBadRequest, "Directory is the same as data path")
return
}
// Do Transaction
tx, err := api.db.DB.Begin()
if err != nil {
log.Error("Transaction Begin DB Error:", err)
apiErrorPage(c, http.StatusBadRequest, "Unknown error")
return
}
// Defer & Start Transaction
defer func() {
if err := tx.Rollback(); err != nil {
log.Error("DB Rollback Error:", err)
}
}()
qtx := api.db.Queries.WithTx(tx)
// Track imports
importResults := make([]importResult, 0)
// Walk Directory & Import
err = filepath.WalkDir(importDirectory, func(importPath string, f fs.DirEntry, err error) error {
if err != nil { if err != nil {
return err return err
} }
if f.IsDir() { if f.IsDir() {
return nil return nil
} }
// Get metadata // Get relative path
fileMeta, err := metadata.GetMetadata(currentPath) basePath := importDirectory
relFilePath, err := filepath.Rel(importDirectory, importPath)
if err != nil { if err != nil {
fmt.Printf("metadata error: %v\n", err) log.Warnf("path error: %v", err)
return nil return nil
} }
// Only needed if copying // Track imports
newName := deriveBaseFileName(fileMeta) iResult := importResult{
Path: relFilePath,
Status: importFailed,
}
defer func() {
importResults = append(importResults, iResult)
}()
// Open File on Disk // Get metadata
// file, err := os.Open(currentPath) fileMeta, err := metadata.GetMetadata(importPath)
// if err != nil { if err != nil {
// return err log.Errorf("metadata error: %v", err)
// } iResult.Error = err
// defer file.Close() return nil
}
iResult.ID = *fileMeta.PartialMD5
iResult.Name = fmt.Sprintf("%s - %s", *fileMeta.Author, *fileMeta.Title)
// TODO - BasePath in DB // Check already exists
// TODO - Copy / Import _, err = qtx.GetDocument(c, *fileMeta.PartialMD5)
if err == nil {
log.Warnf("document already exists: %s", *fileMeta.PartialMD5)
iResult.Status = importExists
return nil
}
fmt.Printf("New File Metadata: %s\n", newName) // Import Copy
if rAdminImport.Type == importCopy {
// Derive & Sanitize File Name
relFilePath = deriveBaseFileName(fileMeta)
safePath := filepath.Join(api.cfg.DataPath, "documents", relFilePath)
// Open Source File
srcFile, err := os.Open(importPath)
if err != nil {
log.Errorf("unable to open current file: %v", err)
iResult.Error = err
return nil
}
defer srcFile.Close()
// Open Destination File
destFile, err := os.Create(safePath)
if err != nil {
log.Errorf("unable to open destination file: %v", err)
iResult.Error = err
return nil
}
defer destFile.Close()
// Copy File
if _, err = io.Copy(destFile, srcFile); err != nil {
log.Errorf("unable to save file: %v", err)
iResult.Error = err
return nil
}
// Update Base & Path
basePath = filepath.Join(api.cfg.DataPath, "documents")
iResult.Path = relFilePath
}
// Upsert document
if _, err = qtx.UpsertDocument(c, database.UpsertDocumentParams{
ID: *fileMeta.PartialMD5,
Title: fileMeta.Title,
Author: fileMeta.Author,
Description: fileMeta.Description,
Md5: fileMeta.MD5,
Words: fileMeta.WordCount,
Filepath: &relFilePath,
Basepath: &basePath,
}); err != nil {
log.Errorf("UpsertDocument DB Error: %v", err)
iResult.Error = err
return nil
}
iResult.Status = importSuccess
return nil return nil
}) })
if err != nil {
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("Import Failed: %v", err))
return
}
templateVars["CurrentPath"] = filepath.Clean(rAdminImport.Directory) // Commit transaction
if err := tx.Commit(); err != nil {
log.Error("Transaction Commit DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("Import DB Error: %v", err))
return
}
c.HTML(http.StatusOK, "page/admin-import", templateVars) // Sort import results
sort.Slice(importResults, func(i int, j int) bool {
return importStatusPriority(importResults[i].Status) <
importStatusPriority(importResults[j].Status)
})
templateVars["Data"] = importResults
c.HTML(http.StatusOK, "page/admin-import-results", templateVars)
} }
func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Context) { func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Context) {
@@ -420,17 +625,9 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
} }
defer backupFile.Close() defer backupFile.Close()
// Vacuum DB
_, err = api.db.DB.ExecContext(api.db.Ctx, "VACUUM;")
if err != nil {
log.Error("Unable to vacuum DB: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to vacuum database")
return
}
// Save Backup File // Save Backup File
w := bufio.NewWriter(backupFile) w := bufio.NewWriter(backupFile)
err = api.createBackup(w, []string{"covers", "documents"}) err = api.createBackup(c, w, []string{"covers", "documents"})
if err != nil { if err != nil {
log.Error("Unable to save backup file: ", err) log.Error("Unable to save backup file: ", err)
appErrorPage(c, http.StatusInternalServerError, "Unable to save backup file") appErrorPage(c, http.StatusInternalServerError, "Unable to save backup file")
@@ -453,13 +650,13 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
} }
// Reinit DB // Reinit DB
if err := api.db.Reload(); err != nil { if err := api.db.Reload(c); err != nil {
appErrorPage(c, http.StatusInternalServerError, "Unable to reload DB") appErrorPage(c, http.StatusInternalServerError, "Unable to reload DB")
log.Panicf("Unable to reload DB: %v", err) log.Panicf("Unable to reload DB: %v", err)
} }
// Rotate Auth Hashes // Rotate Auth Hashes
if err := api.rotateAllAuthHashes(); err != nil { if err := api.rotateAllAuthHashes(c); err != nil {
appErrorPage(c, http.StatusInternalServerError, "Unable to rotate hashes") appErrorPage(c, http.StatusInternalServerError, "Unable to rotate hashes")
log.Panicf("Unable to rotate auth hashes: %v", err) log.Panicf("Unable to rotate auth hashes: %v", err)
} }
@@ -468,7 +665,6 @@ func (api *API) processRestoreFile(rAdminAction requestAdminAction, c *gin.Conte
c.Redirect(http.StatusFound, "/login") c.Redirect(http.StatusFound, "/login")
} }
// Restore all data
func (api *API) restoreData(zipReader *zip.Reader) error { func (api *API) restoreData(zipReader *zip.Reader) error {
// Ensure Directories // Ensure Directories
api.cfg.EnsureDirectories() api.cfg.EnsureDirectories()
@@ -484,14 +680,14 @@ func (api *API) restoreData(zipReader *zip.Reader) error {
destPath := filepath.Join(api.cfg.DataPath, file.Name) destPath := filepath.Join(api.cfg.DataPath, file.Name)
destFile, err := os.Create(destPath) destFile, err := os.Create(destPath)
if err != nil { if err != nil {
fmt.Println("Error creating destination file:", err) log.Errorf("error creating destination file: %v", err)
return err return err
} }
defer destFile.Close() defer destFile.Close()
// Copy the contents from the zip file to the destination file. // Copy the contents from the zip file to the destination file.
if _, err := io.Copy(destFile, rc); err != nil { if _, err := io.Copy(destFile, rc); err != nil {
fmt.Println("Error copying file contents:", err) log.Errorf("Error copying file contents: %v", err)
return err return err
} }
} }
@@ -499,7 +695,6 @@ func (api *API) restoreData(zipReader *zip.Reader) error {
return nil return nil
} }
// Remove all data
func (api *API) removeData() error { func (api *API) removeData() error {
allPaths := []string{ allPaths := []string{
"covers", "covers",
@@ -522,10 +717,14 @@ func (api *API) removeData() error {
return nil return nil
} }
// Backup all data func (api *API) createBackup(ctx context.Context, w io.Writer, directories []string) error {
func (api *API) createBackup(w io.Writer, directories []string) error { // Vacuum DB
ar := zip.NewWriter(w) _, err := api.db.DB.ExecContext(ctx, "VACUUM;")
if err != nil {
return fmt.Errorf("Unable to vacuum database: %w", err)
}
ar := zip.NewWriter(w)
exportWalker := func(currentPath string, f fs.DirEntry, err error) error { exportWalker := func(currentPath string, f fs.DirEntry, err error) error {
if err != nil { if err != nil {
return err return err
@@ -575,7 +774,11 @@ func (api *API) createBackup(w io.Writer, directories []string) error {
if err != nil { if err != nil {
return err return err
} }
io.Copy(newDbFile, dbFile)
_, err = io.Copy(newDbFile, dbFile)
if err != nil {
return err
}
// Backup Covers & Documents // Backup Covers & Documents
for _, dir := range directories { for _, dir := range directories {
@@ -588,3 +791,159 @@ func (api *API) createBackup(w io.Writer, directories []string) error {
ar.Close() ar.Close()
return nil return nil
} }
func (api *API) isLastAdmin(ctx context.Context, userID string) (bool, error) {
allUsers, err := api.db.Queries.GetUsers(ctx)
if err != nil {
return false, fmt.Errorf("GetUsers DB Error: %w", err)
}
hasAdmin := false
for _, user := range allUsers {
if user.Admin && user.ID != userID {
hasAdmin = true
break
}
}
return !hasAdmin, nil
}
func (api *API) createUser(ctx context.Context, user string, rawPassword *string, isAdmin *bool) error {
// Validate Necessary Parameters
if rawPassword == nil || *rawPassword == "" {
return fmt.Errorf("password can't be empty")
}
// Base Params
createParams := database.CreateUserParams{
ID: user,
}
// Handle Admin (Explicit or False)
if isAdmin != nil {
createParams.Admin = *isAdmin
} else {
createParams.Admin = false
}
// Parse Password
password := fmt.Sprintf("%x", md5.Sum([]byte(*rawPassword)))
hashedPassword, err := argon2.CreateHash(password, argon2.DefaultParams)
if err != nil {
return fmt.Errorf("unable to create hashed password")
}
createParams.Pass = &hashedPassword
// Generate Auth Hash
rawAuthHash, err := utils.GenerateToken(64)
if err != nil {
return fmt.Errorf("unable to create token for user")
}
authHash := fmt.Sprintf("%x", rawAuthHash)
createParams.AuthHash = &authHash
// Create user in DB
if rows, err := api.db.Queries.CreateUser(ctx, createParams); err != nil {
log.Error("CreateUser DB Error:", err)
return fmt.Errorf("unable to create user")
} else if rows == 0 {
log.Warn("User Already Exists:", createParams.ID)
return fmt.Errorf("user already exists")
}
return nil
}
func (api *API) updateUser(ctx context.Context, user string, rawPassword *string, isAdmin *bool) error {
// Validate Necessary Parameters
if rawPassword == nil && isAdmin == nil {
return fmt.Errorf("nothing to update")
}
// Base Params
updateParams := database.UpdateUserParams{
UserID: user,
}
// Handle Admin (Update or Existing)
if isAdmin != nil {
updateParams.Admin = *isAdmin
} else {
user, err := api.db.Queries.GetUser(ctx, user)
if err != nil {
return fmt.Errorf("GetUser DB Error: %w", err)
}
updateParams.Admin = user.Admin
}
// Check Admins - Disallow Demotion
if isLast, err := api.isLastAdmin(ctx, user); err != nil {
return err
} else if isLast && !updateParams.Admin {
return fmt.Errorf("unable to demote %s - last admin", user)
}
// Handle Password
if rawPassword != nil {
if *rawPassword == "" {
return fmt.Errorf("password can't be empty")
}
// Parse Password
password := fmt.Sprintf("%x", md5.Sum([]byte(*rawPassword)))
hashedPassword, err := argon2.CreateHash(password, argon2.DefaultParams)
if err != nil {
return fmt.Errorf("unable to create hashed password")
}
updateParams.Password = &hashedPassword
// Generate Auth Hash
rawAuthHash, err := utils.GenerateToken(64)
if err != nil {
return fmt.Errorf("unable to create token for user")
}
authHash := fmt.Sprintf("%x", rawAuthHash)
updateParams.AuthHash = &authHash
}
// Update User
_, err := api.db.Queries.UpdateUser(ctx, updateParams)
if err != nil {
return fmt.Errorf("UpdateUser DB Error: %w", err)
}
return nil
}
func (api *API) deleteUser(ctx context.Context, user string) error {
// Check Admins
if isLast, err := api.isLastAdmin(ctx, user); err != nil {
return err
} else if isLast {
return fmt.Errorf("unable to delete %s - last admin", user)
}
// Create Backup File
backupFilePath := filepath.Join(api.cfg.ConfigPath, fmt.Sprintf("backups/AnthoLumeBackup_%s.zip", time.Now().Format("20060102150405")))
backupFile, err := os.Create(backupFilePath)
if err != nil {
return err
}
defer backupFile.Close()
// Save Backup File (DB Only)
w := bufio.NewWriter(backupFile)
err = api.createBackup(ctx, w, []string{})
if err != nil {
return err
}
// Delete User
_, err = api.db.Queries.DeleteUser(ctx, user)
if err != nil {
return fmt.Errorf("DeleteUser DB Error: %w", err)
}
return nil
}

View File

@@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"crypto/md5" "crypto/md5"
"database/sql" "database/sql"
"fmt" "fmt"
@@ -22,8 +23,8 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"reichard.io/antholume/database" "reichard.io/antholume/database"
"reichard.io/antholume/metadata" "reichard.io/antholume/metadata"
"reichard.io/antholume/pkg/ptr"
"reichard.io/antholume/search" "reichard.io/antholume/search"
"reichard.io/antholume/utils"
) )
type backupType string type backupType string
@@ -69,7 +70,7 @@ type requestDocumentIdentify struct {
type requestSettingsEdit struct { type requestSettingsEdit struct {
Password *string `form:"password"` Password *string `form:"password"`
NewPassword *string `form:"new_password"` NewPassword *string `form:"new_password"`
TimeOffset *string `form:"time_offset"` Timezone *string `form:"timezone"`
} }
type requestDocumentAdd struct { type requestDocumentAdd struct {
@@ -110,9 +111,10 @@ func (api *API) appGetDocuments(c *gin.Context) {
query = &search query = &search
} }
documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{ documents, err := api.db.Queries.GetDocumentsWithStats(c, database.GetDocumentsWithStatsParams{
UserID: auth.UserName, UserID: auth.UserName,
Query: query, Query: query,
Deleted: ptr.Of(false),
Offset: (*qParams.Page - 1) * *qParams.Limit, Offset: (*qParams.Page - 1) * *qParams.Limit,
Limit: *qParams.Limit, Limit: *qParams.Limit,
}) })
@@ -122,14 +124,14 @@ func (api *API) appGetDocuments(c *gin.Context) {
return return
} }
length, err := api.db.Queries.GetDocumentsSize(api.db.Ctx, query) length, err := api.db.Queries.GetDocumentsSize(c, query)
if err != nil { if err != nil {
log.Error("GetDocumentsSize DB Error: ", err) log.Error("GetDocumentsSize DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err))
return return
} }
if err = api.getDocumentsWordCount(documents); err != nil { if err = api.getDocumentsWordCount(c, documents); err != nil {
log.Error("Unable to Get Word Counts: ", err) log.Error("Unable to Get Word Counts: ", err)
} }
@@ -161,13 +163,10 @@ func (api *API) appGetDocument(c *gin.Context) {
return return
} }
document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{ document, err := api.db.GetDocument(c, rDocID.DocumentID, auth.UserName)
UserID: auth.UserName,
DocumentID: rDocID.DocumentID,
})
if err != nil { if err != nil {
log.Error("GetDocumentWithStats DB Error: ", err) log.Error("GetDocument DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsWithStats DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
return return
} }
@@ -193,7 +192,7 @@ func (api *API) appGetProgress(c *gin.Context) {
progressFilter.DocumentID = *qParams.Document progressFilter.DocumentID = *qParams.Document
} }
progress, err := api.db.Queries.GetProgress(api.db.Ctx, progressFilter) progress, err := api.db.Queries.GetProgress(c, progressFilter)
if err != nil { if err != nil {
log.Error("GetProgress DB Error: ", err) log.Error("GetProgress DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err))
@@ -220,7 +219,7 @@ func (api *API) appGetActivity(c *gin.Context) {
activityFilter.DocumentID = *qParams.Document activityFilter.DocumentID = *qParams.Document
} }
activity, err := api.db.Queries.GetActivity(api.db.Ctx, activityFilter) activity, err := api.db.Queries.GetActivity(c, activityFilter)
if err != nil { if err != nil {
log.Error("GetActivity DB Error: ", err) log.Error("GetActivity DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetActivity DB Error: %v", err))
@@ -236,7 +235,7 @@ func (api *API) appGetHome(c *gin.Context) {
templateVars, auth := api.getBaseTemplateVars("home", c) templateVars, auth := api.getBaseTemplateVars("home", c)
start := time.Now() start := time.Now()
graphData, err := api.db.Queries.GetDailyReadStats(api.db.Ctx, auth.UserName) graphData, err := api.db.Queries.GetDailyReadStats(c, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDailyReadStats DB Error: ", err) log.Error("GetDailyReadStats DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDailyReadStats DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDailyReadStats DB Error: %v", err))
@@ -245,7 +244,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetDailyReadStats DB Performance: ", time.Since(start)) log.Debug("GetDailyReadStats DB Performance: ", time.Since(start))
start = time.Now() start = time.Now()
databaseInfo, err := api.db.Queries.GetDatabaseInfo(api.db.Ctx, auth.UserName) databaseInfo, err := api.db.Queries.GetDatabaseInfo(c, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDatabaseInfo DB Error: ", err) log.Error("GetDatabaseInfo DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDatabaseInfo DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDatabaseInfo DB Error: %v", err))
@@ -254,7 +253,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetDatabaseInfo DB Performance: ", time.Since(start)) log.Debug("GetDatabaseInfo DB Performance: ", time.Since(start))
start = time.Now() start = time.Now()
streaks, err := api.db.Queries.GetUserStreaks(api.db.Ctx, auth.UserName) streaks, err := api.db.Queries.GetUserStreaks(c, auth.UserName)
if err != nil { if err != nil {
log.Error("GetUserStreaks DB Error: ", err) log.Error("GetUserStreaks DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStreaks DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStreaks DB Error: %v", err))
@@ -263,7 +262,7 @@ func (api *API) appGetHome(c *gin.Context) {
log.Debug("GetUserStreaks DB Performance: ", time.Since(start)) log.Debug("GetUserStreaks DB Performance: ", time.Since(start))
start = time.Now() start = time.Now()
userStatistics, err := api.db.Queries.GetUserStatistics(api.db.Ctx) userStatistics, err := api.db.Queries.GetUserStatistics(c)
if err != nil { if err != nil {
log.Error("GetUserStatistics DB Error: ", err) log.Error("GetUserStatistics DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStatistics DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUserStatistics DB Error: %v", err))
@@ -284,14 +283,14 @@ func (api *API) appGetHome(c *gin.Context) {
func (api *API) appGetSettings(c *gin.Context) { func (api *API) appGetSettings(c *gin.Context) {
templateVars, auth := api.getBaseTemplateVars("settings", c) templateVars, auth := api.getBaseTemplateVars("settings", c)
user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName) user, err := api.db.Queries.GetUser(c, auth.UserName)
if err != nil { if err != nil {
log.Error("GetUser DB Error: ", err) log.Error("GetUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err))
return return
} }
devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName) devices, err := api.db.Queries.GetDevices(c, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDevices DB Error: ", err) log.Error("GetDevices DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err))
@@ -299,7 +298,7 @@ func (api *API) appGetSettings(c *gin.Context) {
} }
templateVars["Data"] = gin.H{ templateVars["Data"] = gin.H{
"TimeOffset": *user.TimeOffset, "Timezone": *user.Timezone,
"Devices": devices, "Devices": devices,
} }
@@ -314,7 +313,11 @@ func (api *API) appGetSearch(c *gin.Context) {
templateVars, _ := api.getBaseTemplateVars("search", c) templateVars, _ := api.getBaseTemplateVars("search", c)
var sParams searchParams var sParams searchParams
c.BindQuery(&sParams) err := c.BindQuery(&sParams)
if err != nil {
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("Invalid Form Bind: %v", err))
return
}
// Only Handle Query // Only Handle Query
if sParams.Query != nil && sParams.Source != nil { if sParams.Query != nil && sParams.Source != nil {
@@ -365,24 +368,20 @@ func (api *API) appGetDocumentProgress(c *gin.Context) {
return return
} }
progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{ progress, err := api.db.Queries.GetDocumentProgress(c, database.GetDocumentProgressParams{
DocumentID: rDoc.DocumentID, DocumentID: rDoc.DocumentID,
UserID: auth.UserName, UserID: auth.UserName,
}) })
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
log.Error("UpsertDocument DB Error: ", err) log.Error("GetDocumentProgress DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpsertDocument DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentProgress DB Error: %v", err))
return return
} }
document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{ document, err := api.db.GetDocument(c, rDoc.DocumentID, auth.UserName)
UserID: auth.UserName,
DocumentID: rDoc.DocumentID,
})
if err != nil { if err != nil {
log.Error("GetDocumentWithStats DB Error: ", err) log.Error("GetDocument DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentWithStats DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
return return
} }
@@ -402,7 +401,7 @@ func (api *API) appGetDevices(c *gin.Context) {
auth = data.(authData) auth = data.(authData)
} }
devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName) devices, err := api.db.Queries.GetDevices(c, auth.UserName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
log.Error("GetDevices DB Error: ", err) log.Error("GetDevices DB Error: ", err)
@@ -453,7 +452,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
} }
// Check Already Exists // Check Already Exists
_, err = api.db.Queries.GetDocument(api.db.Ctx, *metadataInfo.PartialMD5) _, err = api.db.Queries.GetDocument(c, *metadataInfo.PartialMD5)
if err == nil { if err == nil {
log.Warnf("document already exists: %s", *metadataInfo.PartialMD5) log.Warnf("document already exists: %s", *metadataInfo.PartialMD5)
c.Redirect(http.StatusFound, fmt.Sprintf("./documents/%s", *metadataInfo.PartialMD5)) c.Redirect(http.StatusFound, fmt.Sprintf("./documents/%s", *metadataInfo.PartialMD5))
@@ -461,7 +460,8 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
// Derive & Sanitize File Name // Derive & Sanitize File Name
fileName := deriveBaseFileName(metadataInfo) fileName := deriveBaseFileName(metadataInfo)
safePath := filepath.Join(api.cfg.DataPath, "documents", fileName) basePath := filepath.Join(api.cfg.DataPath, "documents")
safePath := filepath.Join(basePath, fileName)
// Open Destination File // Open Destination File
destFile, err := os.Create(safePath) destFile, err := os.Create(safePath)
@@ -480,7 +480,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
} }
// Upsert Document // Upsert Document
if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: *metadataInfo.PartialMD5, ID: *metadataInfo.PartialMD5,
Title: metadataInfo.Title, Title: metadataInfo.Title,
Author: metadataInfo.Author, Author: metadataInfo.Author,
@@ -488,9 +488,7 @@ func (api *API) appUploadNewDocument(c *gin.Context) {
Md5: metadataInfo.MD5, Md5: metadataInfo.MD5,
Words: metadataInfo.WordCount, Words: metadataInfo.WordCount,
Filepath: &fileName, Filepath: &fileName,
Basepath: &basePath,
// TODO (BasePath):
// - Should be current config directory
}); err != nil { }); err != nil {
log.Errorf("UpsertDocument DB Error: %v", err) log.Errorf("UpsertDocument DB Error: %v", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpsertDocument DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpsertDocument DB Error: %v", err))
@@ -572,7 +570,7 @@ func (api *API) appEditDocument(c *gin.Context) {
coverFileName = &fileName coverFileName = &fileName
} else if rDocEdit.CoverGBID != nil { } else if rDocEdit.CoverGBID != nil {
var coverDir string = filepath.Join(api.cfg.DataPath, "covers") coverDir := filepath.Join(api.cfg.DataPath, "covers")
fileName, err := metadata.CacheCover(*rDocEdit.CoverGBID, coverDir, rDocID.DocumentID, true) fileName, err := metadata.CacheCover(*rDocEdit.CoverGBID, coverDir, rDocID.DocumentID, true)
if err == nil { if err == nil {
coverFileName = fileName coverFileName = fileName
@@ -580,7 +578,7 @@ func (api *API) appEditDocument(c *gin.Context) {
} }
// Update Document // Update Document
if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ if _, err := api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: rDocID.DocumentID, ID: rDocID.DocumentID,
Title: api.sanitizeInput(rDocEdit.Title), Title: api.sanitizeInput(rDocEdit.Title),
Author: api.sanitizeInput(rDocEdit.Author), Author: api.sanitizeInput(rDocEdit.Author),
@@ -595,7 +593,6 @@ func (api *API) appEditDocument(c *gin.Context) {
} }
c.Redirect(http.StatusFound, "./") c.Redirect(http.StatusFound, "./")
return
} }
func (api *API) appDeleteDocument(c *gin.Context) { func (api *API) appDeleteDocument(c *gin.Context) {
@@ -605,7 +602,7 @@ func (api *API) appDeleteDocument(c *gin.Context) {
appErrorPage(c, http.StatusNotFound, "Invalid document") appErrorPage(c, http.StatusNotFound, "Invalid document")
return return
} }
changed, err := api.db.Queries.DeleteDocument(api.db.Ctx, rDocID.DocumentID) changed, err := api.db.Queries.DeleteDocument(c, rDocID.DocumentID)
if err != nil { if err != nil {
log.Error("DeleteDocument DB Error") log.Error("DeleteDocument DB Error")
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("DeleteDocument DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("DeleteDocument DB Error: %v", err))
@@ -667,7 +664,7 @@ func (api *API) appIdentifyDocument(c *gin.Context) {
firstResult := metadataResults[0] firstResult := metadataResults[0]
// Store First Metadata Result // Store First Metadata Result
if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{ if _, err = api.db.Queries.AddMetadata(c, database.AddMetadataParams{
DocumentID: rDocID.DocumentID, DocumentID: rDocID.DocumentID,
Title: firstResult.Title, Title: firstResult.Title,
Author: firstResult.Author, Author: firstResult.Author,
@@ -686,13 +683,10 @@ func (api *API) appIdentifyDocument(c *gin.Context) {
templateVars["MetadataError"] = "No Metadata Found" templateVars["MetadataError"] = "No Metadata Found"
} }
document, err := api.db.Queries.GetDocumentWithStats(api.db.Ctx, database.GetDocumentWithStatsParams{ document, err := api.db.GetDocument(c, rDocID.DocumentID, auth.UserName)
UserID: auth.UserName,
DocumentID: rDocID.DocumentID,
})
if err != nil { if err != nil {
log.Error("GetDocumentWithStats DB Error: ", err) log.Error("GetDocument DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentWithStats DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
return return
} }
@@ -739,52 +733,50 @@ func (api *API) appSaveNewDocument(c *gin.Context) {
} }
// Send Message // Send Message
sendDownloadMessage("Downloading document...", gin.H{"Progress": 10}) sendDownloadMessage("Downloading document...", gin.H{"Progress": 1})
// Scaled Download Function
lastTime := time.Now()
downloadFunc := func(p float32) {
nowTime := time.Now()
if nowTime.Before(lastTime.Add(time.Millisecond * 500)) {
return
}
scaledProgress := int((p * 95 / 100) + 2)
sendDownloadMessage("Downloading document...", gin.H{"Progress": scaledProgress})
lastTime = nowTime
}
// Save Book // Save Book
tempFilePath, err := search.SaveBook(rDocAdd.ID, rDocAdd.Source) tempFilePath, metadata, err := search.SaveBook(rDocAdd.ID, rDocAdd.Source, downloadFunc)
if err != nil { if err != nil {
log.Warn("Temp File Error: ", err) log.Warn("Save Book Error: ", err)
sendDownloadMessage("Unable to download file", gin.H{"Error": true}) sendDownloadMessage("Unable to download file", gin.H{"Error": true})
return return
} }
// Send Message // Send Message
sendDownloadMessage("Calculating partial MD5...", gin.H{"Progress": 60}) sendDownloadMessage("Saving document...", gin.H{"Progress": 98})
// Calculate Partial MD5 ID // Derive Author / Title
partialMD5, err := utils.CalculatePartialMD5(tempFilePath) docAuthor := "Unknown"
if err != nil { if *metadata.Author != "" {
log.Warn("Partial MD5 Error: ", err) docAuthor = *metadata.Author
sendDownloadMessage("Unable to calculate partial MD5", gin.H{"Error": true}) } else if *rDocAdd.Author != "" {
docAuthor = *rDocAdd.Author
} }
// Send Message docTitle := "Unknown"
sendDownloadMessage("Saving file...", gin.H{"Progress": 60}) if *metadata.Title != "" {
docTitle = *metadata.Title
// Derive Extension on MIME } else if *rDocAdd.Title != "" {
fileMime, err := mimetype.DetectFile(tempFilePath) docTitle = *rDocAdd.Title
fileExtension := fileMime.Extension()
// Derive Filename
var fileName string
if *rDocAdd.Author != "" {
fileName = fileName + *rDocAdd.Author
} else {
fileName = fileName + "Unknown"
} }
if *rDocAdd.Title != "" { // Remove Slashes & Sanitize File Name
fileName = fileName + " - " + *rDocAdd.Title fileName := fmt.Sprintf("%s - %s", docAuthor, docTitle)
} else {
fileName = fileName + " - Unknown"
}
// Remove Slashes
fileName = strings.ReplaceAll(fileName, "/", "") fileName = strings.ReplaceAll(fileName, "/", "")
fileName = "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, *metadata.PartialMD5, metadata.Type))
// Derive & Sanitize File Name
fileName = "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, *partialMD5, fileExtension))
// Open Source File // Open Source File
sourceFile, err := os.Open(tempFilePath) sourceFile, err := os.Open(tempFilePath)
@@ -797,7 +789,9 @@ func (api *API) appSaveNewDocument(c *gin.Context) {
defer sourceFile.Close() defer sourceFile.Close()
// Generate Storage Path & Open File // Generate Storage Path & Open File
safePath := filepath.Join(api.cfg.DataPath, "documents", fileName) basePath := filepath.Join(api.cfg.DataPath, "documents")
safePath := filepath.Join(basePath, fileName)
destFile, err := os.Create(safePath) destFile, err := os.Create(safePath)
if err != nil { if err != nil {
log.Error("Dest File Error: ", err) log.Error("Dest File Error: ", err)
@@ -814,38 +808,17 @@ func (api *API) appSaveNewDocument(c *gin.Context) {
} }
// Send Message // Send Message
sendDownloadMessage("Calculating MD5...", gin.H{"Progress": 70}) sendDownloadMessage("Saving to database...", gin.H{"Progress": 99})
// Get MD5 Hash
fileHash, err := getFileMD5(safePath)
if err != nil {
log.Error("Hash Failure: ", err)
sendDownloadMessage("Unable to calculate MD5", gin.H{"Error": true})
return
}
// Send Message
sendDownloadMessage("Calculating word count...", gin.H{"Progress": 80})
// Get Word Count
wordCount, err := metadata.GetWordCount(safePath)
if err != nil {
log.Error("Word Count Failure: ", err)
sendDownloadMessage("Unable to calculate word count", gin.H{"Error": true})
return
}
// Send Message
sendDownloadMessage("Saving to database...", gin.H{"Progress": 90})
// Upsert Document // Upsert Document
if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: *partialMD5, ID: *metadata.PartialMD5,
Title: rDocAdd.Title, Title: &docTitle,
Author: rDocAdd.Author, Author: &docAuthor,
Md5: fileHash, Md5: metadata.MD5,
Words: metadata.WordCount,
Filepath: &fileName, Filepath: &fileName,
Words: wordCount, Basepath: &basePath,
}); err != nil { }); err != nil {
log.Error("UpsertDocument DB Error: ", err) log.Error("UpsertDocument DB Error: ", err)
sendDownloadMessage("Unable to save to database", gin.H{"Error": true}) sendDownloadMessage("Unable to save to database", gin.H{"Error": true})
@@ -856,7 +829,7 @@ func (api *API) appSaveNewDocument(c *gin.Context) {
sendDownloadMessage("Download Success", gin.H{ sendDownloadMessage("Download Success", gin.H{
"Progress": 100, "Progress": 100,
"ButtonText": "Go to Book", "ButtonText": "Go to Book",
"ButtonHref": fmt.Sprintf("./documents/%s", *partialMD5), "ButtonHref": fmt.Sprintf("./documents/%s", *metadata.PartialMD5),
}) })
} }
@@ -869,7 +842,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Validate Something Exists // Validate Something Exists
if rUserSettings.Password == nil && rUserSettings.NewPassword == nil && rUserSettings.TimeOffset == nil { if rUserSettings.Password == nil && rUserSettings.NewPassword == nil && rUserSettings.Timezone == nil {
log.Error("Missing Form Values") log.Error("Missing Form Values")
appErrorPage(c, http.StatusBadRequest, "Invalid or missing form values") appErrorPage(c, http.StatusBadRequest, "Invalid or missing form values")
return return
@@ -879,12 +852,13 @@ func (api *API) appEditSettings(c *gin.Context) {
newUserSettings := database.UpdateUserParams{ newUserSettings := database.UpdateUserParams{
UserID: auth.UserName, UserID: auth.UserName,
Admin: auth.IsAdmin,
} }
// Set New Password // Set New Password
if rUserSettings.Password != nil && rUserSettings.NewPassword != nil { if rUserSettings.Password != nil && rUserSettings.NewPassword != nil {
password := fmt.Sprintf("%x", md5.Sum([]byte(*rUserSettings.Password))) password := fmt.Sprintf("%x", md5.Sum([]byte(*rUserSettings.Password)))
data := api.authorizeCredentials(auth.UserName, password) data := api.authorizeCredentials(c, auth.UserName, password)
if data == nil { if data == nil {
templateVars["PasswordErrorMessage"] = "Invalid Password" templateVars["PasswordErrorMessage"] = "Invalid Password"
} else { } else {
@@ -900,13 +874,13 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Set Time Offset // Set Time Offset
if rUserSettings.TimeOffset != nil { if rUserSettings.Timezone != nil {
templateVars["TimeOffsetMessage"] = "Time Offset Updated" templateVars["TimeOffsetMessage"] = "Time Offset Updated"
newUserSettings.TimeOffset = rUserSettings.TimeOffset newUserSettings.Timezone = rUserSettings.Timezone
} }
// Update User // Update User
_, err := api.db.Queries.UpdateUser(api.db.Ctx, newUserSettings) _, err := api.db.Queries.UpdateUser(c, newUserSettings)
if err != nil { if err != nil {
log.Error("UpdateUser DB Error: ", err) log.Error("UpdateUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpdateUser DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("UpdateUser DB Error: %v", err))
@@ -914,7 +888,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Get User // Get User
user, err := api.db.Queries.GetUser(api.db.Ctx, auth.UserName) user, err := api.db.Queries.GetUser(c, auth.UserName)
if err != nil { if err != nil {
log.Error("GetUser DB Error: ", err) log.Error("GetUser DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetUser DB Error: %v", err))
@@ -922,7 +896,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
// Get Devices // Get Devices
devices, err := api.db.Queries.GetDevices(api.db.Ctx, auth.UserName) devices, err := api.db.Queries.GetDevices(c, auth.UserName)
if err != nil { if err != nil {
log.Error("GetDevices DB Error: ", err) log.Error("GetDevices DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err)) appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDevices DB Error: %v", err))
@@ -930,7 +904,7 @@ func (api *API) appEditSettings(c *gin.Context) {
} }
templateVars["Data"] = gin.H{ templateVars["Data"] = gin.H{
"TimeOffset": *user.TimeOffset, "Timezone": *user.Timezone,
"Devices": devices, "Devices": devices,
} }
@@ -941,7 +915,7 @@ func (api *API) appDemoModeError(c *gin.Context) {
appErrorPage(c, http.StatusUnauthorized, "Not Allowed in Demo Mode") appErrorPage(c, http.StatusUnauthorized, "Not Allowed in Demo Mode")
} }
func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStatsRow) error { func (api *API) getDocumentsWordCount(ctx context.Context, documents []database.GetDocumentsWithStatsRow) error {
// Do Transaction // Do Transaction
tx, err := api.db.DB.Begin() tx, err := api.db.DB.Begin()
if err != nil { if err != nil {
@@ -950,7 +924,11 @@ func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStats
} }
// Defer & Start Transaction // Defer & Start Transaction
defer tx.Rollback() defer func() {
if err := tx.Rollback(); err != nil {
log.Error("DB Rollback Error:", err)
}
}()
qtx := api.db.Queries.WithTx(tx) qtx := api.db.Queries.WithTx(tx)
for _, item := range documents { for _, item := range documents {
@@ -960,7 +938,7 @@ func (api *API) getDocumentsWordCount(documents []database.GetDocumentsWithStats
if err != nil { if err != nil {
log.Warn("Word Count Error: ", err) log.Warn("Word Count Error: ", err)
} else { } else {
if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ if _, err := qtx.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: item.ID, ID: item.ID,
Words: wordCount, Words: wordCount,
}); err != nil { }); err != nil {
@@ -999,7 +977,11 @@ func (api *API) getBaseTemplateVars(routeName string, c *gin.Context) (gin.H, au
func bindQueryParams(c *gin.Context, defaultLimit int64) queryParams { func bindQueryParams(c *gin.Context, defaultLimit int64) queryParams {
var qParams queryParams var qParams queryParams
c.BindQuery(&qParams) err := c.BindQuery(&qParams)
if err != nil {
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("Invalid Form Bind: %v", err))
return qParams
}
if qParams.Limit == nil { if qParams.Limit == nil {
qParams.Limit = &defaultLimit qParams.Limit = &defaultLimit
@@ -1017,7 +999,7 @@ func bindQueryParams(c *gin.Context, defaultLimit int64) queryParams {
} }
func appErrorPage(c *gin.Context, errorCode int, errorMessage string) { func appErrorPage(c *gin.Context, errorCode int, errorMessage string) {
var errorHuman string = "We're not even sure what happened." errorHuman := "We're not even sure what happened."
switch errorCode { switch errorCode {
case http.StatusInternalServerError: case http.StatusInternalServerError:
@@ -1039,11 +1021,11 @@ func appErrorPage(c *gin.Context, errorCode int, errorMessage string) {
func arrangeUserStatistics(userStatistics []database.GetUserStatisticsRow) gin.H { func arrangeUserStatistics(userStatistics []database.GetUserStatisticsRow) gin.H {
// Item Sorter // Item Sorter
sortItem := func(userStatistics []database.GetUserStatisticsRow, key string, less func(i int, j int) bool) []map[string]interface{} { sortItem := func(userStatistics []database.GetUserStatisticsRow, key string, less func(i int, j int) bool) []map[string]any {
sortedData := append([]database.GetUserStatisticsRow(nil), userStatistics...) sortedData := append([]database.GetUserStatisticsRow(nil), userStatistics...)
sort.SliceStable(sortedData, less) sort.SliceStable(sortedData, less)
newData := make([]map[string]interface{}, 0) newData := make([]map[string]any, 0)
for _, item := range sortedData { for _, item := range sortedData {
v := reflect.Indirect(reflect.ValueOf(item)) v := reflect.Indirect(reflect.ValueOf(item))
@@ -1059,7 +1041,7 @@ func arrangeUserStatistics(userStatistics []database.GetUserStatisticsRow) gin.H
value = niceNumbers(rawVal) value = niceNumbers(rawVal)
} }
newData = append(newData, map[string]interface{}{ newData = append(newData, map[string]any{
"UserID": item.UserID, "UserID": item.UserID,
"Value": value, "Value": value,
}) })

View File

@@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"crypto/md5" "crypto/md5"
"fmt" "fmt"
"net/http" "net/http"
@@ -28,22 +29,17 @@ type authKOHeader struct {
AuthKey string `header:"x-auth-key"` AuthKey string `header:"x-auth-key"`
} }
// OPDS Auth Headers func (api *API) authorizeCredentials(ctx context.Context, username string, password string) (auth *authData) {
type authOPDSHeader struct { user, err := api.db.Queries.GetUser(ctx, username)
Authorization string `header:"authorization"`
}
func (api *API) authorizeCredentials(username string, password string) (auth *authData) {
user, err := api.db.Queries.GetUser(api.db.Ctx, username)
if err != nil { if err != nil {
return return
} }
if match, err := argon2.ComparePasswordAndHash(password, *user.Pass); err != nil || match != true { if match, err := argon2.ComparePasswordAndHash(password, *user.Pass); err != nil || !match {
return return
} }
// Update Auth Cache // Update auth cache
api.userAuthCache[user.ID] = *user.AuthHash api.userAuthCache[user.ID] = *user.AuthHash
return &authData{ return &authData{
@@ -57,7 +53,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
// Check Session First // Check Session First
if auth, ok := api.getSession(session); ok == true { if auth, ok := api.getSession(c, session); ok {
c.Set("Authorization", auth) c.Set("Authorization", auth)
c.Header("Cache-Control", "private") c.Header("Cache-Control", "private")
c.Next() c.Next()
@@ -76,7 +72,7 @@ func (api *API) authKOMiddleware(c *gin.Context) {
return return
} }
authData := api.authorizeCredentials(rHeader.AuthUser, rHeader.AuthKey) authData := api.authorizeCredentials(c, rHeader.AuthUser, rHeader.AuthKey)
if authData == nil { if authData == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return return
@@ -98,14 +94,14 @@ func (api *API) authOPDSMiddleware(c *gin.Context) {
user, rawPassword, hasAuth := c.Request.BasicAuth() user, rawPassword, hasAuth := c.Request.BasicAuth()
// Validate Auth Fields // Validate Auth Fields
if hasAuth != true || user == "" || rawPassword == "" { if !hasAuth || user == "" || rawPassword == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization Headers"}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization Headers"})
return return
} }
// Validate Auth // Validate Auth
password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword))) password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword)))
authData := api.authorizeCredentials(user, password) authData := api.authorizeCredentials(c, user, password)
if authData == nil { if authData == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return return
@@ -120,7 +116,7 @@ func (api *API) authWebAppMiddleware(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
// Check Session // Check Session
if auth, ok := api.getSession(session); ok == true { if auth, ok := api.getSession(c, session); ok {
c.Set("Authorization", auth) c.Set("Authorization", auth)
c.Header("Cache-Control", "private") c.Header("Cache-Control", "private")
c.Next() c.Next()
@@ -129,13 +125,12 @@ func (api *API) authWebAppMiddleware(c *gin.Context) {
c.Redirect(http.StatusFound, "/login") c.Redirect(http.StatusFound, "/login")
c.Abort() c.Abort()
return
} }
func (api *API) authAdminWebAppMiddleware(c *gin.Context) { func (api *API) authAdminWebAppMiddleware(c *gin.Context) {
if data, _ := c.Get("Authorization"); data != nil { if data, _ := c.Get("Authorization"); data != nil {
auth := data.(authData) auth := data.(authData)
if auth.IsAdmin == true { if auth.IsAdmin {
c.Next() c.Next()
return return
} }
@@ -143,7 +138,6 @@ func (api *API) authAdminWebAppMiddleware(c *gin.Context) {
appErrorPage(c, http.StatusUnauthorized, "Admin Permissions Required") appErrorPage(c, http.StatusUnauthorized, "Admin Permissions Required")
c.Abort() c.Abort()
return
} }
func (api *API) appAuthLogin(c *gin.Context) { func (api *API) appAuthLogin(c *gin.Context) {
@@ -160,7 +154,7 @@ func (api *API) appAuthLogin(c *gin.Context) {
// MD5 - KOSync Compatiblity // MD5 - KOSync Compatiblity
password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword))) password := fmt.Sprintf("%x", md5.Sum([]byte(rawPassword)))
authData := api.authorizeCredentials(username, password) authData := api.authorizeCredentials(c, username, password)
if authData == nil { if authData == nil {
templateVars["Error"] = "Invalid Credentials" templateVars["Error"] = "Invalid Credentials"
c.HTML(http.StatusUnauthorized, "page/login", templateVars) c.HTML(http.StatusUnauthorized, "page/login", templateVars)
@@ -215,7 +209,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
} }
// Get current users // Get current users
currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx) currentUsers, err := api.db.Queries.GetUsers(c)
if err != nil { if err != nil {
log.Error("Failed to check all users: ", err) log.Error("Failed to check all users: ", err)
templateVars["Error"] = "Failed to Create User" templateVars["Error"] = "Failed to Create User"
@@ -231,7 +225,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
// Create user in DB // Create user in DB
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{ if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{
ID: username, ID: username,
Pass: &hashedPassword, Pass: &hashedPassword,
AuthHash: &authHash, AuthHash: &authHash,
@@ -249,7 +243,7 @@ func (api *API) appAuthRegister(c *gin.Context) {
} }
// Get user // Get user
user, err := api.db.Queries.GetUser(api.db.Ctx, username) user, err := api.db.Queries.GetUser(c, username)
if err != nil { if err != nil {
log.Error("GetUser DB Error:", err) log.Error("GetUser DB Error:", err)
templateVars["Error"] = "Registration Disabled or User Already Exists" templateVars["Error"] = "Registration Disabled or User Already Exists"
@@ -276,7 +270,10 @@ func (api *API) appAuthRegister(c *gin.Context) {
func (api *API) appAuthLogout(c *gin.Context) { func (api *API) appAuthLogout(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
session.Clear() session.Clear()
session.Save() if err := session.Save(); err != nil {
log.Error("unable to save session")
}
c.Redirect(http.StatusFound, "/login") c.Redirect(http.StatusFound, "/login")
} }
@@ -316,7 +313,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
} }
// Get current users // Get current users
currentUsers, err := api.db.Queries.GetUsers(api.db.Ctx) currentUsers, err := api.db.Queries.GetUsers(c)
if err != nil { if err != nil {
log.Error("Failed to check all users: ", err) log.Error("Failed to check all users: ", err)
apiErrorPage(c, http.StatusBadRequest, "Failed to Create User") apiErrorPage(c, http.StatusBadRequest, "Failed to Create User")
@@ -331,7 +328,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
// Create user // Create user
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
if rows, err := api.db.Queries.CreateUser(api.db.Ctx, database.CreateUserParams{ if rows, err := api.db.Queries.CreateUser(c, database.CreateUserParams{
ID: rUser.Username, ID: rUser.Username,
Pass: &hashedPassword, Pass: &hashedPassword,
AuthHash: &authHash, AuthHash: &authHash,
@@ -351,7 +348,7 @@ func (api *API) koAuthRegister(c *gin.Context) {
}) })
} }
func (api *API) getSession(session sessions.Session) (auth authData, ok bool) { func (api *API) getSession(ctx context.Context, session sessions.Session) (auth authData, ok bool) {
// Get Session // Get Session
authorizedUser := session.Get("authorizedUser") authorizedUser := session.Get("authorizedUser")
isAdmin := session.Get("isAdmin") isAdmin := session.Get("isAdmin")
@@ -369,7 +366,7 @@ func (api *API) getSession(session sessions.Session) (auth authData, ok bool) {
} }
// Validate Auth Hash // Validate Auth Hash
correctAuthHash, err := api.getUserAuthHash(auth.UserName) correctAuthHash, err := api.getUserAuthHash(ctx, auth.UserName)
if err != nil || correctAuthHash != auth.AuthHash { if err != nil || correctAuthHash != auth.AuthHash {
return return
} }
@@ -377,7 +374,10 @@ func (api *API) getSession(session sessions.Session) (auth authData, ok bool) {
// Refresh // Refresh
if expiresAt.(int64)-time.Now().Unix() < 60*60*24 { if expiresAt.(int64)-time.Now().Unix() < 60*60*24 {
log.Info("Refreshing Session") log.Info("Refreshing Session")
api.setSession(session, auth) if err := api.setSession(session, auth); err != nil {
log.Error("unable to get session")
return
}
} }
// Authorized // Authorized
@@ -394,14 +394,14 @@ func (api *API) setSession(session sessions.Session, auth authData) error {
return session.Save() return session.Save()
} }
func (api *API) getUserAuthHash(username string) (string, error) { func (api *API) getUserAuthHash(ctx context.Context, username string) (string, error) {
// Return Cache // Return Cache
if api.userAuthCache[username] != "" { if api.userAuthCache[username] != "" {
return api.userAuthCache[username], nil return api.userAuthCache[username], nil
} }
// Get DB // Get DB
user, err := api.db.Queries.GetUser(api.db.Ctx, username) user, err := api.db.Queries.GetUser(ctx, username)
if err != nil { if err != nil {
log.Error("GetUser DB Error:", err) log.Error("GetUser DB Error:", err)
return "", err return "", err
@@ -413,31 +413,7 @@ func (api *API) getUserAuthHash(username string) (string, error) {
return api.userAuthCache[username], nil return api.userAuthCache[username], nil
} }
func (api *API) rotateUserAuthHash(username string) error { func (api *API) rotateAllAuthHashes(ctx context.Context) error {
// Generate Auth Hash
rawAuthHash, err := utils.GenerateToken(64)
if err != nil {
log.Error("Failed to generate user token: ", err)
return err
}
// Update User
authHash := fmt.Sprintf("%x", rawAuthHash)
if _, err = api.db.Queries.UpdateUser(api.db.Ctx, database.UpdateUserParams{
UserID: username,
AuthHash: &authHash,
}); err != nil {
log.Error("UpdateUser DB Error: ", err)
return err
}
// Update Cache
api.userAuthCache[username] = fmt.Sprintf("%x", rawAuthHash)
return nil
}
func (api *API) rotateAllAuthHashes() error {
// Do Transaction // Do Transaction
tx, err := api.db.DB.Begin() tx, err := api.db.DB.Begin()
if err != nil { if err != nil {
@@ -446,15 +422,20 @@ func (api *API) rotateAllAuthHashes() error {
} }
// Defer & Start Transaction // Defer & Start Transaction
defer tx.Rollback() defer func() {
if err := tx.Rollback(); err != nil {
log.Error("DB Rollback Error:", err)
}
}()
qtx := api.db.Queries.WithTx(tx) qtx := api.db.Queries.WithTx(tx)
users, err := qtx.GetUsers(api.db.Ctx) users, err := qtx.GetUsers(ctx)
if err != nil { if err != nil {
return err return err
} }
// Update users // Update Users
newAuthHashCache := make(map[string]string, 0)
for _, user := range users { for _, user := range users {
// Generate Auth Hash // Generate Auth Hash
rawAuthHash, err := utils.GenerateToken(64) rawAuthHash, err := utils.GenerateToken(64)
@@ -464,15 +445,16 @@ func (api *API) rotateAllAuthHashes() error {
// Update User // Update User
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
if _, err = qtx.UpdateUser(api.db.Ctx, database.UpdateUserParams{ if _, err = qtx.UpdateUser(ctx, database.UpdateUserParams{
UserID: user.ID, UserID: user.ID,
AuthHash: &authHash, AuthHash: &authHash,
Admin: user.Admin,
}); err != nil { }); err != nil {
return err return err
} }
// Update Cache // Save New Hash Cache
api.userAuthCache[user.ID] = fmt.Sprintf("%x", rawAuthHash) newAuthHashCache[user.ID] = fmt.Sprintf("%x", rawAuthHash)
} }
// Commit Transaction // Commit Transaction
@@ -481,5 +463,10 @@ func (api *API) rotateAllAuthHashes() error {
return err return err
} }
// Transaction Succeeded -> Update Cache
for user, hash := range newAuthHashCache {
api.userAuthCache[user] = hash
}
return nil return nil
} }

View File

@@ -2,11 +2,12 @@ package api
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database" "reichard.io/antholume/database"
"reichard.io/antholume/metadata" "reichard.io/antholume/metadata"
) )
@@ -21,7 +22,7 @@ func (api *API) createDownloadDocumentHandler(errorFunc func(*gin.Context, int,
} }
// Get Document // Get Document
document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID) document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID)
if err != nil { if err != nil {
log.Error("GetDocument DB Error:", err) log.Error("GetDocument DB Error:", err)
errorFunc(c, http.StatusBadRequest, "Unknown Document") errorFunc(c, http.StatusBadRequest, "Unknown Document")
@@ -34,8 +35,14 @@ func (api *API) createDownloadDocumentHandler(errorFunc func(*gin.Context, int,
return return
} }
// Derive Basepath
basepath := filepath.Join(api.cfg.DataPath, "documents")
if document.Basepath != nil && *document.Basepath != "" {
basepath = *document.Basepath
}
// Derive Storage Location // Derive Storage Location
filePath := filepath.Join(api.cfg.DataPath, "documents", *document.Filepath) filePath := filepath.Join(basepath, *document.Filepath)
// Validate File Exists // Validate File Exists
_, err = os.Stat(filePath) _, err = os.Stat(filePath)
@@ -61,7 +68,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Validate Document Exists in DB // Validate Document Exists in DB
document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID) document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID)
if err != nil { if err != nil {
log.Error("GetDocument DB Error:", err) log.Error("GetDocument DB Error:", err)
errorFunc(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err)) errorFunc(c, http.StatusInternalServerError, fmt.Sprintf("GetDocument DB Error: %v", err))
@@ -110,7 +117,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Store First Metadata Result // Store First Metadata Result
if _, err = api.db.Queries.AddMetadata(api.db.Ctx, database.AddMetadataParams{ if _, err = api.db.Queries.AddMetadata(c, database.AddMetadataParams{
DocumentID: document.ID, DocumentID: document.ID,
Title: firstResult.Title, Title: firstResult.Title,
Author: firstResult.Author, Author: firstResult.Author,
@@ -125,7 +132,7 @@ func (api *API) createGetCoverHandler(errorFunc func(*gin.Context, int, string))
} }
// Upsert Document // Upsert Document
if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: document.ID, ID: document.ID,
Coverfile: &coverFile, Coverfile: &coverFile,
}); err != nil { }); err != nil {

View File

@@ -72,7 +72,7 @@ type requestDocumentID struct {
} }
func (api *API) koAuthorizeUser(c *gin.Context) { func (api *API) koAuthorizeUser(c *gin.Context) {
c.JSON(200, gin.H{ koJSON(c, 200, gin.H{
"authorized": "OK", "authorized": "OK",
}) })
} }
@@ -91,7 +91,7 @@ func (api *API) koSetProgress(c *gin.Context) {
} }
// Upsert Device // Upsert Device
if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ if _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{
ID: rPosition.DeviceID, ID: rPosition.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rPosition.Device, DeviceName: rPosition.Device,
@@ -101,14 +101,14 @@ func (api *API) koSetProgress(c *gin.Context) {
} }
// Upsert Document // Upsert Document
if _, err := api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ if _, err := api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: rPosition.DocumentID, ID: rPosition.DocumentID,
}); err != nil { }); err != nil {
log.Error("UpsertDocument DB Error:", err) log.Error("UpsertDocument DB Error:", err)
} }
// Create or Replace Progress // Create or Replace Progress
progress, err := api.db.Queries.UpdateProgress(api.db.Ctx, database.UpdateProgressParams{ progress, err := api.db.Queries.UpdateProgress(c, database.UpdateProgressParams{
Percentage: rPosition.Percentage, Percentage: rPosition.Percentage,
DocumentID: rPosition.DocumentID, DocumentID: rPosition.DocumentID,
DeviceID: rPosition.DeviceID, DeviceID: rPosition.DeviceID,
@@ -121,7 +121,7 @@ func (api *API) koSetProgress(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, gin.H{ koJSON(c, http.StatusOK, gin.H{
"document": progress.DocumentID, "document": progress.DocumentID,
"timestamp": progress.CreatedAt, "timestamp": progress.CreatedAt,
}) })
@@ -140,14 +140,14 @@ func (api *API) koGetProgress(c *gin.Context) {
return return
} }
progress, err := api.db.Queries.GetDocumentProgress(api.db.Ctx, database.GetDocumentProgressParams{ progress, err := api.db.Queries.GetDocumentProgress(c, database.GetDocumentProgressParams{
DocumentID: rDocID.DocumentID, DocumentID: rDocID.DocumentID,
UserID: auth.UserName, UserID: auth.UserName,
}) })
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// Not Found // Not Found
c.JSON(http.StatusOK, gin.H{}) koJSON(c, http.StatusOK, gin.H{})
return return
} else if err != nil { } else if err != nil {
log.Error("GetDocumentProgress DB Error:", err) log.Error("GetDocumentProgress DB Error:", err)
@@ -155,7 +155,7 @@ func (api *API) koGetProgress(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, gin.H{ koJSON(c, http.StatusOK, gin.H{
"document": progress.DocumentID, "document": progress.DocumentID,
"percentage": progress.Percentage, "percentage": progress.Percentage,
"progress": progress.Progress, "progress": progress.Progress,
@@ -193,12 +193,16 @@ func (api *API) koAddActivities(c *gin.Context) {
allDocuments := getKeys(allDocumentsMap) allDocuments := getKeys(allDocumentsMap)
// Defer & Start Transaction // Defer & Start Transaction
defer tx.Rollback() defer func() {
if err := tx.Rollback(); err != nil {
log.Error("DB Rollback Error:", err)
}
}()
qtx := api.db.Queries.WithTx(tx) qtx := api.db.Queries.WithTx(tx)
// Upsert Documents // Upsert Documents
for _, doc := range allDocuments { for _, doc := range allDocuments {
if _, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ if _, err := qtx.UpsertDocument(c, database.UpsertDocumentParams{
ID: doc, ID: doc,
}); err != nil { }); err != nil {
log.Error("UpsertDocument DB Error:", err) log.Error("UpsertDocument DB Error:", err)
@@ -208,7 +212,7 @@ func (api *API) koAddActivities(c *gin.Context) {
} }
// Upsert Device // Upsert Device
if _, err = qtx.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ if _, err = qtx.UpsertDevice(c, database.UpsertDeviceParams{
ID: rActivity.DeviceID, ID: rActivity.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rActivity.Device, DeviceName: rActivity.Device,
@@ -221,7 +225,7 @@ func (api *API) koAddActivities(c *gin.Context) {
// Add All Activity // Add All Activity
for _, item := range rActivity.Activity { for _, item := range rActivity.Activity {
if _, err := qtx.AddActivity(api.db.Ctx, database.AddActivityParams{ if _, err := qtx.AddActivity(c, database.AddActivityParams{
UserID: auth.UserName, UserID: auth.UserName,
DocumentID: item.DocumentID, DocumentID: item.DocumentID,
DeviceID: rActivity.DeviceID, DeviceID: rActivity.DeviceID,
@@ -243,7 +247,7 @@ func (api *API) koAddActivities(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, gin.H{ koJSON(c, http.StatusOK, gin.H{
"added": len(rActivity.Activity), "added": len(rActivity.Activity),
}) })
} }
@@ -262,7 +266,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) {
} }
// Upsert Device // Upsert Device
if _, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ if _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{
ID: rCheckActivity.DeviceID, ID: rCheckActivity.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rCheckActivity.Device, DeviceName: rCheckActivity.Device,
@@ -274,7 +278,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) {
} }
// Get Last Device Activity // Get Last Device Activity
lastActivity, err := api.db.Queries.GetLastActivity(api.db.Ctx, database.GetLastActivityParams{ lastActivity, err := api.db.Queries.GetLastActivity(c, database.GetLastActivityParams{
UserID: auth.UserName, UserID: auth.UserName,
DeviceID: rCheckActivity.DeviceID, DeviceID: rCheckActivity.DeviceID,
}) })
@@ -294,7 +298,7 @@ func (api *API) koCheckActivitySync(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, gin.H{ koJSON(c, http.StatusOK, gin.H{
"last_sync": parsedTime.Unix(), "last_sync": parsedTime.Unix(),
}) })
} }
@@ -316,12 +320,16 @@ func (api *API) koAddDocuments(c *gin.Context) {
} }
// Defer & Start Transaction // Defer & Start Transaction
defer tx.Rollback() defer func() {
if err := tx.Rollback(); err != nil {
log.Error("DB Rollback Error:", err)
}
}()
qtx := api.db.Queries.WithTx(tx) qtx := api.db.Queries.WithTx(tx)
// Upsert Documents // Upsert Documents
for _, doc := range rNewDocs.Documents { for _, doc := range rNewDocs.Documents {
_, err := qtx.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ _, err := qtx.UpsertDocument(c, database.UpsertDocumentParams{
ID: doc.ID, ID: doc.ID,
Title: api.sanitizeInput(doc.Title), Title: api.sanitizeInput(doc.Title),
Author: api.sanitizeInput(doc.Author), Author: api.sanitizeInput(doc.Author),
@@ -344,7 +352,7 @@ func (api *API) koAddDocuments(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, gin.H{ koJSON(c, http.StatusOK, gin.H{
"changed": len(rNewDocs.Documents), "changed": len(rNewDocs.Documents),
}) })
} }
@@ -363,7 +371,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
} }
// Upsert Device // Upsert Device
_, err := api.db.Queries.UpsertDevice(api.db.Ctx, database.UpsertDeviceParams{ _, err := api.db.Queries.UpsertDevice(c, database.UpsertDeviceParams{
ID: rCheckDocs.DeviceID, ID: rCheckDocs.DeviceID,
UserID: auth.UserName, UserID: auth.UserName,
DeviceName: rCheckDocs.Device, DeviceName: rCheckDocs.Device,
@@ -375,11 +383,8 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
return return
} }
missingDocs := []database.Document{}
deletedDocIDs := []string{}
// Get Missing Documents // Get Missing Documents
missingDocs, err = api.db.Queries.GetMissingDocuments(api.db.Ctx, rCheckDocs.Have) missingDocs, err := api.db.Queries.GetMissingDocuments(c, rCheckDocs.Have)
if err != nil { if err != nil {
log.Error("GetMissingDocuments DB Error", err) log.Error("GetMissingDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request") apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@@ -387,7 +392,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
} }
// Get Deleted Documents // Get Deleted Documents
deletedDocIDs, err = api.db.Queries.GetDeletedDocuments(api.db.Ctx, rCheckDocs.Have) deletedDocIDs, err := api.db.Queries.GetDeletedDocuments(c, rCheckDocs.Have)
if err != nil { if err != nil {
log.Error("GetDeletedDocuments DB Error", err) log.Error("GetDeletedDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request") apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@@ -402,7 +407,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
return return
} }
wantedDocs, err := api.db.Queries.GetWantedDocuments(api.db.Ctx, string(jsonHaves)) wantedDocs, err := api.db.Queries.GetWantedDocuments(c, string(jsonHaves))
if err != nil { if err != nil {
log.Error("GetWantedDocuments DB Error", err) log.Error("GetWantedDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request") apiErrorPage(c, http.StatusBadRequest, "Invalid Request")
@@ -442,7 +447,7 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
rCheckDocSync.Delete = deletedDocIDs rCheckDocSync.Delete = deletedDocIDs
} }
c.JSON(http.StatusOK, rCheckDocSync) koJSON(c, http.StatusOK, rCheckDocSync)
} }
func (api *API) koUploadExistingDocument(c *gin.Context) { func (api *API) koUploadExistingDocument(c *gin.Context) {
@@ -462,7 +467,7 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
} }
// Validate Document Exists in DB // Validate Document Exists in DB
document, err := api.db.Queries.GetDocument(api.db.Ctx, rDoc.DocumentID) document, err := api.db.Queries.GetDocument(c, rDoc.DocumentID)
if err != nil { if err != nil {
log.Error("GetDocument DB Error:", err) log.Error("GetDocument DB Error:", err)
apiErrorPage(c, http.StatusBadRequest, "Unknown Document") apiErrorPage(c, http.StatusBadRequest, "Unknown Document")
@@ -494,7 +499,8 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
}) })
// Generate Storage Path // Generate Storage Path
safePath := filepath.Join(api.cfg.DataPath, "documents", fileName) basePath := filepath.Join(api.cfg.DataPath, "documents")
safePath := filepath.Join(basePath, fileName)
// Save & Prevent Overwrites // Save & Prevent Overwrites
_, err = os.Stat(safePath) _, err = os.Stat(safePath)
@@ -516,18 +522,19 @@ func (api *API) koUploadExistingDocument(c *gin.Context) {
} }
// Upsert Document // Upsert Document
if _, err = api.db.Queries.UpsertDocument(api.db.Ctx, database.UpsertDocumentParams{ if _, err = api.db.Queries.UpsertDocument(c, database.UpsertDocumentParams{
ID: document.ID, ID: document.ID,
Md5: metadataInfo.MD5, Md5: metadataInfo.MD5,
Words: metadataInfo.WordCount, Words: metadataInfo.WordCount,
Filepath: &fileName, Filepath: &fileName,
Basepath: &basePath,
}); err != nil { }); err != nil {
log.Error("UpsertDocument DB Error:", err) log.Error("UpsertDocument DB Error:", err)
apiErrorPage(c, http.StatusBadRequest, "Document Error") apiErrorPage(c, http.StatusBadRequest, "Document Error")
return return
} }
c.JSON(http.StatusOK, gin.H{ koJSON(c, http.StatusOK, gin.H{
"status": "ok", "status": "ok",
}) })
} }
@@ -582,3 +589,10 @@ func getFileMD5(filePath string) (*string, error) {
return &fileHash, nil return &fileHash, nil
} }
// koJSON forces koJSON Content-Type to only return `application/json`. This is addressing
// the following issue: https://github.com/koreader/koreader/issues/13629
func koJSON(c *gin.Context, code int, obj any) {
c.Header("Content-Type", "application/json")
c.JSON(code, obj)
}

View File

@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/antholume/database" "reichard.io/antholume/database"
"reichard.io/antholume/opds" "reichard.io/antholume/opds"
"reichard.io/antholume/pkg/ptr"
) )
var mimeMapping map[string]string = map[string]string{ var mimeMapping map[string]string = map[string]string{
@@ -77,9 +78,10 @@ func (api *API) opdsDocuments(c *gin.Context) {
} }
// Get Documents // Get Documents
documents, err := api.db.Queries.GetDocumentsWithStats(api.db.Ctx, database.GetDocumentsWithStatsParams{ documents, err := api.db.Queries.GetDocumentsWithStats(c, database.GetDocumentsWithStatsParams{
UserID: auth.UserName, UserID: auth.UserName,
Query: query, Query: query,
Deleted: ptr.Of(false),
Offset: (*qParams.Page - 1) * *qParams.Limit, Offset: (*qParams.Page - 1) * *qParams.Limit,
Limit: *qParams.Limit, Limit: *qParams.Limit,
}) })

View File

@@ -13,56 +13,49 @@ import (
"reichard.io/antholume/metadata" "reichard.io/antholume/metadata"
) )
type UTCOffset struct { // getTimeZones returns a string slice of IANA timezones.
Name string func getTimeZones() []string {
Value string return []string{
} "Africa/Cairo",
"Africa/Johannesburg",
var UTC_OFFSETS = []UTCOffset{ "Africa/Lagos",
{Value: "-12 hours", Name: "UTC12:00"}, "Africa/Nairobi",
{Value: "-11 hours", Name: "UTC11:00"}, "America/Adak",
{Value: "-10 hours", Name: "UTC10:00"}, "America/Anchorage",
{Value: "-9.5 hours", Name: "UTC09:30"}, "America/Buenos_Aires",
{Value: "-9 hours", Name: "UTC09:00"}, "America/Chicago",
{Value: "-8 hours", Name: "UTC08:00"}, "America/Denver",
{Value: "-7 hours", Name: "UTC07:00"}, "America/Los_Angeles",
{Value: "-6 hours", Name: "UTC06:00"}, "America/Mexico_City",
{Value: "-5 hours", Name: "UTC05:00"}, "America/New_York",
{Value: "-4 hours", Name: "UTC04:00"}, "America/Nuuk",
{Value: "-3.5 hours", Name: "UTC03:30"}, "America/Phoenix",
{Value: "-3 hours", Name: "UTC03:00"}, "America/Puerto_Rico",
{Value: "-2 hours", Name: "UTC02:00"}, "America/Sao_Paulo",
{Value: "-1 hours", Name: "UTC01:00"}, "America/St_Johns",
{Value: "0 hours", Name: "UTC±00:00"}, "America/Toronto",
{Value: "+1 hours", Name: "UTC+01:00"}, "Asia/Dubai",
{Value: "+2 hours", Name: "UTC+02:00"}, "Asia/Hong_Kong",
{Value: "+3 hours", Name: "UTC+03:00"}, "Asia/Kolkata",
{Value: "+3.5 hours", Name: "UTC+03:30"}, "Asia/Seoul",
{Value: "+4 hours", Name: "UTC+04:00"}, "Asia/Shanghai",
{Value: "+4.5 hours", Name: "UTC+04:30"}, "Asia/Singapore",
{Value: "+5 hours", Name: "UTC+05:00"}, "Asia/Tokyo",
{Value: "+5.5 hours", Name: "UTC+05:30"}, "Atlantic/Azores",
{Value: "+5.75 hours", Name: "UTC+05:45"}, "Australia/Melbourne",
{Value: "+6 hours", Name: "UTC+06:00"}, "Australia/Sydney",
{Value: "+6.5 hours", Name: "UTC+06:30"}, "Europe/Berlin",
{Value: "+7 hours", Name: "UTC+07:00"}, "Europe/London",
{Value: "+8 hours", Name: "UTC+08:00"}, "Europe/Moscow",
{Value: "+8.75 hours", Name: "UTC+08:45"}, "Europe/Paris",
{Value: "+9 hours", Name: "UTC+09:00"}, "Pacific/Auckland",
{Value: "+9.5 hours", Name: "UTC+09:30"}, "Pacific/Honolulu",
{Value: "+10 hours", Name: "UTC+10:00"}, }
{Value: "+10.5 hours", Name: "UTC+10:30"},
{Value: "+11 hours", Name: "UTC+11:00"},
{Value: "+12 hours", Name: "UTC+12:00"},
{Value: "+12.75 hours", Name: "UTC+12:45"},
{Value: "+13 hours", Name: "UTC+13:00"},
{Value: "+14 hours", Name: "UTC+14:00"},
}
func getUTCOffsets() []UTCOffset {
return UTC_OFFSETS
} }
// niceSeconds takes in an int (in seconds) and returns a string readable
// representation. For example 1928371 -> "22d 7h 39m 31s".
// Deprecated: Use formatters.FormatDuration
func niceSeconds(input int64) (result string) { func niceSeconds(input int64) (result string) {
if input == 0 { if input == 0 {
return "N/A" return "N/A"
@@ -91,6 +84,9 @@ func niceSeconds(input int64) (result string) {
return return
} }
// niceNumbers takes in an int and returns a string representation. For example
// 19823 -> "19.8k".
// Deprecated: Use formatters.FormatNumber
func niceNumbers(input int64) string { func niceNumbers(input int64) string {
if input == 0 { if input == 0 {
return "0" return "0"
@@ -109,7 +105,8 @@ func niceNumbers(input int64) string {
} }
} }
// Convert Database Array -> Int64 Array // getSVGGraphData builds SVGGraphData from the provided stats, width and height.
// It is used exclusively in templates to generate the daily read stats graph.
func getSVGGraphData(inputData []database.GetDailyReadStatsRow, svgWidth int, svgHeight int) graph.SVGGraphData { func getSVGGraphData(inputData []database.GetDailyReadStatsRow, svgWidth int, svgHeight int) graph.SVGGraphData {
var intData []int64 var intData []int64
for _, item := range inputData { for _, item := range inputData {
@@ -119,11 +116,13 @@ func getSVGGraphData(inputData []database.GetDailyReadStatsRow, svgWidth int, sv
return graph.GetSVGGraphData(intData, svgWidth, svgHeight) return graph.GetSVGGraphData(intData, svgWidth, svgHeight)
} }
func dict(values ...interface{}) (map[string]interface{}, error) { // dict returns a map[string]any dict. Each pair of two is a key & value
// respectively. It's primarily utilized in templates.
func dict(values ...any) (map[string]any, error) {
if len(values)%2 != 0 { if len(values)%2 != 0 {
return nil, errors.New("invalid dict call") return nil, errors.New("invalid dict call")
} }
dict := make(map[string]interface{}, len(values)/2) dict := make(map[string]any, len(values)/2)
for i := 0; i < len(values); i += 2 { for i := 0; i < len(values); i += 2 {
key, ok := values[i].(string) key, ok := values[i].(string)
if !ok { if !ok {
@@ -134,12 +133,14 @@ func dict(values ...interface{}) (map[string]interface{}, error) {
return dict, nil return dict, nil
} }
func fields(value interface{}) (map[string]interface{}, error) { // fields returns a map[string]any of the provided struct. It's primarily
// utilized in templates.
func fields(value any) (map[string]any, error) {
v := reflect.Indirect(reflect.ValueOf(value)) v := reflect.Indirect(reflect.ValueOf(value))
if v.Kind() != reflect.Struct { if v.Kind() != reflect.Struct {
return nil, fmt.Errorf("%T is not a struct", value) return nil, fmt.Errorf("%T is not a struct", value)
} }
m := make(map[string]interface{}) m := make(map[string]any)
t := v.Type() t := v.Type()
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
sv := t.Field(i) sv := t.Field(i)
@@ -148,6 +149,13 @@ func fields(value interface{}) (map[string]interface{}, error) {
return m, nil return m, nil
} }
// slice returns a slice of the provided arguments. It's primarily utilized in
// templates.
func slice(elements ...any) []any {
return elements
}
// deriveBaseFileName builds the base filename for a given MetadataInfo object.
func deriveBaseFileName(metadataInfo *metadata.MetadataInfo) string { func deriveBaseFileName(metadataInfo *metadata.MetadataInfo) string {
// Derive New FileName // Derive New FileName
var newFileName string var newFileName string
@@ -166,3 +174,15 @@ func deriveBaseFileName(metadataInfo *metadata.MetadataInfo) string {
fileName := strings.ReplaceAll(newFileName, "/", "") fileName := strings.ReplaceAll(newFileName, "/", "")
return "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, *metadataInfo.PartialMD5, metadataInfo.Type)) return "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, *metadataInfo.PartialMD5, metadataInfo.Type))
} }
// importStatusPriority returns the order priority for import status in the UI.
func importStatusPriority(status importStatus) int {
switch status {
case importFailed:
return 1
case importExists:
return 2
default:
return 3
}
}

151
api/v1/activity.go Normal file
View File

@@ -0,0 +1,151 @@
package v1
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
)
// GET /activity
func (s *Server) GetActivity(ctx context.Context, request GetActivityRequestObject) (GetActivityResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetActivity401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
docFilter := false
if request.Params.DocFilter != nil {
docFilter = *request.Params.DocFilter
}
documentID := ""
if request.Params.DocumentId != nil {
documentID = *request.Params.DocumentId
}
offset := int64(0)
if request.Params.Offset != nil {
offset = *request.Params.Offset
}
limit := int64(100)
if request.Params.Limit != nil {
limit = *request.Params.Limit
}
activities, err := s.db.Queries.GetActivity(ctx, database.GetActivityParams{
UserID: auth.UserName,
DocFilter: docFilter,
DocumentID: documentID,
Offset: offset,
Limit: limit,
})
if err != nil {
return GetActivity500JSONResponse{Code: 500, Message: err.Error()}, nil
}
apiActivities := make([]Activity, len(activities))
for i, a := range activities {
// Convert StartTime from interface{} to string
startTimeStr := ""
if a.StartTime != nil {
if str, ok := a.StartTime.(string); ok {
startTimeStr = str
}
}
apiActivities[i] = Activity{
DocumentId: a.DocumentID,
DeviceId: a.DeviceID,
StartTime: startTimeStr,
Title: a.Title,
Author: a.Author,
Duration: a.Duration,
StartPercentage: float32(a.StartPercentage),
EndPercentage: float32(a.EndPercentage),
ReadPercentage: float32(a.ReadPercentage),
}
}
response := ActivityResponse{
Activities: apiActivities,
}
return GetActivity200JSONResponse(response), nil
}
// POST /activity
func (s *Server) CreateActivity(ctx context.Context, request CreateActivityRequestObject) (CreateActivityResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return CreateActivity401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return CreateActivity400JSONResponse{Code: 400, Message: "Request body is required"}, nil
}
tx, err := s.db.DB.Begin()
if err != nil {
log.Error("Transaction Begin DB Error:", err)
return CreateActivity500JSONResponse{Code: 500, Message: "Database error"}, nil
}
committed := false
defer func() {
if committed {
return
}
if rollbackErr := tx.Rollback(); rollbackErr != nil {
log.Debug("Transaction Rollback DB Error:", rollbackErr)
}
}()
qtx := s.db.Queries.WithTx(tx)
allDocumentsMap := make(map[string]struct{})
for _, item := range request.Body.Activity {
allDocumentsMap[item.DocumentId] = struct{}{}
}
for documentID := range allDocumentsMap {
if _, err := qtx.UpsertDocument(ctx, database.UpsertDocumentParams{ID: documentID}); err != nil {
log.Error("UpsertDocument DB Error:", err)
return CreateActivity400JSONResponse{Code: 400, Message: "Invalid document"}, nil
}
}
if _, err := qtx.UpsertDevice(ctx, database.UpsertDeviceParams{
ID: request.Body.DeviceId,
UserID: auth.UserName,
DeviceName: request.Body.DeviceName,
LastSynced: time.Now().UTC().Format(time.RFC3339),
}); err != nil {
log.Error("UpsertDevice DB Error:", err)
return CreateActivity400JSONResponse{Code: 400, Message: "Invalid device"}, nil
}
for _, item := range request.Body.Activity {
if _, err := qtx.AddActivity(ctx, database.AddActivityParams{
UserID: auth.UserName,
DocumentID: item.DocumentId,
DeviceID: request.Body.DeviceId,
StartTime: time.Unix(item.StartTime, 0).UTC().Format(time.RFC3339),
Duration: item.Duration,
StartPercentage: float64(item.Page) / float64(item.Pages),
EndPercentage: float64(item.Page+1) / float64(item.Pages),
}); err != nil {
log.Error("AddActivity DB Error:", err)
return CreateActivity400JSONResponse{Code: 400, Message: "Invalid activity"}, nil
}
}
if err := tx.Commit(); err != nil {
log.Error("Transaction Commit DB Error:", err)
return CreateActivity500JSONResponse{Code: 500, Message: "Database error"}, nil
}
committed = true
response := CreateActivityResponse{Added: int64(len(request.Body.Activity))}
return CreateActivity200JSONResponse(response), nil
}

1070
api/v1/admin.go Normal file

File diff suppressed because it is too large Load Diff

152
api/v1/admin_test.go Normal file
View File

@@ -0,0 +1,152 @@
package v1
import (
"bytes"
"context"
"crypto/md5"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
argon2 "github.com/alexedwards/argon2id"
"github.com/stretchr/testify/require"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
)
func createAdminTestUser(t *testing.T, db *database.DBManager, username, password string) {
t.Helper()
md5Hash := fmt.Sprintf("%x", md5.Sum([]byte(password)))
hashedPassword, err := argon2.CreateHash(md5Hash, argon2.DefaultParams)
require.NoError(t, err)
authHash := "test-auth-hash"
_, err = db.Queries.CreateUser(context.Background(), database.CreateUserParams{
ID: username,
Pass: &hashedPassword,
AuthHash: &authHash,
Admin: true,
})
require.NoError(t, err)
}
func loginAdminTestUser(t *testing.T, srv *Server, username, password string) *http.Cookie {
t.Helper()
body, err := json.Marshal(LoginRequest{Username: username, Password: password})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
cookies := w.Result().Cookies()
require.Len(t, cookies, 1)
return cookies[0]
}
func TestGetLogsPagination(t *testing.T) {
configPath := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(configPath, "logs"), 0o755))
require.NoError(t, os.WriteFile(filepath.Join(configPath, "logs", "antholume.log"), []byte(
"{\"level\":\"info\",\"msg\":\"one\"}\n"+
"plain two\n"+
"{\"level\":\"error\",\"msg\":\"three\"}\n"+
"plain four\n",
), 0o644))
cfg := &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: configPath,
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
db := database.NewMgr(cfg)
srv := NewServer(db, cfg, nil)
createAdminTestUser(t, db, "admin", "password")
cookie := loginAdminTestUser(t, srv, "admin", "password")
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/logs?page=2&limit=2", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp LogsResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
require.NotNil(t, resp.Logs)
require.Len(t, *resp.Logs, 2)
require.NotNil(t, resp.Page)
require.Equal(t, int64(2), *resp.Page)
require.NotNil(t, resp.Limit)
require.Equal(t, int64(2), *resp.Limit)
require.NotNil(t, resp.Total)
require.Equal(t, int64(4), *resp.Total)
require.Nil(t, resp.NextPage)
require.NotNil(t, resp.PreviousPage)
require.Equal(t, int64(1), *resp.PreviousPage)
require.Contains(t, (*resp.Logs)[0], "three")
require.Contains(t, (*resp.Logs)[1], "plain four")
}
func TestGetLogsPaginationWithBasicFilter(t *testing.T) {
configPath := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(configPath, "logs"), 0o755))
require.NoError(t, os.WriteFile(filepath.Join(configPath, "logs", "antholume.log"), []byte(
"{\"level\":\"info\",\"msg\":\"match-1\"}\n"+
"{\"level\":\"info\",\"msg\":\"skip\"}\n"+
"plain match-2\n"+
"{\"level\":\"info\",\"msg\":\"match-3\"}\n",
), 0o644))
cfg := &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: configPath,
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
db := database.NewMgr(cfg)
srv := NewServer(db, cfg, nil)
createAdminTestUser(t, db, "admin", "password")
cookie := loginAdminTestUser(t, srv, "admin", "password")
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/logs?filter=%22match%22&page=1&limit=2", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp LogsResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
require.NotNil(t, resp.Logs)
require.Len(t, *resp.Logs, 2)
require.NotNil(t, resp.Total)
require.Equal(t, int64(3), *resp.Total)
require.NotNil(t, resp.NextPage)
require.Equal(t, int64(2), *resp.NextPage)
}

4146
api/v1/api.gen.go Normal file

File diff suppressed because it is too large Load Diff

286
api/v1/auth.go Normal file
View File

@@ -0,0 +1,286 @@
package v1
import (
"context"
"crypto/md5"
"fmt"
"net/http"
"time"
argon2 "github.com/alexedwards/argon2id"
"github.com/gorilla/sessions"
log "github.com/sirupsen/logrus"
)
// POST /auth/login
func (s *Server) Login(ctx context.Context, request LoginRequestObject) (LoginResponseObject, error) {
if request.Body == nil {
return Login400JSONResponse{Code: 400, Message: "Invalid request body"}, nil
}
req := *request.Body
if req.Username == "" || req.Password == "" {
return Login400JSONResponse{Code: 400, Message: "Invalid credentials"}, nil
}
// MD5 - KOSync compatibility
password := fmt.Sprintf("%x", md5.Sum([]byte(req.Password)))
// Verify credentials
user, err := s.db.Queries.GetUser(ctx, req.Username)
if err != nil {
return Login401JSONResponse{Code: 401, Message: "Invalid credentials"}, nil
}
if match, err := argon2.ComparePasswordAndHash(password, *user.Pass); err != nil || !match {
return Login401JSONResponse{Code: 401, Message: "Invalid credentials"}, nil
}
if err := s.saveUserSession(ctx, user.ID, user.Admin, *user.AuthHash); err != nil {
return Login500JSONResponse{Code: 500, Message: err.Error()}, nil
}
return Login200JSONResponse{
Body: LoginResponse{
Username: user.ID,
IsAdmin: user.Admin,
},
Headers: Login200ResponseHeaders{
SetCookie: s.getSetCookieFromContext(ctx),
},
}, nil
}
// POST /auth/register
func (s *Server) Register(ctx context.Context, request RegisterRequestObject) (RegisterResponseObject, error) {
if !s.cfg.RegistrationEnabled {
return Register403JSONResponse{Code: 403, Message: "Registration is disabled"}, nil
}
if request.Body == nil {
return Register400JSONResponse{Code: 400, Message: "Invalid request body"}, nil
}
req := *request.Body
if req.Username == "" || req.Password == "" {
return Register400JSONResponse{Code: 400, Message: "Invalid user or password"}, nil
}
currentUsers, err := s.db.Queries.GetUsers(ctx)
if err != nil {
return Register500JSONResponse{Code: 500, Message: "Failed to create user"}, nil
}
isAdmin := len(currentUsers) == 0
if err := s.createUser(ctx, req.Username, &req.Password, &isAdmin); err != nil {
return Register400JSONResponse{Code: 400, Message: err.Error()}, nil
}
user, err := s.db.Queries.GetUser(ctx, req.Username)
if err != nil {
return Register500JSONResponse{Code: 500, Message: "Failed to load created user"}, nil
}
if err := s.saveUserSession(ctx, user.ID, user.Admin, *user.AuthHash); err != nil {
return Register500JSONResponse{Code: 500, Message: err.Error()}, nil
}
return Register201JSONResponse{
Body: LoginResponse{
Username: user.ID,
IsAdmin: user.Admin,
},
Headers: Register201ResponseHeaders{
SetCookie: s.getSetCookieFromContext(ctx),
},
}, nil
}
// POST /auth/logout
func (s *Server) Logout(ctx context.Context, request LogoutRequestObject) (LogoutResponseObject, error) {
_, ok := s.getSessionFromContext(ctx)
if !ok {
return Logout401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
r := s.getRequestFromContext(ctx)
w := s.getResponseWriterFromContext(ctx)
if r == nil || w == nil {
return Logout401JSONResponse{Code: 401, Message: "Internal context error"}, nil
}
session, err := s.getCookieSession(r)
if err != nil {
return Logout401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
session.Values = make(map[any]any)
if err := session.Save(r, w); err != nil {
return Logout401JSONResponse{Code: 401, Message: "Failed to logout"}, nil
}
return Logout200Response{}, nil
}
// GET /auth/me
func (s *Server) GetMe(ctx context.Context, request GetMeRequestObject) (GetMeResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetMe401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
return GetMe200JSONResponse{
Username: auth.UserName,
IsAdmin: auth.IsAdmin,
}, nil
}
func (s *Server) saveUserSession(ctx context.Context, username string, isAdmin bool, authHash string) error {
r := s.getRequestFromContext(ctx)
w := s.getResponseWriterFromContext(ctx)
if r == nil || w == nil {
return fmt.Errorf("internal context error")
}
session, err := s.getCookieSession(r)
if err != nil {
return fmt.Errorf("unauthorized")
}
session.Values["authorizedUser"] = username
session.Values["isAdmin"] = isAdmin
session.Values["expiresAt"] = time.Now().Unix() + (60 * 60 * 24 * 7)
session.Values["authHash"] = authHash
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to create session")
}
return nil
}
func (s *Server) getCookieSession(r *http.Request) (*sessions.Session, error) {
store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey))
if s.cfg.CookieEncKey != "" {
if len(s.cfg.CookieEncKey) == 16 || len(s.cfg.CookieEncKey) == 32 {
store = sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey), []byte(s.cfg.CookieEncKey))
}
}
session, err := store.Get(r, "token")
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
}
session.Options.SameSite = http.SameSiteLaxMode
session.Options.HttpOnly = true
session.Options.Secure = s.cfg.CookieSecure
return session, nil
}
// getSessionFromContext extracts authData from context
func (s *Server) getSessionFromContext(ctx context.Context) (authData, bool) {
auth, ok := ctx.Value("auth").(authData)
if !ok {
return authData{}, false
}
return auth, true
}
// isAdmin checks if a user has admin privileges
func (s *Server) isAdmin(ctx context.Context) bool {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return false
}
return auth.IsAdmin
}
// getRequestFromContext extracts the HTTP request from context
func (s *Server) getRequestFromContext(ctx context.Context) *http.Request {
r, ok := ctx.Value("request").(*http.Request)
if !ok {
return nil
}
return r
}
// getResponseWriterFromContext extracts the response writer from context
func (s *Server) getResponseWriterFromContext(ctx context.Context) http.ResponseWriter {
w, ok := ctx.Value("response").(http.ResponseWriter)
if !ok {
return nil
}
return w
}
func (s *Server) getSetCookieFromContext(ctx context.Context) string {
w := s.getResponseWriterFromContext(ctx)
if w == nil {
return ""
}
return w.Header().Get("Set-Cookie")
}
// getSession retrieves auth data from the session cookie
func (s *Server) getSession(r *http.Request) (auth authData, ok bool) {
// Get session from cookie store
store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey))
if s.cfg.CookieEncKey != "" {
if len(s.cfg.CookieEncKey) == 16 || len(s.cfg.CookieEncKey) == 32 {
store = sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey), []byte(s.cfg.CookieEncKey))
} else {
log.Error("invalid cookie encryption key (must be 16 or 32 bytes)")
return authData{}, false
}
}
session, err := store.Get(r, "token")
if err != nil {
return authData{}, false
}
// Get session values
authorizedUser := session.Values["authorizedUser"]
isAdmin := session.Values["isAdmin"]
expiresAt := session.Values["expiresAt"]
authHash := session.Values["authHash"]
if authorizedUser == nil || isAdmin == nil || expiresAt == nil || authHash == nil {
return authData{}, false
}
auth = authData{
UserName: authorizedUser.(string),
IsAdmin: isAdmin.(bool),
AuthHash: authHash.(string),
}
// Validate auth hash
ctx := r.Context()
correctAuthHash, err := s.getUserAuthHash(ctx, auth.UserName)
if err != nil || correctAuthHash != auth.AuthHash {
return authData{}, false
}
return auth, true
}
// getUserAuthHash retrieves the user's auth hash from DB or cache
func (s *Server) getUserAuthHash(ctx context.Context, username string) (string, error) {
user, err := s.db.Queries.GetUser(ctx, username)
if err != nil {
return "", err
}
return *user.AuthHash, nil
}
// authData represents authenticated user information
type authData struct {
UserName string
IsAdmin bool
AuthHash string
}

228
api/v1/auth_test.go Normal file
View File

@@ -0,0 +1,228 @@
package v1
import (
"bytes"
"crypto/md5"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
argon2 "github.com/alexedwards/argon2id"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
)
type AuthTestSuite struct {
suite.Suite
db *database.DBManager
cfg *config.Config
srv *Server
}
func (suite *AuthTestSuite) setupConfig() *config.Config {
return &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: "/tmp",
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
}
func TestAuth(t *testing.T) {
suite.Run(t, new(AuthTestSuite))
}
func (suite *AuthTestSuite) SetupTest() {
suite.cfg = suite.setupConfig()
suite.db = database.NewMgr(suite.cfg)
suite.srv = NewServer(suite.db, suite.cfg, nil)
}
func (suite *AuthTestSuite) createTestUser(username, password string) {
md5Hash := fmt.Sprintf("%x", md5.Sum([]byte(password)))
hashedPassword, err := argon2.CreateHash(md5Hash, argon2.DefaultParams)
suite.Require().NoError(err)
authHash := "test-auth-hash"
_, err = suite.db.Queries.CreateUser(suite.T().Context(), database.CreateUserParams{
ID: username,
Pass: &hashedPassword,
AuthHash: &authHash,
Admin: true,
})
suite.Require().NoError(err)
}
func (suite *AuthTestSuite) assertSessionCookie(cookie *http.Cookie) {
suite.Require().NotNil(cookie)
suite.Equal("token", cookie.Name)
suite.NotEmpty(cookie.Value)
suite.True(cookie.HttpOnly)
}
func (suite *AuthTestSuite) login(username, password string) *http.Cookie {
reqBody := LoginRequest{
Username: username,
Password: password,
}
body, err := json.Marshal(reqBody)
suite.Require().NoError(err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code, "login should return 200")
var resp LoginResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1, "should have session cookie")
suite.assertSessionCookie(cookies[0])
return cookies[0]
}
func (suite *AuthTestSuite) TestAPILogin() {
suite.createTestUser("testuser", "testpass")
reqBody := LoginRequest{
Username: "testuser",
Password: "testpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
var resp LoginResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal("testuser", resp.Username)
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1)
suite.assertSessionCookie(cookies[0])
}
func (suite *AuthTestSuite) TestAPILoginInvalidCredentials() {
reqBody := LoginRequest{
Username: "testuser",
Password: "wrongpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}
func (suite *AuthTestSuite) TestAPIRegister() {
reqBody := LoginRequest{
Username: "newuser",
Password: "newpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusCreated, w.Code)
var resp LoginResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal("newuser", resp.Username)
suite.True(resp.IsAdmin, "first registered user should mirror legacy admin bootstrap behavior")
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1, "register should set a session cookie")
suite.assertSessionCookie(cookies[0])
user, err := suite.db.Queries.GetUser(suite.T().Context(), "newuser")
suite.Require().NoError(err)
suite.True(user.Admin)
}
func (suite *AuthTestSuite) TestAPIRegisterDisabled() {
suite.cfg.RegistrationEnabled = false
suite.srv = NewServer(suite.db, suite.cfg, nil)
reqBody := LoginRequest{
Username: "newuser",
Password: "newpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusForbidden, w.Code)
}
func (suite *AuthTestSuite) TestAPILogout() {
suite.createTestUser("testuser", "testpass")
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1)
suite.Equal("token", cookies[0].Name)
}
func (suite *AuthTestSuite) TestAPIGetMe() {
suite.createTestUser("testuser", "testpass")
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
var resp UserData
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal("testuser", resp.Username)
}
func (suite *AuthTestSuite) TestAPIGetMeUnauthenticated() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}

827
api/v1/documents.go Normal file
View File

@@ -0,0 +1,827 @@
package v1
import (
"context"
"fmt"
"io"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
"reichard.io/antholume/metadata"
)
// GET /documents
func (s *Server) GetDocuments(ctx context.Context, request GetDocumentsRequestObject) (GetDocumentsResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetDocuments401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
page := int64(1)
if request.Params.Page != nil {
page = *request.Params.Page
}
limit := int64(9)
if request.Params.Limit != nil {
limit = *request.Params.Limit
}
search := ""
if request.Params.Search != nil {
search = "%" + *request.Params.Search + "%"
}
rows, err := s.db.Queries.GetDocumentsWithStats(
ctx,
database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
Query: &search,
Deleted: ptrOf(false),
Offset: (page - 1) * limit,
Limit: limit,
},
)
if err != nil {
return GetDocuments500JSONResponse{Code: 500, Message: err.Error()}, nil
}
total := int64(len(rows))
var nextPage *int64
var previousPage *int64
if page*limit < total {
nextPage = ptrOf(page + 1)
}
if page > 1 {
previousPage = ptrOf(page - 1)
}
apiDocuments := make([]Document, len(rows))
for i, row := range rows {
apiDocuments[i] = Document{
Id: row.ID,
Title: *row.Title,
Author: *row.Author,
Description: row.Description,
Isbn10: row.Isbn10,
Isbn13: row.Isbn13,
Words: row.Words,
Filepath: row.Filepath,
Percentage: ptrOf(float32(row.Percentage)),
TotalTimeSeconds: ptrOf(row.TotalTimeSeconds),
Wpm: ptrOf(float32(row.Wpm)),
SecondsPerPercent: ptrOf(row.SecondsPerPercent),
LastRead: parseInterfaceTime(row.LastRead),
CreatedAt: time.Now(), // Will be overwritten if we had a proper created_at from DB
UpdatedAt: time.Now(), // Will be overwritten if we had a proper updated_at from DB
Deleted: false, // Default, should be overridden if available
}
}
response := DocumentsResponse{
Documents: apiDocuments,
Total: total,
Page: page,
Limit: limit,
NextPage: nextPage,
PreviousPage: previousPage,
Search: request.Params.Search,
}
return GetDocuments200JSONResponse(response), nil
}
// GET /documents/{id}
func (s *Server) GetDocument(ctx context.Context, request GetDocumentRequestObject) (GetDocumentResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetDocument401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
// Use GetDocumentsWithStats to get document with stats
docs, err := s.db.Queries.GetDocumentsWithStats(
ctx,
database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
ID: &request.Id,
Deleted: ptrOf(false),
Offset: 0,
Limit: 1,
},
)
if err != nil || len(docs) == 0 {
return GetDocument404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
doc := docs[0]
apiDoc := Document{
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Description: doc.Description,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Filepath: doc.Filepath,
Percentage: ptrOf(float32(doc.Percentage)),
TotalTimeSeconds: ptrOf(doc.TotalTimeSeconds),
Wpm: ptrOf(float32(doc.Wpm)),
SecondsPerPercent: ptrOf(doc.SecondsPerPercent),
LastRead: parseInterfaceTime(doc.LastRead),
CreatedAt: time.Now(), // Will be overwritten if we had a proper created_at from DB
UpdatedAt: time.Now(), // Will be overwritten if we had a proper updated_at from DB
Deleted: false, // Default, should be overridden if available
}
response := DocumentResponse{
Document: apiDoc,
}
return GetDocument200JSONResponse(response), nil
}
// POST /documents/{id}
func (s *Server) EditDocument(ctx context.Context, request EditDocumentRequestObject) (EditDocumentResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return EditDocument401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return EditDocument400JSONResponse{Code: 400, Message: "Missing request body"}, nil
}
// Validate document exists and get current state
currentDoc, err := s.db.Queries.GetDocument(ctx, request.Id)
if err != nil {
return EditDocument404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
// Validate at least one editable field is provided
if request.Body.Title == nil &&
request.Body.Author == nil &&
request.Body.Description == nil &&
request.Body.Isbn10 == nil &&
request.Body.Isbn13 == nil &&
request.Body.CoverGbid == nil {
return EditDocument400JSONResponse{Code: 400, Message: "No editable fields provided"}, nil
}
// Handle cover via Google Books ID
var coverFileName *string
if request.Body.CoverGbid != nil {
coverDir := filepath.Join(s.cfg.DataPath, "covers")
fileName, err := metadata.CacheCoverWithContext(ctx, *request.Body.CoverGbid, coverDir, request.Id, true)
if err == nil {
coverFileName = fileName
}
}
// Update document with provided editable fields only
_, err = s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: request.Id,
Title: request.Body.Title,
Author: request.Body.Author,
Description: request.Body.Description,
Isbn10: request.Body.Isbn10,
Isbn13: request.Body.Isbn13,
Coverfile: coverFileName,
// Preserve existing values for non-editable fields
Md5: currentDoc.Md5,
Basepath: currentDoc.Basepath,
Filepath: currentDoc.Filepath,
Words: currentDoc.Words,
})
if err != nil {
log.Error("UpsertDocument DB Error:", err)
return EditDocument500JSONResponse{Code: 500, Message: "Failed to update document"}, nil
}
// Use GetDocumentsWithStats to get document with stats for the response
docs, err := s.db.Queries.GetDocumentsWithStats(
ctx,
database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
ID: &request.Id,
Deleted: ptrOf(false),
Offset: 0,
Limit: 1,
},
)
if err != nil || len(docs) == 0 {
return EditDocument404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
doc := docs[0]
apiDoc := Document{
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Description: doc.Description,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Filepath: doc.Filepath,
Percentage: ptrOf(float32(doc.Percentage)),
TotalTimeSeconds: ptrOf(doc.TotalTimeSeconds),
Wpm: ptrOf(float32(doc.Wpm)),
SecondsPerPercent: ptrOf(doc.SecondsPerPercent),
LastRead: parseInterfaceTime(doc.LastRead),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Deleted: false,
}
response := DocumentResponse{
Document: apiDoc,
}
return EditDocument200JSONResponse(response), nil
}
// deriveBaseFileName builds the base filename for a given MetadataInfo object.
func deriveBaseFileName(metadataInfo *metadata.MetadataInfo) string {
// Derive New FileName
var newFileName string
if metadataInfo.Author != nil && *metadataInfo.Author != "" {
newFileName = newFileName + *metadataInfo.Author
} else {
newFileName = newFileName + "Unknown"
}
if metadataInfo.Title != nil && *metadataInfo.Title != "" {
newFileName = newFileName + " - " + *metadataInfo.Title
} else {
newFileName = newFileName + " - Unknown"
}
// Remove Slashes
fileName := strings.ReplaceAll(newFileName, "/", "")
return "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, *metadataInfo.PartialMD5, metadataInfo.Type))
}
// parseInterfaceTime converts an interface{} to time.Time for SQLC queries
func parseInterfaceTime(t any) *time.Time {
if t == nil {
return nil
}
switch v := t.(type) {
case string:
parsed, err := time.Parse(time.RFC3339, v)
if err != nil {
return nil
}
return &parsed
case time.Time:
return &v
default:
return nil
}
}
// serveNoCover serves the default no-cover image from assets
func (s *Server) serveNoCover() (fs.File, string, int64, error) {
// Try to open the no-cover image from assets
file, err := s.assets.Open("assets/images/no-cover.jpg")
if err != nil {
return nil, "", 0, err
}
// Get file info
info, err := file.Stat()
if err != nil {
file.Close()
return nil, "", 0, err
}
return file, "image/jpeg", info.Size(), nil
}
// openFileReader opens a file and returns it as an io.ReaderCloser
func openFileReader(path string) (*os.File, error) {
return os.Open(path)
}
// GET /documents/{id}/cover
func (s *Server) GetDocumentCover(ctx context.Context, request GetDocumentCoverRequestObject) (GetDocumentCoverResponseObject, error) {
// Authentication is handled by middleware, which also adds auth data to context
// This endpoint just serves the cover image
// Validate Document Exists in DB
document, err := s.db.Queries.GetDocument(ctx, request.Id)
if err != nil {
log.Error("GetDocument DB Error:", err)
return GetDocumentCover404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
var coverFile fs.File
var contentType string
var contentLength int64
var needMetadataFetch bool
// Handle Identified Document
if document.Coverfile != nil {
if *document.Coverfile == "UNKNOWN" {
// Serve no-cover image
file, ct, size, err := s.serveNoCover()
if err != nil {
log.Error("Failed to open no-cover image:", err)
return GetDocumentCover404JSONResponse{Code: 404, Message: "Cover not found"}, nil
}
coverFile = file
contentType = ct
contentLength = size
needMetadataFetch = true
} else {
// Derive Path
coverPath := filepath.Join(s.cfg.DataPath, "covers", *document.Coverfile)
// Validate File Exists
fileInfo, err := os.Stat(coverPath)
if os.IsNotExist(err) {
log.Error("Cover file should but doesn't exist: ", err)
// Serve no-cover image
file, ct, size, err := s.serveNoCover()
if err != nil {
log.Error("Failed to open no-cover image:", err)
return GetDocumentCover404JSONResponse{Code: 404, Message: "Cover not found"}, nil
}
coverFile = file
contentType = ct
contentLength = size
needMetadataFetch = true
} else {
// Open the cover file
file, err := openFileReader(coverPath)
if err != nil {
log.Error("Failed to open cover file:", err)
return GetDocumentCover500JSONResponse{Code: 500, Message: "Failed to open cover"}, nil
}
coverFile = file
contentLength = fileInfo.Size()
// Determine content type based on file extension
contentType = "image/jpeg"
if strings.HasSuffix(coverPath, ".png") {
contentType = "image/png"
}
}
}
} else {
needMetadataFetch = true
}
// Attempt Metadata fetch if needed
var cachedCoverFile string = "UNKNOWN"
var coverDir string = filepath.Join(s.cfg.DataPath, "covers")
if needMetadataFetch {
// Create context with timeout for metadata service calls
metadataCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Identify Documents & Save Covers
metadataResults, err := metadata.SearchMetadataWithContext(metadataCtx, metadata.SOURCE_GBOOK, metadata.MetadataInfo{
Title: document.Title,
Author: document.Author,
})
if err == nil && len(metadataResults) > 0 && metadataResults[0].ID != nil {
firstResult := metadataResults[0]
// Save Cover
fileName, err := metadata.CacheCoverWithContext(metadataCtx, *firstResult.ID, coverDir, document.ID, false)
if err == nil {
cachedCoverFile = *fileName
}
// Store First Metadata Result
if _, err = s.db.Queries.AddMetadata(ctx, database.AddMetadataParams{
DocumentID: document.ID,
Title: firstResult.Title,
Author: firstResult.Author,
Description: firstResult.Description,
Gbid: firstResult.ID,
Olid: nil,
Isbn10: firstResult.ISBN10,
Isbn13: firstResult.ISBN13,
}); err != nil {
log.Error("AddMetadata DB Error:", err)
}
}
// Upsert Document
if _, err = s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: document.ID,
Coverfile: &cachedCoverFile,
}); err != nil {
log.Warn("UpsertDocument DB Error:", err)
}
// Update cover file if we got a new cover
if cachedCoverFile != "UNKNOWN" {
coverPath := filepath.Join(coverDir, cachedCoverFile)
fileInfo, err := os.Stat(coverPath)
if err != nil {
log.Error("Failed to stat cached cover:", err)
// Keep the no-cover image
} else {
file, err := openFileReader(coverPath)
if err != nil {
log.Error("Failed to open cached cover:", err)
// Keep the no-cover image
} else {
_ = coverFile.Close() // Close the previous file
coverFile = file
contentLength = fileInfo.Size()
// Determine content type based on file extension
contentType = "image/jpeg"
if strings.HasSuffix(coverPath, ".png") {
contentType = "image/png"
}
}
}
}
}
return &GetDocumentCover200Response{
Body: coverFile,
ContentLength: contentLength,
ContentType: contentType,
}, nil
}
// POST /documents/{id}/cover
func (s *Server) UploadDocumentCover(ctx context.Context, request UploadDocumentCoverRequestObject) (UploadDocumentCoverResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return UploadDocumentCover401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return UploadDocumentCover400JSONResponse{Code: 400, Message: "Missing request body"}, nil
}
// Validate document exists
_, err := s.db.Queries.GetDocument(ctx, request.Id)
if err != nil {
return UploadDocumentCover404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
// Read multipart form
form, err := request.Body.ReadForm(32 << 20) // 32MB max
if err != nil {
log.Error("ReadForm error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Failed to read form"}, nil
}
// Get file from form
fileField := form.File["cover_file"]
if len(fileField) == 0 {
return UploadDocumentCover400JSONResponse{Code: 400, Message: "No file provided"}, nil
}
file := fileField[0]
// Validate file extension
if !strings.HasSuffix(strings.ToLower(file.Filename), ".jpg") && !strings.HasSuffix(strings.ToLower(file.Filename), ".png") {
return UploadDocumentCover400JSONResponse{Code: 400, Message: "Only JPG and PNG files are allowed"}, nil
}
// Open file
f, err := file.Open()
if err != nil {
log.Error("Open file error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Failed to open file"}, nil
}
defer f.Close()
// Read file content
data, err := io.ReadAll(f)
if err != nil {
log.Error("Read file error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Failed to read file"}, nil
}
// Validate actual content type
contentType := http.DetectContentType(data)
allowedTypes := map[string]bool{
"image/jpeg": true,
"image/png": true,
}
if !allowedTypes[contentType] {
return UploadDocumentCover400JSONResponse{
Code: 400,
Message: fmt.Sprintf("Invalid file type: %s. Only JPG and PNG files are allowed.", contentType),
}, nil
}
// Derive storage path
coverDir := filepath.Join(s.cfg.DataPath, "covers")
fileName := fmt.Sprintf("%s%s", request.Id, strings.ToLower(filepath.Ext(file.Filename)))
safePath := filepath.Join(coverDir, fileName)
// Save file
err = os.WriteFile(safePath, data, 0644)
if err != nil {
log.Error("Save file error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Unable to save cover"}, nil
}
// Upsert document with new cover
_, err = s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: request.Id,
Coverfile: &fileName,
})
if err != nil {
log.Error("UpsertDocument DB error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Failed to save cover"}, nil
}
// Use GetDocumentsWithStats to get document with stats for the response
docs, err := s.db.Queries.GetDocumentsWithStats(
ctx,
database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
ID: &request.Id,
Deleted: ptrOf(false),
Offset: 0,
Limit: 1,
},
)
if err != nil || len(docs) == 0 {
return UploadDocumentCover404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
doc := docs[0]
apiDoc := Document{
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Description: doc.Description,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Filepath: doc.Filepath,
Percentage: ptrOf(float32(doc.Percentage)),
TotalTimeSeconds: ptrOf(doc.TotalTimeSeconds),
Wpm: ptrOf(float32(doc.Wpm)),
SecondsPerPercent: ptrOf(doc.SecondsPerPercent),
LastRead: parseInterfaceTime(doc.LastRead),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Deleted: false,
}
response := DocumentResponse{
Document: apiDoc,
}
return UploadDocumentCover200JSONResponse(response), nil
}
// GET /documents/{id}/file
func (s *Server) GetDocumentFile(ctx context.Context, request GetDocumentFileRequestObject) (GetDocumentFileResponseObject, error) {
// Authentication is handled by middleware, which also adds auth data to context
// This endpoint just serves the document file download
// Get Document
document, err := s.db.Queries.GetDocument(ctx, request.Id)
if err != nil {
log.Error("GetDocument DB Error:", err)
return GetDocumentFile404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
if document.Filepath == nil {
log.Error("Document Doesn't Have File:", request.Id)
return GetDocumentFile404JSONResponse{Code: 404, Message: "Document file not found"}, nil
}
// Derive Basepath
basepath := filepath.Join(s.cfg.DataPath, "documents")
if document.Basepath != nil && *document.Basepath != "" {
basepath = *document.Basepath
}
// Derive Storage Location
filePath := filepath.Join(basepath, *document.Filepath)
// Validate File Exists
fileInfo, err := os.Stat(filePath)
if os.IsNotExist(err) {
log.Error("File should but doesn't exist:", err)
return GetDocumentFile404JSONResponse{Code: 404, Message: "Document file not found"}, nil
}
// Open file
file, err := os.Open(filePath)
if err != nil {
log.Error("Failed to open document file:", err)
return GetDocumentFile500JSONResponse{Code: 500, Message: "Failed to open document"}, nil
}
return &GetDocumentFile200Response{
Body: file,
ContentLength: fileInfo.Size(),
Filename: filepath.Base(*document.Filepath),
}, nil
}
// POST /documents
func (s *Server) CreateDocument(ctx context.Context, request CreateDocumentRequestObject) (CreateDocumentResponseObject, error) {
_, ok := s.getSessionFromContext(ctx)
if !ok {
return CreateDocument401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return CreateDocument400JSONResponse{Code: 400, Message: "Missing request body"}, nil
}
// Read multipart form
form, err := request.Body.ReadForm(32 << 20) // 32MB max memory
if err != nil {
log.Error("ReadForm error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Failed to read form"}, nil
}
// Get file from form
fileField := form.File["document_file"]
if len(fileField) == 0 {
return CreateDocument400JSONResponse{Code: 400, Message: "No file provided"}, nil
}
file := fileField[0]
// Validate file extension
if !strings.HasSuffix(strings.ToLower(file.Filename), ".epub") {
return CreateDocument400JSONResponse{Code: 400, Message: "Only EPUB files are allowed"}, nil
}
// Open file
f, err := file.Open()
if err != nil {
log.Error("Open file error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Failed to open file"}, nil
}
defer f.Close()
// Read file content
data, err := io.ReadAll(f)
if err != nil {
log.Error("Read file error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Failed to read file"}, nil
}
// Validate actual content type
contentType := http.DetectContentType(data)
if contentType != "application/epub+zip" && contentType != "application/zip" {
return CreateDocument400JSONResponse{
Code: 400,
Message: fmt.Sprintf("Invalid file type: %s. Only EPUB files are allowed.", contentType),
}, nil
}
// Create temp file to get metadata
tempFile, err := os.CreateTemp("", "book")
if err != nil {
log.Error("Temp file create error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Unable to create temp file"}, nil
}
defer os.Remove(tempFile.Name())
defer tempFile.Close()
// Write data to temp file
if _, err := tempFile.Write(data); err != nil {
log.Error("Write temp file error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Unable to write temp file"}, nil
}
// Get metadata using metadata package
metadataInfo, err := metadata.GetMetadata(tempFile.Name())
if err != nil {
log.Error("GetMetadata error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Unable to acquire metadata"}, nil
}
// Check if already exists
_, err = s.db.Queries.GetDocument(ctx, *metadataInfo.PartialMD5)
if err == nil {
// Document already exists
existingDoc, _ := s.db.Queries.GetDocument(ctx, *metadataInfo.PartialMD5)
apiDoc := Document{
Id: existingDoc.ID,
Title: *existingDoc.Title,
Author: *existingDoc.Author,
Description: existingDoc.Description,
Isbn10: existingDoc.Isbn10,
Isbn13: existingDoc.Isbn13,
Words: existingDoc.Words,
Filepath: existingDoc.Filepath,
CreatedAt: parseTime(existingDoc.CreatedAt),
UpdatedAt: parseTime(existingDoc.UpdatedAt),
Deleted: existingDoc.Deleted,
}
response := DocumentResponse{
Document: apiDoc,
}
return CreateDocument200JSONResponse(response), nil
}
// Derive & sanitize file name
fileName := deriveBaseFileName(metadataInfo)
basePath := filepath.Join(s.cfg.DataPath, "documents")
safePath := filepath.Join(basePath, fileName)
// Save file to storage
err = os.WriteFile(safePath, data, 0644)
if err != nil {
log.Error("Save file error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Unable to save file"}, nil
}
// Upsert document
doc, err := s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: *metadataInfo.PartialMD5,
Title: metadataInfo.Title,
Author: metadataInfo.Author,
Description: metadataInfo.Description,
Md5: metadataInfo.MD5,
Words: metadataInfo.WordCount,
Filepath: &fileName,
Basepath: &basePath,
})
if err != nil {
log.Error("UpsertDocument DB error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Failed to save document"}, nil
}
apiDoc := Document{
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Description: doc.Description,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Filepath: doc.Filepath,
CreatedAt: parseTime(doc.CreatedAt),
UpdatedAt: parseTime(doc.UpdatedAt),
Deleted: doc.Deleted,
}
response := DocumentResponse{
Document: apiDoc,
}
return CreateDocument200JSONResponse(response), nil
}
// GetDocumentCover200Response is a custom response type that allows setting content type
type GetDocumentCover200Response struct {
Body io.Reader
ContentLength int64
ContentType string
}
func (response GetDocumentCover200Response) VisitGetDocumentCoverResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", response.ContentType)
if response.ContentLength != 0 {
w.Header().Set("Content-Length", fmt.Sprint(response.ContentLength))
}
w.WriteHeader(200)
if closer, ok := response.Body.(io.Closer); ok {
defer closer.Close()
}
_, err := io.Copy(w, response.Body)
return err
}
// GetDocumentFile200Response is a custom response type that allows setting filename for download
type GetDocumentFile200Response struct {
Body io.Reader
ContentLength int64
Filename string
}
func (response GetDocumentFile200Response) VisitGetDocumentFileResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/octet-stream")
if response.ContentLength != 0 {
w.Header().Set("Content-Length", fmt.Sprint(response.ContentLength))
}
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", response.Filename))
w.WriteHeader(200)
if closer, ok := response.Body.(io.Closer); ok {
defer closer.Close()
}
_, err := io.Copy(w, response.Body)
return err
}

178
api/v1/documents_test.go Normal file
View File

@@ -0,0 +1,178 @@
package v1
import (
"bytes"
"crypto/md5"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
argon2 "github.com/alexedwards/argon2id"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
"reichard.io/antholume/pkg/ptr"
)
type DocumentsTestSuite struct {
suite.Suite
db *database.DBManager
cfg *config.Config
srv *Server
}
func (suite *DocumentsTestSuite) setupConfig() *config.Config {
return &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: "/tmp",
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
}
func TestDocuments(t *testing.T) {
suite.Run(t, new(DocumentsTestSuite))
}
func (suite *DocumentsTestSuite) SetupTest() {
suite.cfg = suite.setupConfig()
suite.db = database.NewMgr(suite.cfg)
suite.srv = NewServer(suite.db, suite.cfg, nil)
}
func (suite *DocumentsTestSuite) createTestUser(username, password string) {
suite.authTestSuiteHelper(username, password)
}
func (suite *DocumentsTestSuite) login(username, password string) *http.Cookie {
return suite.authLoginHelper(username, password)
}
func (suite *DocumentsTestSuite) authTestSuiteHelper(username, password string) {
// MD5 hash for KOSync compatibility (matches existing system)
md5Hash := fmt.Sprintf("%x", md5.Sum([]byte(password)))
// Then argon2 hash the MD5
hashedPassword, err := argon2.CreateHash(md5Hash, argon2.DefaultParams)
suite.Require().NoError(err)
_, err = suite.db.Queries.CreateUser(suite.T().Context(), database.CreateUserParams{
ID: username,
Pass: &hashedPassword,
AuthHash: ptr.Of("test-auth-hash"),
Admin: true,
})
suite.Require().NoError(err)
}
func (suite *DocumentsTestSuite) authLoginHelper(username, password string) *http.Cookie {
reqBody := LoginRequest{Username: username, Password: password}
body, err := json.Marshal(reqBody)
suite.Require().NoError(err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1)
return cookies[0]
}
func (suite *DocumentsTestSuite) TestAPIGetDocuments() {
suite.createTestUser("testuser", "testpass")
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents?page=1&limit=9", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
var resp DocumentsResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal(int64(1), resp.Page)
suite.Equal(int64(9), resp.Limit)
}
func (suite *DocumentsTestSuite) TestAPIGetDocumentsUnauthenticated() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}
func (suite *DocumentsTestSuite) TestAPIGetDocument() {
suite.createTestUser("testuser", "testpass")
docID := "test-doc-1"
_, err := suite.db.Queries.UpsertDocument(suite.T().Context(), database.UpsertDocumentParams{
ID: docID,
Title: ptr.Of("Test Document"),
Author: ptr.Of("Test Author"),
})
suite.Require().NoError(err)
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/"+docID, nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
var resp DocumentResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal(docID, resp.Document.Id)
suite.Equal("Test Document", resp.Document.Title)
}
func (suite *DocumentsTestSuite) TestAPIGetDocumentNotFound() {
suite.createTestUser("testuser", "testpass")
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/non-existent", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusNotFound, w.Code)
}
func (suite *DocumentsTestSuite) TestAPIGetDocumentCoverUnauthenticated() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/test-id/cover", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}
func (suite *DocumentsTestSuite) TestAPIGetDocumentFileUnauthenticated() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/test-id/file", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}

3
api/v1/generate.go Normal file
View File

@@ -0,0 +1,3 @@
package v1
//go:generate oapi-codegen -config oapi-codegen.yaml openapi.yaml

226
api/v1/home.go Normal file
View File

@@ -0,0 +1,226 @@
package v1
import (
"context"
"sort"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
"reichard.io/antholume/graph"
)
// GET /home
func (s *Server) GetHome(ctx context.Context, request GetHomeRequestObject) (GetHomeResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetHome401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
// Get database info
dbInfo, err := s.db.Queries.GetDatabaseInfo(ctx, auth.UserName)
if err != nil {
log.Error("GetDatabaseInfo DB Error:", err)
return GetHome500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Get streaks
streaks, err := s.db.Queries.GetUserStreaks(ctx, auth.UserName)
if err != nil {
log.Error("GetUserStreaks DB Error:", err)
return GetHome500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Get graph data
graphData, err := s.db.Queries.GetDailyReadStats(ctx, auth.UserName)
if err != nil {
log.Error("GetDailyReadStats DB Error:", err)
return GetHome500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Get user statistics
userStats, err := s.db.Queries.GetUserStatistics(ctx)
if err != nil {
log.Error("GetUserStatistics DB Error:", err)
return GetHome500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Build response
response := HomeResponse{
DatabaseInfo: DatabaseInfo{
DocumentsSize: dbInfo.DocumentsSize,
ActivitySize: dbInfo.ActivitySize,
ProgressSize: dbInfo.ProgressSize,
DevicesSize: dbInfo.DevicesSize,
},
Streaks: StreaksResponse{
Streaks: convertStreaks(streaks),
},
GraphData: GraphDataResponse{
GraphData: convertGraphData(graphData),
},
UserStatistics: arrangeUserStatistics(userStats),
}
return GetHome200JSONResponse(response), nil
}
// GET /home/streaks
func (s *Server) GetStreaks(ctx context.Context, request GetStreaksRequestObject) (GetStreaksResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetStreaks401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
streaks, err := s.db.Queries.GetUserStreaks(ctx, auth.UserName)
if err != nil {
log.Error("GetUserStreaks DB Error:", err)
return GetStreaks500JSONResponse{Code: 500, Message: "Database error"}, nil
}
response := StreaksResponse{
Streaks: convertStreaks(streaks),
}
return GetStreaks200JSONResponse(response), nil
}
// GET /home/graph
func (s *Server) GetGraphData(ctx context.Context, request GetGraphDataRequestObject) (GetGraphDataResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetGraphData401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
graphData, err := s.db.Queries.GetDailyReadStats(ctx, auth.UserName)
if err != nil {
log.Error("GetDailyReadStats DB Error:", err)
return GetGraphData500JSONResponse{Code: 500, Message: "Database error"}, nil
}
response := GraphDataResponse{
GraphData: convertGraphData(graphData),
}
return GetGraphData200JSONResponse(response), nil
}
// GET /home/statistics
func (s *Server) GetUserStatistics(ctx context.Context, request GetUserStatisticsRequestObject) (GetUserStatisticsResponseObject, error) {
_, ok := s.getSessionFromContext(ctx)
if !ok {
return GetUserStatistics401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
userStats, err := s.db.Queries.GetUserStatistics(ctx)
if err != nil {
log.Error("GetUserStatistics DB Error:", err)
return GetUserStatistics500JSONResponse{Code: 500, Message: "Database error"}, nil
}
response := arrangeUserStatistics(userStats)
return GetUserStatistics200JSONResponse(response), nil
}
func convertStreaks(streaks []database.UserStreak) []UserStreak {
result := make([]UserStreak, len(streaks))
for i, streak := range streaks {
result[i] = UserStreak{
Window: streak.Window,
MaxStreak: streak.MaxStreak,
MaxStreakStartDate: streak.MaxStreakStartDate,
MaxStreakEndDate: streak.MaxStreakEndDate,
CurrentStreak: streak.CurrentStreak,
CurrentStreakStartDate: streak.CurrentStreakStartDate,
CurrentStreakEndDate: streak.CurrentStreakEndDate,
}
}
return result
}
func convertGraphData(graphData []database.GetDailyReadStatsRow) []GraphDataPoint {
result := make([]GraphDataPoint, len(graphData))
for i, data := range graphData {
result[i] = GraphDataPoint{
Date: data.Date,
MinutesRead: data.MinutesRead,
}
}
return result
}
func arrangeUserStatistics(userStatistics []database.GetUserStatisticsRow) UserStatisticsResponse {
// Sort by WPM for each period
sortByWPM := func(stats []database.GetUserStatisticsRow, getter func(database.GetUserStatisticsRow) float64) []LeaderboardEntry {
sorted := append([]database.GetUserStatisticsRow(nil), stats...)
sort.SliceStable(sorted, func(i, j int) bool {
return getter(sorted[i]) > getter(sorted[j])
})
result := make([]LeaderboardEntry, len(sorted))
for i, item := range sorted {
result[i] = LeaderboardEntry{UserId: item.UserID, Value: getter(item)}
}
return result
}
// Sort by duration (seconds) for each period
sortByDuration := func(stats []database.GetUserStatisticsRow, getter func(database.GetUserStatisticsRow) int64) []LeaderboardEntry {
sorted := append([]database.GetUserStatisticsRow(nil), stats...)
sort.SliceStable(sorted, func(i, j int) bool {
return getter(sorted[i]) > getter(sorted[j])
})
result := make([]LeaderboardEntry, len(sorted))
for i, item := range sorted {
result[i] = LeaderboardEntry{UserId: item.UserID, Value: float64(getter(item))}
}
return result
}
// Sort by words for each period
sortByWords := func(stats []database.GetUserStatisticsRow, getter func(database.GetUserStatisticsRow) int64) []LeaderboardEntry {
sorted := append([]database.GetUserStatisticsRow(nil), stats...)
sort.SliceStable(sorted, func(i, j int) bool {
return getter(sorted[i]) > getter(sorted[j])
})
result := make([]LeaderboardEntry, len(sorted))
for i, item := range sorted {
result[i] = LeaderboardEntry{UserId: item.UserID, Value: float64(getter(item))}
}
return result
}
return UserStatisticsResponse{
Wpm: LeaderboardData{
All: sortByWPM(userStatistics, func(s database.GetUserStatisticsRow) float64 { return s.TotalWpm }),
Year: sortByWPM(userStatistics, func(s database.GetUserStatisticsRow) float64 { return s.YearlyWpm }),
Month: sortByWPM(userStatistics, func(s database.GetUserStatisticsRow) float64 { return s.MonthlyWpm }),
Week: sortByWPM(userStatistics, func(s database.GetUserStatisticsRow) float64 { return s.WeeklyWpm }),
},
Duration: LeaderboardData{
All: sortByDuration(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.TotalSeconds }),
Year: sortByDuration(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.YearlySeconds }),
Month: sortByDuration(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.MonthlySeconds }),
Week: sortByDuration(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.WeeklySeconds }),
},
Words: LeaderboardData{
All: sortByWords(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.TotalWordsRead }),
Year: sortByWords(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.YearlyWordsRead }),
Month: sortByWords(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.MonthlyWordsRead }),
Week: sortByWords(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.WeeklyWordsRead }),
},
}
}
// GetSVGGraphData generates SVG bezier path for graph visualization
func GetSVGGraphData(inputData []GraphDataPoint, svgWidth int, svgHeight int) graph.SVGGraphData {
// Convert to int64 slice expected by graph package
intData := make([]int64, len(inputData))
for i, data := range inputData {
intData[i] = int64(data.MinutesRead)
}
return graph.GetSVGGraphData(intData, svgWidth, svgHeight)
}

6
api/v1/oapi-codegen.yaml Normal file
View File

@@ -0,0 +1,6 @@
package: v1
generate:
std-http-server: true
strict-server: true
models: true
output: api.gen.go

1977
api/v1/openapi.yaml Normal file

File diff suppressed because it is too large Load Diff

163
api/v1/progress.go Normal file
View File

@@ -0,0 +1,163 @@
package v1
import (
"context"
"math"
"time"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
)
// GET /progress
func (s *Server) GetProgressList(ctx context.Context, request GetProgressListRequestObject) (GetProgressListResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetProgressList401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
page := int64(1)
if request.Params.Page != nil {
page = *request.Params.Page
}
limit := int64(15)
if request.Params.Limit != nil {
limit = *request.Params.Limit
}
filter := database.GetProgressParams{
UserID: auth.UserName,
Offset: (page - 1) * limit,
Limit: limit,
}
if request.Params.Document != nil && *request.Params.Document != "" {
filter.DocFilter = true
filter.DocumentID = *request.Params.Document
}
progress, err := s.db.Queries.GetProgress(ctx, filter)
if err != nil {
log.Error("GetProgress DB Error:", err)
return GetProgressList500JSONResponse{Code: 500, Message: "Database error"}, nil
}
total := int64(len(progress))
var nextPage *int64
var previousPage *int64
// Calculate total pages
totalPages := int64(math.Ceil(float64(total) / float64(limit)))
if page < totalPages {
nextPage = ptrOf(page + 1)
}
if page > 1 {
previousPage = ptrOf(page - 1)
}
apiProgress := make([]Progress, len(progress))
for i, row := range progress {
apiProgress[i] = Progress{
Title: row.Title,
Author: row.Author,
DeviceName: &row.DeviceName,
Percentage: &row.Percentage,
DocumentId: &row.DocumentID,
UserId: &row.UserID,
CreatedAt: parseTimePtr(row.CreatedAt),
}
}
response := ProgressListResponse{
Progress: &apiProgress,
Page: &page,
Limit: &limit,
NextPage: nextPage,
PreviousPage: previousPage,
Total: &total,
}
return GetProgressList200JSONResponse(response), nil
}
// GET /progress/{id}
func (s *Server) GetProgress(ctx context.Context, request GetProgressRequestObject) (GetProgressResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetProgress401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
row, err := s.db.Queries.GetDocumentProgress(ctx, database.GetDocumentProgressParams{
UserID: auth.UserName,
DocumentID: request.Id,
})
if err != nil {
log.Error("GetDocumentProgress DB Error:", err)
return GetProgress404JSONResponse{Code: 404, Message: "Progress not found"}, nil
}
apiProgress := Progress{
DeviceName: &row.DeviceName,
DeviceId: &row.DeviceID,
Percentage: &row.Percentage,
Progress: &row.Progress,
DocumentId: &row.DocumentID,
UserId: &row.UserID,
CreatedAt: parseTimePtr(row.CreatedAt),
}
response := ProgressResponse{
Progress: &apiProgress,
}
return GetProgress200JSONResponse(response), nil
}
// PUT /progress
func (s *Server) UpdateProgress(ctx context.Context, request UpdateProgressRequestObject) (UpdateProgressResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return UpdateProgress401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return UpdateProgress400JSONResponse{Code: 400, Message: "Request body is required"}, nil
}
if _, err := s.db.Queries.UpsertDevice(ctx, database.UpsertDeviceParams{
ID: request.Body.DeviceId,
UserID: auth.UserName,
DeviceName: request.Body.DeviceName,
LastSynced: time.Now().UTC().Format(time.RFC3339),
}); err != nil {
log.Error("UpsertDevice DB Error:", err)
return UpdateProgress500JSONResponse{Code: 500, Message: "Database error"}, nil
}
if _, err := s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: request.Body.DocumentId,
}); err != nil {
log.Error("UpsertDocument DB Error:", err)
return UpdateProgress500JSONResponse{Code: 500, Message: "Database error"}, nil
}
progress, err := s.db.Queries.UpdateProgress(ctx, database.UpdateProgressParams{
Percentage: request.Body.Percentage,
DocumentID: request.Body.DocumentId,
DeviceID: request.Body.DeviceId,
UserID: auth.UserName,
Progress: request.Body.Progress,
})
if err != nil {
log.Error("UpdateProgress DB Error:", err)
return UpdateProgress400JSONResponse{Code: 400, Message: "Invalid request"}, nil
}
response := UpdateProgressResponse{
DocumentId: progress.DocumentID,
Timestamp: parseTime(progress.CreatedAt),
}
return UpdateProgress200JSONResponse(response), nil
}

59
api/v1/search.go Normal file
View File

@@ -0,0 +1,59 @@
package v1
import (
"context"
"reichard.io/antholume/search"
log "github.com/sirupsen/logrus"
)
// GET /search
func (s *Server) GetSearch(ctx context.Context, request GetSearchRequestObject) (GetSearchResponseObject, error) {
if request.Params.Query == "" {
return GetSearch400JSONResponse{Code: 400, Message: "Invalid query"}, nil
}
query := request.Params.Query
source := string(request.Params.Source)
// Validate source
if source != "LibGen" && source != "Annas Archive" {
return GetSearch400JSONResponse{Code: 400, Message: "Invalid source"}, nil
}
searchResults, err := search.SearchBook(query, search.Source(source))
if err != nil {
log.Error("Search Error:", err)
return GetSearch500JSONResponse{Code: 500, Message: "Search error"}, nil
}
apiResults := make([]SearchItem, len(searchResults))
for i, item := range searchResults {
apiResults[i] = SearchItem{
Id: ptrOf(item.ID),
Title: ptrOf(item.Title),
Author: ptrOf(item.Author),
Language: ptrOf(item.Language),
Series: ptrOf(item.Series),
FileType: ptrOf(item.FileType),
FileSize: ptrOf(item.FileSize),
UploadDate: ptrOf(item.UploadDate),
}
}
response := SearchResponse{
Results: apiResults,
Source: source,
Query: query,
}
return GetSearch200JSONResponse(response), nil
}
// POST /search
func (s *Server) PostSearch(ctx context.Context, request PostSearchRequestObject) (PostSearchResponseObject, error) {
// This endpoint is used by the SSR template to queue a download
// For the API, we just return success - the actual download happens via /documents POST
return PostSearch200Response{}, nil
}

99
api/v1/server.go Normal file
View File

@@ -0,0 +1,99 @@
package v1
import (
"context"
"encoding/json"
"io/fs"
"net/http"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
)
var _ StrictServerInterface = (*Server)(nil)
type Server struct {
mux *http.ServeMux
db *database.DBManager
cfg *config.Config
assets fs.FS
}
// NewServer creates a new native HTTP server
func NewServer(db *database.DBManager, cfg *config.Config, assets fs.FS) *Server {
s := &Server{
mux: http.NewServeMux(),
db: db,
cfg: cfg,
assets: assets,
}
// Create strict handler with authentication middleware
strictHandler := NewStrictHandler(s, []StrictMiddlewareFunc{s.authMiddleware})
s.mux = HandlerFromMuxWithBaseURL(strictHandler, s.mux, "/api/v1").(*http.ServeMux)
return s
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.mux.ServeHTTP(w, r)
}
// authMiddleware adds authentication context to requests
func (s *Server) authMiddleware(handler StrictHandlerFunc, operationID string) StrictHandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (any, error) {
// Store request and response in context for all handlers
ctx = context.WithValue(ctx, "request", r)
ctx = context.WithValue(ctx, "response", w)
// Skip auth for public auth and info endpoints - cover and file require auth via cookies
if operationID == "Login" || operationID == "Register" || operationID == "GetInfo" {
return handler(ctx, w, r, request)
}
auth, ok := s.getSession(r)
if !ok {
// Write 401 response directly
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(401)
json.NewEncoder(w).Encode(ErrorResponse{Code: 401, Message: "Unauthorized"})
return nil, nil
}
// Check admin status for admin-only endpoints
adminEndpoints := []string{
"GetAdmin",
"PostAdminAction",
"GetUsers",
"UpdateUser",
"GetImportDirectory",
"PostImport",
"GetImportResults",
"GetLogs",
}
for _, adminEndpoint := range adminEndpoints {
if operationID == adminEndpoint && !auth.IsAdmin {
// Write 403 response directly
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(403)
json.NewEncoder(w).Encode(ErrorResponse{Code: 403, Message: "Admin privileges required"})
return nil, nil
}
}
// Store auth in context for handlers to access
ctx = context.WithValue(ctx, "auth", auth)
return handler(ctx, w, r, request)
}
}
// GetInfo returns server information
func (s *Server) GetInfo(ctx context.Context, request GetInfoRequestObject) (GetInfoResponseObject, error) {
return GetInfo200JSONResponse{
Version: s.cfg.Version,
SearchEnabled: s.cfg.SearchEnabled,
RegistrationEnabled: s.cfg.RegistrationEnabled,
}, nil
}

58
api/v1/server_test.go Normal file
View File

@@ -0,0 +1,58 @@
package v1
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
)
type ServerTestSuite struct {
suite.Suite
db *database.DBManager
cfg *config.Config
srv *Server
}
func TestServer(t *testing.T) {
suite.Run(t, new(ServerTestSuite))
}
func (suite *ServerTestSuite) SetupTest() {
suite.cfg = &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: "/tmp",
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
suite.db = database.NewMgr(suite.cfg)
suite.srv = NewServer(suite.db, suite.cfg, nil)
}
func (suite *ServerTestSuite) TestNewServer() {
suite.NotNil(suite.srv)
suite.NotNil(suite.srv.mux)
suite.NotNil(suite.srv.db)
suite.NotNil(suite.srv.cfg)
}
func (suite *ServerTestSuite) TestServerServeHTTP() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}

157
api/v1/settings.go Normal file
View File

@@ -0,0 +1,157 @@
package v1
import (
"context"
"crypto/md5"
"fmt"
"reichard.io/antholume/database"
argon2id "github.com/alexedwards/argon2id"
)
// GET /settings
func (s *Server) GetSettings(ctx context.Context, request GetSettingsRequestObject) (GetSettingsResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetSettings401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
user, err := s.db.Queries.GetUser(ctx, auth.UserName)
if err != nil {
return GetSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
devices, err := s.db.Queries.GetDevices(ctx, auth.UserName)
if err != nil {
return GetSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
apiDevices := make([]Device, len(devices))
for i, device := range devices {
apiDevices[i] = Device{
Id: &device.ID,
DeviceName: &device.DeviceName,
CreatedAt: parseTimePtr(device.CreatedAt),
LastSynced: parseTimePtr(device.LastSynced),
}
}
response := SettingsResponse{
User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin},
Timezone: user.Timezone,
Devices: &apiDevices,
}
return GetSettings200JSONResponse(response), nil
}
// authorizeCredentials verifies if credentials are valid
func (s *Server) authorizeCredentials(ctx context.Context, username string, password string) bool {
user, err := s.db.Queries.GetUser(ctx, username)
if err != nil {
return false
}
// Try argon2 hash comparison
if match, err := argon2id.ComparePasswordAndHash(password, *user.Pass); err == nil && match {
return true
}
return false
}
// PUT /settings
func (s *Server) UpdateSettings(ctx context.Context, request UpdateSettingsRequestObject) (UpdateSettingsResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return UpdateSettings401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return UpdateSettings400JSONResponse{Code: 400, Message: "Request body is required"}, nil
}
user, err := s.db.Queries.GetUser(ctx, auth.UserName)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
updateParams := database.UpdateUserParams{
UserID: auth.UserName,
Admin: auth.IsAdmin,
}
// Update password if provided
if request.Body.NewPassword != nil {
if request.Body.Password == nil {
return UpdateSettings400JSONResponse{Code: 400, Message: "Current password is required to set new password"}, nil
}
// Verify current password - first try bcrypt (new format), then argon2, then MD5 (legacy format)
currentPasswordMatched := false
// Try argon2 (current format)
if !currentPasswordMatched {
currentPassword := fmt.Sprintf("%x", md5.Sum([]byte(*request.Body.Password)))
if match, err := argon2id.ComparePasswordAndHash(currentPassword, *user.Pass); err == nil && match {
currentPasswordMatched = true
}
}
if !currentPasswordMatched {
return UpdateSettings400JSONResponse{Code: 400, Message: "Invalid current password"}, nil
}
// Hash new password with argon2
newPassword := fmt.Sprintf("%x", md5.Sum([]byte(*request.Body.NewPassword)))
hashedPassword, err := argon2id.CreateHash(newPassword, argon2id.DefaultParams)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: "Failed to hash password"}, nil
}
updateParams.Password = &hashedPassword
}
// Update timezone if provided
if request.Body.Timezone != nil {
updateParams.Timezone = request.Body.Timezone
}
// If nothing to update, return error
if request.Body.NewPassword == nil && request.Body.Timezone == nil {
return UpdateSettings400JSONResponse{Code: 400, Message: "At least one field must be provided"}, nil
}
// Update user
_, err = s.db.Queries.UpdateUser(ctx, updateParams)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
// Get updated settings to return
user, err = s.db.Queries.GetUser(ctx, auth.UserName)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
devices, err := s.db.Queries.GetDevices(ctx, auth.UserName)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
apiDevices := make([]Device, len(devices))
for i, device := range devices {
apiDevices[i] = Device{
Id: &device.ID,
DeviceName: &device.DeviceName,
CreatedAt: parseTimePtr(device.CreatedAt),
LastSynced: parseTimePtr(device.LastSynced),
}
}
response := SettingsResponse{
User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin},
Timezone: user.Timezone,
Devices: &apiDevices,
}
return UpdateSettings200JSONResponse(response), nil
}

84
api/v1/utils.go Normal file
View File

@@ -0,0 +1,84 @@
package v1
import (
"encoding/json"
"net/http"
"net/url"
"strconv"
"time"
)
// writeJSON writes a JSON response (deprecated - used by tests only)
func writeJSON(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil {
writeJSONError(w, http.StatusInternalServerError, "Failed to encode response")
}
}
// writeJSONError writes a JSON error response (deprecated - used by tests only)
func writeJSONError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, ErrorResponse{
Code: status,
Message: message,
})
}
// QueryParams represents parsed query parameters (deprecated - used by tests only)
type QueryParams struct {
Page int64
Limit int64
Search *string
}
// parseQueryParams parses URL query parameters (deprecated - used by tests only)
func parseQueryParams(query url.Values, defaultLimit int64) QueryParams {
page, _ := strconv.ParseInt(query.Get("page"), 10, 64)
if page == 0 {
page = 1
}
limit, _ := strconv.ParseInt(query.Get("limit"), 10, 64)
if limit == 0 {
limit = defaultLimit
}
search := query.Get("search")
var searchPtr *string
if search != "" {
searchPtr = ptrOf("%" + search + "%")
}
return QueryParams{
Page: page,
Limit: limit,
Search: searchPtr,
}
}
// ptrOf returns a pointer to the given value
func ptrOf[T any](v T) *T {
return &v
}
// parseTime parses a string to time.Time
func parseTime(s string) time.Time {
t, _ := time.Parse(time.RFC3339, s)
if t.IsZero() {
t, _ = time.Parse("2006-01-02T15:04:05", s)
}
return t
}
// parseTimePtr parses an interface{} (from SQL) to *time.Time
func parseTimePtr(v interface{}) *time.Time {
if v == nil {
return nil
}
if s, ok := v.(string); ok {
t := parseTime(s)
if t.IsZero() {
return nil
}
return &t
}
return nil
}

76
api/v1/utils_test.go Normal file
View File

@@ -0,0 +1,76 @@
package v1
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
)
type UtilsTestSuite struct {
suite.Suite
}
func TestUtils(t *testing.T) {
suite.Run(t, new(UtilsTestSuite))
}
func (suite *UtilsTestSuite) TestWriteJSON() {
w := httptest.NewRecorder()
data := map[string]string{"test": "value"}
writeJSON(w, http.StatusOK, data)
suite.Equal("application/json", w.Header().Get("Content-Type"))
suite.Equal(http.StatusOK, w.Code)
var resp map[string]string
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal("value", resp["test"])
}
func (suite *UtilsTestSuite) TestWriteJSONError() {
w := httptest.NewRecorder()
writeJSONError(w, http.StatusBadRequest, "test error")
suite.Equal(http.StatusBadRequest, w.Code)
var resp ErrorResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal(http.StatusBadRequest, resp.Code)
suite.Equal("test error", resp.Message)
}
func (suite *UtilsTestSuite) TestParseQueryParams() {
query := make(map[string][]string)
query["page"] = []string{"2"}
query["limit"] = []string{"15"}
query["search"] = []string{"test"}
params := parseQueryParams(query, 9)
suite.Equal(int64(2), params.Page)
suite.Equal(int64(15), params.Limit)
suite.NotNil(params.Search)
}
func (suite *UtilsTestSuite) TestParseQueryParamsDefaults() {
query := make(map[string][]string)
params := parseQueryParams(query, 9)
suite.Equal(int64(1), params.Page)
suite.Equal(int64(9), params.Limit)
suite.Nil(params.Search)
}
func (suite *UtilsTestSuite) TestPtrOf() {
value := "test"
ptr := ptrOf(value)
suite.NotNil(ptr)
suite.Equal("test", *ptr)
}

View File

@@ -82,7 +82,8 @@
id="top-bar" id="top-bar"
class="transition-all duration-200 absolute z-10 bg-gray-100 dark:bg-gray-800 w-full px-2" class="transition-all duration-200 absolute z-10 bg-gray-100 dark:bg-gray-800 w-full px-2"
> >
<div class="w-full h-32 flex items-center justify-around relative"> <div class="max-h-[75vh] w-full flex flex-col items-center justify-around relative dark:text-white">
<div class="h-32">
<div class="text-gray-500 absolute top-6 left-4 flex flex-col gap-4"> <div class="text-gray-500 absolute top-6 left-4 flex flex-col gap-4">
<a href="#"> <a href="#">
<svg <svg
@@ -152,6 +153,8 @@
</div> </div>
</div> </div>
</div> </div>
<div id="toc" class="w-full text-center max-h-[50%] overflow-scroll no-scrollbar"></div>
</div>
</div> </div>
<div <div

View File

@@ -66,6 +66,56 @@ function populateMetadata(data) {
authorEl.innerText = data.author; authorEl.innerText = data.author;
} }
/**
* Populate the Table of Contents
**/
function populateTOC() {
if (!currentReader.book.navigation.toc) {
console.warn("[populateTOC] No TOC");
return;
}
let tocEl = document.querySelector("#toc");
if (!tocEl) {
console.warn("[populateTOC] No TOC Element");
return;
}
// Parse the Table of Contents
let parsedTOC = currentReader.book.navigation.toc.reduce((agg, item) => {
let sectionTitle = item.label.trim();
agg.push({ title: sectionTitle, href: item.href });
if (item.subitems.length == 0) {
return agg;
}
let allSubSections = item.subitems.map(item => {
let itemTitle = item.label.trim();
if (sectionTitle != "") {
itemTitle = sectionTitle + " - " + item.label.trim();
}
return { title: itemTitle, href: item.href };
});
agg.push(...allSubSections);
return agg;
}, [])
// Add Table of Contents to DOM
let listEl = document.createElement("ul");
listEl.classList.add("m-4")
parsedTOC.forEach(item => {
let listItem = document.createElement("li");
listItem.style.cursor = "pointer";
listItem.addEventListener("click", () => {
currentReader.rendition.display(item.href);
});
listItem.textContent = item.title;
listEl.appendChild(listItem);
});
tocEl.appendChild(listEl);
}
/** /**
* This is the main reader class. All functionality is wrapped in this class. * This is the main reader class. All functionality is wrapped in this class.
* Responsible for handling gesture / clicks, flushing progress & activity, * Responsible for handling gesture / clicks, flushing progress & activity,
@@ -439,6 +489,7 @@ class EBookReader {
// ------------------------------------------------ // // ------------------------------------------------ //
// ----------------- Swipe Helpers ---------------- // // ----------------- Swipe Helpers ---------------- //
// ------------------------------------------------ // // ------------------------------------------------ //
let disablePagination = false;
let touchStartX, let touchStartX,
touchStartY, touchStartY,
touchEndX, touchEndX,
@@ -459,25 +510,38 @@ class EBookReader {
} }
// Swipe Left // Swipe Left
if (touchEndX + drasticity < touchStartX) { if (!disablePagination && touchEndX + drasticity < touchStartX) {
nextPage(); nextPage();
} }
// Swipe Right // Swipe Right
if (touchEndX - drasticity > touchStartX) { if (!disablePagination && touchEndX - drasticity > touchStartX) {
prevPage(); prevPage();
} }
} }
function handleSwipeDown() { function handleSwipeDown() {
if (bottomBar.classList.contains("bottom-0")) if (bottomBar.classList.contains("bottom-0")) {
bottomBar.classList.remove("bottom-0"); bottomBar.classList.remove("bottom-0");
else topBar.classList.add("top-0"); disablePagination = false;
} else {
topBar.classList.add("top-0");
populateTOC()
disablePagination = true;
}
} }
function handleSwipeUp() { function handleSwipeUp() {
if (topBar.classList.contains("top-0")) topBar.classList.remove("top-0"); if (topBar.classList.contains("top-0")) {
else bottomBar.classList.add("bottom-0"); topBar.classList.remove("top-0");
disablePagination = false;
const tocEl = document.querySelector("#toc");
if (tocEl) tocEl.innerHTML = "";
} else {
bottomBar.classList.add("bottom-0");
disablePagination = true;
}
} }
this.rendition.hooks.render.register(function (doc, data) { this.rendition.hooks.render.register(function (doc, data) {
@@ -523,8 +587,8 @@ class EBookReader {
// Handle Event // Handle Event
if (yCoord < top) handleSwipeDown(); if (yCoord < top) handleSwipeDown();
else if (yCoord > bottom) handleSwipeUp(); else if (yCoord > bottom) handleSwipeUp();
else if (xCoord < left) prevPage(); else if (!disablePagination && xCoord < left) prevPage();
else if (xCoord > right) nextPage(); else if (!disablePagination && xCoord > right) nextPage();
else { else {
bottomBar.classList.remove("bottom-0"); bottomBar.classList.remove("bottom-0");
topBar.classList.remove("top-0"); topBar.classList.remove("top-0");
@@ -670,6 +734,9 @@ class EBookReader {
// Close Top Bar // Close Top Bar
document.querySelector(".close-top-bar").addEventListener("click", () => { document.querySelector(".close-top-bar").addEventListener("click", () => {
topBar.classList.remove("top-0"); topBar.classList.remove("top-0");
const tocEl = document.querySelector("#toc");
if (tocEl) tocEl.innerHTML = "";
}); });
} }
@@ -949,10 +1016,16 @@ class EBookReader {
**/ **/
async getXPathFromCFI(cfi) { async getXPathFromCFI(cfi) {
// Get DocFragment (Spine Index) // Get DocFragment (Spine Index)
let startCFI = cfi.replace("epubcfi(", ""); let cfiBaseMatch = cfi.match(/\(([^!]+)/);
if (!cfiBaseMatch) {
console.error("[getXPathFromCFI] No CFI Match");
return {};
}
let startCFI = cfiBaseMatch[1];
let docFragmentIndex = let docFragmentIndex =
this.book.spine.spineItems.find((item) => this.book.spine.spineItems.find((item) =>
startCFI.startsWith(item.cfiBase), item.cfiBase == startCFI
).index + 1; ).index + 1;
// Base Progress // Base Progress
@@ -1029,10 +1102,6 @@ class EBookReader {
return {}; return {};
} }
// Match Item Index
let indexMatch = xpath.match(/\.(\d+)$/);
let itemIndex = indexMatch ? parseInt(indexMatch[1]) : 0;
// Get Spine Item // Get Spine Item
let spinePosition = parseInt(fragMatch[1]) - 1; let spinePosition = parseInt(fragMatch[1]) - 1;
let sectionItem = this.book.spine.get(spinePosition); let sectionItem = this.book.spine.get(spinePosition);
@@ -1124,6 +1193,11 @@ class EBookReader {
let element = docSearch.iterateNext() || derivedSelectorElement; let element = docSearch.iterateNext() || derivedSelectorElement;
let cfi = sectionItem.cfiFromElement(element); let cfi = sectionItem.cfiFromElement(element);
// Hack - epub.js crashes sometimes when its a bare section with no element
// so just return the first.
if (cfi.endsWith("!/)"))
cfi = cfi.slice(0, -1) + "0)"
return { cfi, element }; return { cfi, element };
} }
@@ -1243,7 +1317,7 @@ class EBookReader {
let spineWC = await Promise.all( let spineWC = await Promise.all(
this.book.spine.spineItems.map(async (item) => { this.book.spine.spineItems.map(async (item) => {
let newDoc = await item.load(this.book.load.bind(this.book)); let newDoc = await item.load(this.book.load.bind(this.book));
let spineWords = newDoc.innerText.trim().split(/\s+/).length; let spineWords = (newDoc.innerText || "").trim().split(/\s+/).length;
item.wordCount = spineWords; item.wordCount = spineWords;
return spineWords; return spineWords;
}), }),
@@ -1271,14 +1345,3 @@ class EBookReader {
} }
document.addEventListener("DOMContentLoaded", initReader); document.addEventListener("DOMContentLoaded", initReader);
// WIP
async function getTOC() {
let toc = currentReader.book.navigation.toc;
// Alternatively:
// let nav = await currentReader.book.loaded.navigation;
// let toc = nav.toc;
currentReader.rendition.display(nav.toc[10].href);
}

File diff suppressed because one or more lines are too long

View File

@@ -99,7 +99,7 @@ const PRECACHE_ASSETS = [
// ----------------------- Helpers ----------------------- // // ----------------------- Helpers ----------------------- //
// ------------------------------------------------------- // // ------------------------------------------------------- //
function purgeCache() { async function purgeCache() {
console.log("[purgeCache] Purging Cache"); console.log("[purgeCache] Purging Cache");
return caches.keys().then(function (names) { return caches.keys().then(function (names) {
for (let name of names) caches.delete(name); for (let name of names) caches.delete(name);
@@ -136,7 +136,7 @@ async function handleFetch(event) {
const directive = ROUTES.find( const directive = ROUTES.find(
(item) => (item) =>
(item.route instanceof RegExp && url.match(item.route)) || (item.route instanceof RegExp && url.match(item.route)) ||
url == item.route url == item.route,
) || { type: CACHE_NEVER }; ) || { type: CACHE_NEVER };
// Get Fallback // Get Fallback
@@ -161,11 +161,11 @@ async function handleFetch(event) {
); );
case CACHE_UPDATE_SYNC: case CACHE_UPDATE_SYNC:
return updateCache(event.request).catch( return updateCache(event.request).catch(
(e) => currentCache || fallbackFunc(event) (e) => currentCache || fallbackFunc(event),
); );
case CACHE_UPDATE_ASYNC: case CACHE_UPDATE_ASYNC:
let newResponse = updateCache(event.request).catch((e) => let newResponse = updateCache(event.request).catch((e) =>
fallbackFunc(event) fallbackFunc(event),
); );
return currentCache || newResponse; return currentCache || newResponse;
@@ -192,7 +192,7 @@ function handleMessage(event) {
.filter( .filter(
(item) => (item) =>
item.startsWith("/documents/") || item.startsWith("/documents/") ||
item.startsWith("/reader/progress/") item.startsWith("/reader/progress/"),
); );
// Derive Unique IDs // Derive Unique IDs
@@ -200,8 +200,8 @@ function handleMessage(event) {
new Set( new Set(
docResources docResources
.filter((item) => item.startsWith("/documents/")) .filter((item) => item.startsWith("/documents/"))
.map((item) => item.split("/")[2]) .map((item) => item.split("/")[2]),
) ),
); );
/** /**
@@ -214,14 +214,14 @@ function handleMessage(event) {
.filter( .filter(
(id) => (id) =>
docResources.includes("/documents/" + id + "/file") && docResources.includes("/documents/" + id + "/file") &&
docResources.includes("/reader/progress/" + id) docResources.includes("/reader/progress/" + id),
) )
.map(async (id) => { .map(async (id) => {
let url = "/reader/progress/" + id; let url = "/reader/progress/" + id;
let currentCache = await caches.match(url); let currentCache = await caches.match(url);
let resp = await updateCache(url).catch((e) => currentCache); let resp = await updateCache(url).catch((e) => currentCache);
return resp.json(); return resp.json();
}) }),
); );
event.source.postMessage({ id, data: cachedDocuments }); event.source.postMessage({ id, data: cachedDocuments });
@@ -233,7 +233,7 @@ function handleMessage(event) {
Promise.all([ Promise.all([
cache.delete("/documents/" + data.id + "/file"), cache.delete("/documents/" + data.id + "/file"),
cache.delete("/reader/progress/" + data.id), cache.delete("/reader/progress/" + data.id),
]) ]),
) )
.then(() => event.source.postMessage({ id, data: "SUCCESS" })) .then(() => event.source.postMessage({ id, data: "SUCCESS" }))
.catch(() => event.source.postMessage({ id, data: "FAILURE" })); .catch(() => event.source.postMessage({ id, data: "FAILURE" }));

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.25.0 // sqlc v1.29.0
package database package database

View File

@@ -0,0 +1,151 @@
WITH grouped_activity AS (
SELECT
ga.user_id,
ga.document_id,
MAX(ga.created_at) AS created_at,
MAX(ga.start_time) AS start_time,
MIN(ga.start_percentage) AS start_percentage,
MAX(ga.end_percentage) AS end_percentage,
-- Total Duration & Percentage
SUM(ga.duration) AS total_time_seconds,
SUM(ga.end_percentage - ga.start_percentage) AS total_read_percentage,
-- Yearly Duration
SUM(
CASE
WHEN
ga.start_time >= DATE('now', '-1 year')
THEN ga.duration
ELSE 0
END
)
AS yearly_time_seconds,
-- Yearly Percentage
SUM(
CASE
WHEN
ga.start_time >= DATE('now', '-1 year')
THEN ga.end_percentage - ga.start_percentage
ELSE 0
END
)
AS yearly_read_percentage,
-- Monthly Duration
SUM(
CASE
WHEN
ga.start_time >= DATE('now', '-1 month')
THEN ga.duration
ELSE 0
END
)
AS monthly_time_seconds,
-- Monthly Percentage
SUM(
CASE
WHEN
ga.start_time >= DATE('now', '-1 month')
THEN ga.end_percentage - ga.start_percentage
ELSE 0
END
)
AS monthly_read_percentage,
-- Weekly Duration
SUM(
CASE
WHEN
ga.start_time >= DATE('now', '-7 days')
THEN ga.duration
ELSE 0
END
)
AS weekly_time_seconds,
-- Weekly Percentage
SUM(
CASE
WHEN
ga.start_time >= DATE('now', '-7 days')
THEN ga.end_percentage - ga.start_percentage
ELSE 0
END
)
AS weekly_read_percentage
FROM activity AS ga
GROUP BY ga.user_id, ga.document_id
),
current_progress AS (
SELECT
user_id,
document_id,
COALESCE((
SELECT dp.percentage
FROM document_progress AS dp
WHERE
dp.user_id = iga.user_id
AND dp.document_id = iga.document_id
ORDER BY dp.created_at DESC
LIMIT 1
), end_percentage) AS percentage
FROM grouped_activity AS iga
)
INSERT INTO document_user_statistics
SELECT
ga.document_id,
ga.user_id,
cp.percentage,
MAX(ga.start_time) AS last_read,
MAX(ga.created_at) AS last_seen,
SUM(ga.total_read_percentage) AS read_percentage,
-- All Time WPM
SUM(ga.total_time_seconds) AS total_time_seconds,
(CAST(COALESCE(d.words, 0.0) AS REAL) * SUM(ga.total_read_percentage))
AS total_words_read,
(CAST(COALESCE(d.words, 0.0) AS REAL) * SUM(ga.total_read_percentage))
/ (SUM(ga.total_time_seconds) / 60.0) AS total_wpm,
-- Yearly WPM
ga.yearly_time_seconds,
CAST(COALESCE(d.words, 0.0) AS REAL) * ga.yearly_read_percentage
AS yearly_words_read,
COALESCE(
(CAST(COALESCE(d.words, 0.0) AS REAL) * ga.yearly_read_percentage)
/ (ga.yearly_time_seconds / 60), 0.0)
AS yearly_wpm,
-- Monthly WPM
ga.monthly_time_seconds,
CAST(COALESCE(d.words, 0.0) AS REAL) * ga.monthly_read_percentage
AS monthly_words_read,
COALESCE(
(CAST(COALESCE(d.words, 0.0) AS REAL) * ga.monthly_read_percentage)
/ (ga.monthly_time_seconds / 60), 0.0)
AS monthly_wpm,
-- Weekly WPM
ga.weekly_time_seconds,
CAST(COALESCE(d.words, 0.0) AS REAL) * ga.weekly_read_percentage
AS weekly_words_read,
COALESCE(
(CAST(COALESCE(d.words, 0.0) AS REAL) * ga.weekly_read_percentage)
/ (ga.weekly_time_seconds / 60), 0.0)
AS weekly_wpm
FROM grouped_activity AS ga
INNER JOIN
current_progress AS cp
ON ga.user_id = cp.user_id AND ga.document_id = cp.document_id
INNER JOIN
documents AS d
ON ga.document_id = d.id
GROUP BY ga.document_id, ga.user_id
ORDER BY total_wpm DESC;

27
database/documents.go Normal file
View File

@@ -0,0 +1,27 @@
package database
import (
"context"
"fmt"
"reichard.io/antholume/pkg/ptr"
"reichard.io/antholume/pkg/sliceutils"
)
func (d *DBManager) GetDocument(ctx context.Context, docID, userID string) (*GetDocumentsWithStatsRow, error) {
documents, err := d.Queries.GetDocumentsWithStats(ctx, GetDocumentsWithStatsParams{
ID: ptr.Of(docID),
UserID: userID,
Limit: 1,
})
if err != nil {
return nil, err
}
document, found := sliceutils.First(documents)
if !found {
return nil, fmt.Errorf("document not found: %s", docID)
}
return &document, nil
}

115
database/documents_test.go Normal file
View File

@@ -0,0 +1,115 @@
package database
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/suite"
"reichard.io/antholume/config"
)
type DocumentsTestSuite struct {
suite.Suite
dbm *DBManager
}
func TestDocuments(t *testing.T) {
suite.Run(t, new(DocumentsTestSuite))
}
func (suite *DocumentsTestSuite) SetupTest() {
cfg := config.Config{
DBType: "memory",
}
suite.dbm = NewMgr(&cfg)
// Create Document
_, err := suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{
ID: documentID,
Title: &documentTitle,
Author: &documentAuthor,
Words: &documentWords,
})
suite.NoError(err)
}
// DOCUMENT - TODO:
// - 󰊕 (q *Queries) GetDocumentProgress
// - 󰊕 (q *Queries) GetDocumentWithStats
// - 󰊕 (q *Queries) GetDocumentsSize
// - 󰊕 (q *Queries) GetDocumentsWithStats
// - 󰊕 (q *Queries) GetMissingDocuments
func (suite *DocumentsTestSuite) TestGetDocument() {
doc, err := suite.dbm.Queries.GetDocument(context.Background(), documentID)
suite.Nil(err, "should have nil err")
suite.Equal(documentID, doc.ID, "should have changed the document")
}
func (suite *DocumentsTestSuite) TestUpsertDocument() {
testDocID := "docid1"
doc, err := suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{
ID: testDocID,
Title: &documentTitle,
Author: &documentAuthor,
})
suite.Nil(err, "should have nil err")
suite.Equal(testDocID, doc.ID, "should have document id")
suite.Equal(documentTitle, *doc.Title, "should have document title")
suite.Equal(documentAuthor, *doc.Author, "should have document author")
}
func (suite *DocumentsTestSuite) TestDeleteDocument() {
changed, err := suite.dbm.Queries.DeleteDocument(context.Background(), documentID)
suite.Nil(err, "should have nil err")
suite.Equal(int64(1), changed, "should have changed the document")
doc, err := suite.dbm.Queries.GetDocument(context.Background(), documentID)
suite.Nil(err, "should have nil err")
suite.True(doc.Deleted, "should have deleted the document")
}
func (suite *DocumentsTestSuite) TestGetDeletedDocuments() {
changed, err := suite.dbm.Queries.DeleteDocument(context.Background(), documentID)
suite.Nil(err, "should have nil err")
suite.Equal(int64(1), changed, "should have changed the document")
deletedDocs, err := suite.dbm.Queries.GetDeletedDocuments(context.Background(), []string{documentID})
suite.Nil(err, "should have nil err")
suite.Len(deletedDocs, 1, "should have one deleted document")
}
// TODO - Convert GetWantedDocuments -> (sqlc.slice('document_ids'));
func (suite *DocumentsTestSuite) TestGetWantedDocuments() {
wantedDocs, err := suite.dbm.Queries.GetWantedDocuments(context.Background(), fmt.Sprintf("[\"%s\"]", documentID))
suite.Nil(err, "should have nil err")
suite.Len(wantedDocs, 1, "should have one wanted document")
}
func (suite *DocumentsTestSuite) TestGetMissingDocuments() {
// Create Document
_, err := suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{
ID: documentID,
Filepath: &documentFilepath,
})
suite.NoError(err)
missingDocs, err := suite.dbm.Queries.GetMissingDocuments(context.Background(), []string{documentID})
suite.Nil(err, "should have nil err")
suite.Len(missingDocs, 0, "should have no wanted document")
missingDocs, err = suite.dbm.Queries.GetMissingDocuments(context.Background(), []string{"other"})
suite.Nil(err, "should have nil err")
suite.Len(missingDocs, 1, "should have one missing document")
suite.Equal(documentID, missingDocs[0].ID, "should have missing doc")
// TODO - https://github.com/sqlc-dev/sqlc/issues/3451
// missingDocs, err = suite.dbm.Queries.GetMissingDocuments(context.Background(), []string{})
// suite.Nil(err, "should have nil err")
// suite.Len(missingDocs, 1, "should have one missing document")
// suite.Equal(documentID, missingDocs[0].ID, "should have missing doc")
}

View File

@@ -3,22 +3,22 @@ package database
import ( import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver"
"embed" "embed"
_ "embed" "errors"
"fmt" "fmt"
"path/filepath" "path/filepath"
"time" "time"
"github.com/pressly/goose/v3" "github.com/pressly/goose/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
_ "modernc.org/sqlite" sqlite "modernc.org/sqlite"
"reichard.io/antholume/config" "reichard.io/antholume/config"
_ "reichard.io/antholume/database/migrations" _ "reichard.io/antholume/database/migrations"
) )
type DBManager struct { type DBManager struct {
DB *sql.DB DB *sql.DB
Ctx context.Context
Queries *Queries Queries *Queries
cfg *config.Config cfg *config.Config
} }
@@ -26,26 +26,43 @@ type DBManager struct {
//go:embed schema.sql //go:embed schema.sql
var ddl string var ddl string
//go:embed user_streaks.sql
var user_streaks string
//go:embed document_user_statistics.sql
var document_user_statistics string
//go:embed migrations/* //go:embed migrations/*
var migrations embed.FS var migrations embed.FS
// Returns an initialized manager // Register scalar sqlite function on init
func init() {
sqlite.MustRegisterFunction("LOCAL_TIME", &sqlite.FunctionImpl{
NArgs: 2,
Deterministic: true,
Scalar: localTime,
})
sqlite.MustRegisterFunction("LOCAL_DATE", &sqlite.FunctionImpl{
NArgs: 2,
Deterministic: true,
Scalar: localDate,
})
}
// NewMgr Returns an initialized manager
func NewMgr(c *config.Config) *DBManager { func NewMgr(c *config.Config) *DBManager {
// Create Manager // Create Manager
dbm := &DBManager{ dbm := &DBManager{cfg: c}
Ctx: context.Background(),
cfg: c,
}
if err := dbm.init(); err != nil { if err := dbm.init(context.Background()); err != nil {
log.Panic("Unable to init DB") log.Panic("Unable to init DB")
} }
return dbm return dbm
} }
// Init manager // init loads the DB manager
func (dbm *DBManager) init() error { func (dbm *DBManager) init(ctx context.Context) error {
// Build DB Location // Build DB Location
var dbLocation string var dbLocation string
switch dbm.cfg.DBType { switch dbm.cfg.DBType {
@@ -91,20 +108,22 @@ func (dbm *DBManager) init() error {
} }
// Update settings // Update settings
err = dbm.updateSettings() err = dbm.updateSettings(ctx)
if err != nil { if err != nil {
log.Panicf("Error running DB settings update: %v", err) log.Panicf("Error running DB settings update: %v", err)
return err return err
} }
// Cache tables // Cache tables
go dbm.CacheTempTables() if err := dbm.CacheTempTables(ctx); err != nil {
log.Warn("Refreshing temp table cache failed: ", err)
}
return nil return nil
} }
// Reload manager (close DB & reinit) // Reload closes the DB & reinits
func (dbm *DBManager) Reload() error { func (dbm *DBManager) Reload(ctx context.Context) error {
// Close handle // Close handle
err := dbm.DB.Close() err := dbm.DB.Close()
if err != nil { if err != nil {
@@ -112,30 +131,23 @@ func (dbm *DBManager) Reload() error {
} }
// Reinit DB // Reinit DB
if err := dbm.init(); err != nil { if err := dbm.init(ctx); err != nil {
return err return err
} }
return nil return nil
} }
func (dbm *DBManager) CacheTempTables() error { // CacheTempTables clears existing statistics and recalculates
func (dbm *DBManager) CacheTempTables(ctx context.Context) error {
start := time.Now() start := time.Now()
user_streaks_sql := ` if _, err := dbm.DB.ExecContext(ctx, user_streaks); err != nil {
DELETE FROM user_streaks;
INSERT INTO user_streaks SELECT * FROM view_user_streaks;
`
if _, err := dbm.DB.ExecContext(dbm.Ctx, user_streaks_sql); err != nil {
return err return err
} }
log.Debug("Cached 'user_streaks' in: ", time.Since(start)) log.Debug("Cached 'user_streaks' in: ", time.Since(start))
start = time.Now() start = time.Now()
document_statistics_sql := ` if _, err := dbm.DB.ExecContext(ctx, document_user_statistics); err != nil {
DELETE FROM document_user_statistics;
INSERT INTO document_user_statistics SELECT * FROM view_document_user_statistics;
`
if _, err := dbm.DB.ExecContext(dbm.Ctx, document_statistics_sql); err != nil {
return err return err
} }
log.Debug("Cached 'document_user_statistics' in: ", time.Since(start)) log.Debug("Cached 'document_user_statistics' in: ", time.Since(start))
@@ -143,7 +155,9 @@ func (dbm *DBManager) CacheTempTables() error {
return nil return nil
} }
func (dbm *DBManager) updateSettings() error { // updateSettings ensures that we're enforcing foreign keys and enable journal
// mode.
func (dbm *DBManager) updateSettings(ctx context.Context) error {
// Set SQLite PRAGMA Settings // Set SQLite PRAGMA Settings
pragmaQuery := ` pragmaQuery := `
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
@@ -155,7 +169,7 @@ func (dbm *DBManager) updateSettings() error {
} }
// Update Antholume Version in DB // Update Antholume Version in DB
if _, err := dbm.Queries.UpdateSettings(dbm.Ctx, UpdateSettingsParams{ if _, err := dbm.Queries.UpdateSettings(ctx, UpdateSettingsParams{
Name: "version", Name: "version",
Value: dbm.cfg.Version, Value: dbm.cfg.Version,
}); err != nil { }); err != nil {
@@ -166,9 +180,10 @@ func (dbm *DBManager) updateSettings() error {
return nil return nil
} }
// performMigrations runs all migrations
func (dbm *DBManager) performMigrations(isNew bool) error { func (dbm *DBManager) performMigrations(isNew bool) error {
// Create context // Create context
ctx := context.WithValue(context.Background(), "isNew", isNew) ctx := context.WithValue(context.Background(), "isNew", isNew) // nolint
// Set DB migration // Set DB migration
goose.SetBaseFS(migrations) goose.SetBaseFS(migrations)
@@ -182,6 +197,7 @@ func (dbm *DBManager) performMigrations(isNew bool) error {
return goose.UpContext(ctx, dbm.DB, "migrations") return goose.UpContext(ctx, dbm.DB, "migrations")
} }
// isEmpty determines whether the database is empty
func isEmpty(db *sql.DB) (bool, error) { func isEmpty(db *sql.DB) (bool, error) {
var tableCount int var tableCount int
err := db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table';").Scan(&tableCount) err := db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table';").Scan(&tableCount)
@@ -190,3 +206,53 @@ func isEmpty(db *sql.DB) (bool, error) {
} }
return tableCount == 0, nil return tableCount == 0, nil
} }
// localTime is a custom SQL function that is registered as LOCAL_TIME in the init function
func localTime(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) {
timeStr, ok := args[0].(string)
if !ok {
return nil, errors.New("both arguments to TZTime must be strings")
}
timeZoneStr, ok := args[1].(string)
if !ok {
return nil, errors.New("both arguments to TZTime must be strings")
}
timeZone, err := time.LoadLocation(timeZoneStr)
if err != nil {
return nil, errors.New("unable to parse timezone")
}
formattedTime, err := time.ParseInLocation(time.RFC3339, timeStr, time.UTC)
if err != nil {
return nil, errors.New("unable to parse time")
}
return formattedTime.In(timeZone).Format(time.RFC3339), nil
}
// localDate is a custom SQL function that is registered as LOCAL_DATE in the init function
func localDate(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) {
timeStr, ok := args[0].(string)
if !ok {
return nil, errors.New("both arguments to TZTime must be strings")
}
timeZoneStr, ok := args[1].(string)
if !ok {
return nil, errors.New("both arguments to TZTime must be strings")
}
timeZone, err := time.LoadLocation(timeZoneStr)
if err != nil {
return nil, errors.New("unable to parse timezone")
}
formattedTime, err := time.ParseInLocation(time.RFC3339, timeStr, time.UTC)
if err != nil {
return nil, errors.New("unable to parse time")
}
return formattedTime.In(timeZone).Format("2006-01-02"), nil
}

View File

@@ -1,102 +1,78 @@
package database package database
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite"
"reichard.io/antholume/config" "reichard.io/antholume/config"
"reichard.io/antholume/utils" "reichard.io/antholume/utils"
) )
type databaseTest struct { var (
*testing.T userID string = "testUser"
userPass string = "testPass"
deviceID string = "testDevice"
deviceName string = "testDeviceName"
documentID string = "testDocument"
documentTitle string = "testTitle"
documentAuthor string = "testAuthor"
documentFilepath string = "./testPath.epub"
documentWords int64 = 5000
)
type DatabaseTestSuite struct {
suite.Suite
dbm *DBManager dbm *DBManager
} }
var userID string = "testUser" func TestDatabase(t *testing.T) {
var userPass string = "testPass" suite.Run(t, new(DatabaseTestSuite))
var deviceID string = "testDevice" }
var deviceName string = "testDeviceName"
var documentID string = "testDocument"
var documentTitle string = "testTitle"
var documentAuthor string = "testAuthor"
func TestNewMgr(t *testing.T) { // PROGRESS - TODO:
// - 󰊕 (q *Queries) GetProgress
// - 󰊕 (q *Queries) UpdateProgress
func (suite *DatabaseTestSuite) SetupTest() {
cfg := config.Config{ cfg := config.Config{
DBType: "memory", DBType: "memory",
} }
dbm := NewMgr(&cfg) suite.dbm = NewMgr(&cfg)
assert.NotNil(t, dbm, "should not have nil dbm")
t.Run("Database", func(t *testing.T) {
dt := databaseTest{t, dbm}
dt.TestUser()
dt.TestDocument()
dt.TestDevice()
dt.TestActivity()
dt.TestDailyReadStats()
})
}
func (dt *databaseTest) TestUser() {
dt.Run("User", func(t *testing.T) {
// Generate Auth Hash
rawAuthHash, err := utils.GenerateToken(64)
assert.Nil(t, err, "should have nil err")
// Create User
rawAuthHash, _ := utils.GenerateToken(64)
authHash := fmt.Sprintf("%x", rawAuthHash) authHash := fmt.Sprintf("%x", rawAuthHash)
changed, err := dt.dbm.Queries.CreateUser(dt.dbm.Ctx, CreateUserParams{ _, err := suite.dbm.Queries.CreateUser(context.Background(), CreateUserParams{
ID: userID, ID: userID,
Pass: &userPass, Pass: &userPass,
AuthHash: &authHash, AuthHash: &authHash,
}) })
suite.NoError(err)
assert.Nil(t, err, "should have nil err") // Create Document
assert.Equal(t, int64(1), changed) _, err = suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{
user, err := dt.dbm.Queries.GetUser(dt.dbm.Ctx, userID)
assert.Nil(t, err, "should have nil err")
assert.Equal(t, userPass, *user.Pass)
})
}
func (dt *databaseTest) TestDocument() {
dt.Run("Document", func(t *testing.T) {
doc, err := dt.dbm.Queries.UpsertDocument(dt.dbm.Ctx, UpsertDocumentParams{
ID: documentID, ID: documentID,
Title: &documentTitle, Title: &documentTitle,
Author: &documentAuthor, Author: &documentAuthor,
Filepath: &documentFilepath,
Words: &documentWords,
}) })
suite.NoError(err)
assert.Nil(t, err, "should have nil err") // Create Device
assert.Equal(t, documentID, doc.ID, "should have document id") _, err = suite.dbm.Queries.UpsertDevice(context.Background(), UpsertDeviceParams{
assert.Equal(t, documentTitle, *doc.Title, "should have document title")
assert.Equal(t, documentAuthor, *doc.Author, "should have document author")
})
}
func (dt *databaseTest) TestDevice() {
dt.Run("Device", func(t *testing.T) {
device, err := dt.dbm.Queries.UpsertDevice(dt.dbm.Ctx, UpsertDeviceParams{
ID: deviceID, ID: deviceID,
UserID: userID, UserID: userID,
DeviceName: deviceName, DeviceName: deviceName,
}) })
suite.NoError(err)
assert.Nil(t, err, "should have nil err") // Create Activity
assert.Equal(t, deviceID, device.ID, "should have device id")
assert.Equal(t, userID, device.UserID, "should have user id")
assert.Equal(t, deviceName, device.DeviceName, "should have device name")
})
}
func (dt *databaseTest) TestActivity() {
dt.Run("Progress", func(t *testing.T) {
// 10 Activities, 10 Days
end := time.Now() end := time.Now()
start := end.AddDate(0, 0, -9) start := end.AddDate(0, 0, -9)
var counter int64 = 0 var counter int64 = 0
@@ -105,7 +81,7 @@ func (dt *databaseTest) TestActivity() {
counter += 1 counter += 1
// Add Item // Add Item
activity, err := dt.dbm.Queries.AddActivity(dt.dbm.Ctx, AddActivityParams{ activity, err := suite.dbm.Queries.AddActivity(context.Background(), AddActivityParams{
DocumentID: documentID, DocumentID: documentID,
DeviceID: deviceID, DeviceID: deviceID,
UserID: userID, UserID: userID,
@@ -115,25 +91,50 @@ func (dt *databaseTest) TestActivity() {
EndPercentage: float64(counter+1) / 100.0, EndPercentage: float64(counter+1) / 100.0,
}) })
assert.Nil(t, err, fmt.Sprintf("[%d] should have nil err for add activity", counter)) suite.Nil(err, fmt.Sprintf("[%d] should have nil err for add activity", counter))
assert.Equal(t, counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter)) suite.Equal(counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter))
} }
// Initiate Cache // Initiate Cache
dt.dbm.CacheTempTables() err = suite.dbm.CacheTempTables(context.Background())
suite.NoError(err)
}
// DEVICES - TODO:
// - 󰊕 (q *Queries) GetDevice
// - 󰊕 (q *Queries) GetDevices
// - 󰊕 (q *Queries) UpsertDevice
func (suite *DatabaseTestSuite) TestDevice() {
testDevice := "dev123"
device, err := suite.dbm.Queries.UpsertDevice(context.Background(), UpsertDeviceParams{
ID: testDevice,
UserID: userID,
DeviceName: deviceName,
})
suite.Nil(err, "should have nil err")
suite.Equal(testDevice, device.ID, "should have device id")
suite.Equal(userID, device.UserID, "should have user id")
suite.Equal(deviceName, device.DeviceName, "should have device name")
}
// ACTIVITY - TODO:
// - 󰊕 (q *Queries) AddActivity
// - 󰊕 (q *Queries) GetActivity
// - 󰊕 (q *Queries) GetLastActivity
func (suite *DatabaseTestSuite) TestActivity() {
// Validate Exists // Validate Exists
existsRows, err := dt.dbm.Queries.GetActivity(dt.dbm.Ctx, GetActivityParams{ existsRows, err := suite.dbm.Queries.GetActivity(context.Background(), GetActivityParams{
UserID: userID, UserID: userID,
Offset: 0, Offset: 0,
Limit: 50, Limit: 50,
}) })
assert.Nil(t, err, "should have nil err for get activity") suite.Nil(err, "should have nil err for get activity")
assert.Len(t, existsRows, 10, "should have correct number of rows get activity") suite.Len(existsRows, 10, "should have correct number of rows get activity")
// Validate Doesn't Exist // Validate Doesn't Exist
doesntExistsRows, err := dt.dbm.Queries.GetActivity(dt.dbm.Ctx, GetActivityParams{ doesntExistsRows, err := suite.dbm.Queries.GetActivity(context.Background(), GetActivityParams{
UserID: userID, UserID: userID,
DocumentID: "unknownDoc", DocumentID: "unknownDoc",
DocFilter: true, DocFilter: true,
@@ -141,28 +142,30 @@ func (dt *databaseTest) TestActivity() {
Limit: 50, Limit: 50,
}) })
assert.Nil(t, err, "should have nil err for get activity") suite.Nil(err, "should have nil err for get activity")
assert.Len(t, doesntExistsRows, 0, "should have no rows") suite.Len(doesntExistsRows, 0, "should have no rows")
})
} }
func (dt *databaseTest) TestDailyReadStats() { // MISC - TODO:
dt.Run("DailyReadStats", func(t *testing.T) { // - 󰊕 (q *Queries) AddMetadata
readStats, err := dt.dbm.Queries.GetDailyReadStats(dt.dbm.Ctx, userID) // - 󰊕 (q *Queries) GetDailyReadStats
// - 󰊕 (q *Queries) GetDatabaseInfo
// - 󰊕 (q *Queries) UpdateSettings
func (suite *DatabaseTestSuite) TestGetDailyReadStats() {
readStats, err := suite.dbm.Queries.GetDailyReadStats(context.Background(), userID)
assert.Nil(t, err, "should have nil err") suite.Nil(err, "should have nil err")
assert.Len(t, readStats, 30, "should have length of 30") suite.Len(readStats, 30, "should have length of 30")
// Validate 1 Minute / Day - Last 10 Days // Validate 1 Minute / Day - Last 10 Days
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
stat := readStats[i] stat := readStats[i]
assert.Equal(t, int64(1), stat.MinutesRead, "should have one minute read") suite.Equal(int64(1), stat.MinutesRead, "should have one minute read")
} }
// Validate 0 Minute / Day - Remaining 20 Days // Validate 0 Minute / Day - Remaining 20 Days
for i := 10; i < 30; i++ { for i := 10; i < 30; i++ {
stat := readStats[i] stat := readStats[i]
assert.Equal(t, int64(0), stat.MinutesRead, "should have zero minutes read") suite.Equal(int64(0), stat.MinutesRead, "should have zero minutes read")
} }
})
} }

View File

@@ -0,0 +1,58 @@
package migrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationContext(upUserTimezone, downUserTimezone)
}
func upUserTimezone(ctx context.Context, tx *sql.Tx) error {
// Determine if we have a new DB or not
isNew := ctx.Value("isNew").(bool)
if isNew {
return nil
}
// Copy table & create column
_, err := tx.Exec(`
-- Copy Table
CREATE TABLE temp_users AS SELECT * FROM users;
ALTER TABLE temp_users DROP COLUMN time_offset;
ALTER TABLE temp_users ADD COLUMN timezone TEXT;
UPDATE temp_users SET timezone = 'Europe/London';
-- Clean Table
DELETE FROM users;
ALTER TABLE users DROP COLUMN time_offset;
ALTER TABLE users ADD COLUMN timezone TEXT NOT NULL DEFAULT 'Europe/London';
-- Copy Temp Table -> Clean Table
INSERT INTO users SELECT * FROM temp_users;
-- Drop Temp Table
DROP TABLE temp_users;
`)
if err != nil {
return err
}
return nil
}
func downUserTimezone(ctx context.Context, tx *sql.Tx) error {
// Update column name & value
_, err := tx.Exec(`
ALTER TABLE users RENAME COLUMN timezone TO time_offset;
UPDATE users SET time_offset = '0 hours';
`)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,38 @@
package migrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationContext(upImportBasepath, downImportBasepath)
}
func upImportBasepath(ctx context.Context, tx *sql.Tx) error {
// Determine if we have a new DB or not
isNew := ctx.Value("isNew").(bool)
if isNew {
return nil
}
// Add basepath column
_, err := tx.Exec(`ALTER TABLE documents ADD COLUMN basepath TEXT;`)
if err != nil {
return err
}
// This code is executed when the migration is applied.
return nil
}
func downImportBasepath(ctx context.Context, tx *sql.Tx) error {
// Drop basepath column
_, err := tx.Exec("ALTER documents DROP COLUMN basepath;")
if err != nil {
return err
}
return nil
}

View File

@@ -1,11 +1,9 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.25.0 // sqlc v1.29.0
package database package database
import ()
type Activity struct { type Activity struct {
ID int64 `json:"id"` ID int64 `json:"id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
@@ -30,6 +28,7 @@ type Device struct {
type Document struct { type Document struct {
ID string `json:"id"` ID string `json:"id"`
Md5 *string `json:"md5"` Md5 *string `json:"md5"`
Basepath *string `json:"basepath"`
Filepath *string `json:"filepath"` Filepath *string `json:"filepath"`
Coverfile *string `json:"coverfile"` Coverfile *string `json:"coverfile"`
Title *string `json:"title"` Title *string `json:"title"`
@@ -63,6 +62,7 @@ type DocumentUserStatistic struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
Percentage float64 `json:"percentage"` Percentage float64 `json:"percentage"`
LastRead string `json:"last_read"` LastRead string `json:"last_read"`
LastSeen string `json:"last_seen"`
ReadPercentage float64 `json:"read_percentage"` ReadPercentage float64 `json:"read_percentage"`
TotalTimeSeconds int64 `json:"total_time_seconds"` TotalTimeSeconds int64 `json:"total_time_seconds"`
TotalWordsRead int64 `json:"total_words_read"` TotalWordsRead int64 `json:"total_words_read"`
@@ -78,7 +78,7 @@ type DocumentUserStatistic struct {
WeeklyWpm float64 `json:"weekly_wpm"` WeeklyWpm float64 `json:"weekly_wpm"`
} }
type Metadatum struct { type Metadata struct {
ID int64 `json:"id"` ID int64 `json:"id"`
DocumentID string `json:"document_id"` DocumentID string `json:"document_id"`
Title *string `json:"title"` Title *string `json:"title"`
@@ -103,7 +103,7 @@ type User struct {
Pass *string `json:"-"` Pass *string `json:"-"`
AuthHash *string `json:"auth_hash"` AuthHash *string `json:"auth_hash"`
Admin bool `json:"-"` Admin bool `json:"-"`
TimeOffset *string `json:"time_offset"` Timezone *string `json:"timezone"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
} }
@@ -116,35 +116,8 @@ type UserStreak struct {
CurrentStreak int64 `json:"current_streak"` CurrentStreak int64 `json:"current_streak"`
CurrentStreakStartDate string `json:"current_streak_start_date"` CurrentStreakStartDate string `json:"current_streak_start_date"`
CurrentStreakEndDate string `json:"current_streak_end_date"` CurrentStreakEndDate string `json:"current_streak_end_date"`
} LastTimezone string `json:"last_timezone"`
LastSeen string `json:"last_seen"`
type ViewDocumentUserStatistic struct { LastRecord string `json:"last_record"`
DocumentID string `json:"document_id"` LastCalculated string `json:"last_calculated"`
UserID string `json:"user_id"`
Percentage float64 `json:"percentage"`
LastRead interface{} `json:"last_read"`
ReadPercentage *float64 `json:"read_percentage"`
TotalTimeSeconds *float64 `json:"total_time_seconds"`
TotalWordsRead interface{} `json:"total_words_read"`
TotalWpm int64 `json:"total_wpm"`
YearlyTimeSeconds *float64 `json:"yearly_time_seconds"`
YearlyWordsRead interface{} `json:"yearly_words_read"`
YearlyWpm interface{} `json:"yearly_wpm"`
MonthlyTimeSeconds *float64 `json:"monthly_time_seconds"`
MonthlyWordsRead interface{} `json:"monthly_words_read"`
MonthlyWpm interface{} `json:"monthly_wpm"`
WeeklyTimeSeconds *float64 `json:"weekly_time_seconds"`
WeeklyWordsRead interface{} `json:"weekly_words_read"`
WeeklyWpm interface{} `json:"weekly_wpm"`
}
type ViewUserStreak struct {
UserID string `json:"user_id"`
Window string `json:"window"`
MaxStreak interface{} `json:"max_streak"`
MaxStreakStartDate interface{} `json:"max_streak_start_date"`
MaxStreakEndDate interface{} `json:"max_streak_end_date"`
CurrentStreak interface{} `json:"current_streak"`
CurrentStreakStartDate interface{} `json:"current_streak_start_date"`
CurrentStreakEndDate interface{} `json:"current_streak_end_date"`
} }

View File

@@ -30,6 +30,9 @@ INSERT INTO users (id, pass, auth_hash, admin)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
-- name: DeleteUser :execrows
DELETE FROM users WHERE id = $id;
-- name: DeleteDocument :execrows -- name: DeleteDocument :execrows
UPDATE documents UPDATE documents
SET SET
@@ -64,7 +67,7 @@ WITH filtered_activity AS (
SELECT SELECT
document_id, document_id,
device_id, device_id,
CAST(STRFTIME('%Y-%m-%d %H:%M:%S', activity.start_time, users.time_offset) AS TEXT) AS start_time, LOCAL_TIME(activity.start_time, users.timezone) AS start_time,
title, title,
author, author,
duration, duration,
@@ -77,7 +80,7 @@ LEFT JOIN users ON users.id = activity.user_id;
-- name: GetDailyReadStats :many -- name: GetDailyReadStats :many
WITH RECURSIVE last_30_days AS ( WITH RECURSIVE last_30_days AS (
SELECT DATE('now', time_offset) AS date SELECT LOCAL_DATE(STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now'), timezone) AS date
FROM users WHERE users.id = $user_id FROM users WHERE users.id = $user_id
UNION ALL UNION ALL
SELECT DATE(date, '-1 days') SELECT DATE(date, '-1 days')
@@ -96,7 +99,7 @@ filtered_activity AS (
activity_days AS ( activity_days AS (
SELECT SELECT
SUM(duration) AS seconds_read, SUM(duration) AS seconds_read,
DATE(start_time, time_offset) AS day LOCAL_DATE(start_time, timezone) AS day
FROM filtered_activity AS activity FROM filtered_activity AS activity
LEFT JOIN users ON users.id = activity.user_id LEFT JOIN users ON users.id = activity.user_id
GROUP BY day GROUP BY day
@@ -135,8 +138,8 @@ WHERE id = $device_id LIMIT 1;
SELECT SELECT
devices.id, devices.id,
devices.device_name, devices.device_name,
CAST(STRFTIME('%Y-%m-%d %H:%M:%S', devices.created_at, users.time_offset) AS TEXT) AS created_at, LOCAL_TIME(devices.created_at, users.timezone) AS created_at,
CAST(STRFTIME('%Y-%m-%d %H:%M:%S', devices.last_synced, users.time_offset) AS TEXT) AS last_synced LOCAL_TIME(devices.last_synced, users.timezone) AS last_synced
FROM devices FROM devices
JOIN users ON users.id = devices.user_id JOIN users ON users.id = devices.user_id
WHERE users.id = $user_id WHERE users.id = $user_id
@@ -160,42 +163,6 @@ ORDER BY
DESC DESC
LIMIT 1; LIMIT 1;
-- name: GetDocumentWithStats :one
SELECT
docs.id,
docs.title,
docs.author,
docs.description,
docs.isbn10,
docs.isbn13,
docs.filepath,
docs.words,
CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm,
COALESCE(dus.read_percentage, 0) AS read_percentage,
COALESCE(dus.total_time_seconds, 0) AS total_time_seconds,
STRFTIME('%Y-%m-%d %H:%M:%S', COALESCE(dus.last_read, "1970-01-01"), users.time_offset)
AS last_read,
ROUND(CAST(CASE
WHEN dus.percentage IS NULL THEN 0.0
WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0
ELSE dus.percentage * 100.0
END AS REAL), 2) AS percentage,
CAST(CASE
WHEN dus.total_time_seconds IS NULL THEN 0.0
ELSE
CAST(dus.total_time_seconds AS REAL)
/ (dus.read_percentage * 100.0)
END AS INTEGER) AS seconds_per_percent
FROM documents AS docs
LEFT JOIN users ON users.id = $user_id
LEFT JOIN
document_user_statistics AS dus
ON dus.document_id = docs.id AND dus.user_id = $user_id
WHERE users.id = $user_id
AND docs.id = $document_id
LIMIT 1;
-- name: GetDocuments :many -- name: GetDocuments :many
SELECT * FROM documents SELECT * FROM documents
ORDER BY created_at DESC ORDER BY created_at DESC
@@ -226,33 +193,32 @@ SELECT
CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm, CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm,
COALESCE(dus.read_percentage, 0) AS read_percentage, COALESCE(dus.read_percentage, 0) AS read_percentage,
COALESCE(dus.total_time_seconds, 0) AS total_time_seconds, COALESCE(dus.total_time_seconds, 0) AS total_time_seconds,
STRFTIME('%Y-%m-%d %H:%M:%S', COALESCE(dus.last_read, "1970-01-01"), users.time_offset) STRFTIME('%Y-%m-%d %H:%M:%S', LOCAL_TIME(COALESCE(dus.last_read, STRFTIME('%Y-%m-%dT%H:%M:%SZ', 0, 'unixepoch')), users.timezone))
AS last_read, AS last_read,
ROUND(CAST(CASE ROUND(CAST(CASE
WHEN dus.percentage IS NULL THEN 0.0 WHEN dus.percentage IS NULL THEN 0.0
WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0 WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0
ELSE dus.percentage * 100.0 ELSE dus.percentage * 100.0
END AS REAL), 2) AS percentage, END AS REAL), 2) AS percentage,
CAST(CASE
CASE
WHEN dus.total_time_seconds IS NULL THEN 0.0 WHEN dus.total_time_seconds IS NULL THEN 0.0
ELSE ELSE
ROUND(
CAST(dus.total_time_seconds AS REAL) CAST(dus.total_time_seconds AS REAL)
/ (dus.read_percentage * 100.0) / (dus.read_percentage * 100.0)
) END AS INTEGER) AS seconds_per_percent
END AS seconds_per_percent
FROM documents AS docs FROM documents AS docs
LEFT JOIN users ON users.id = $user_id LEFT JOIN users ON users.id = $user_id
LEFT JOIN LEFT JOIN
document_user_statistics AS dus document_user_statistics AS dus
ON dus.document_id = docs.id AND dus.user_id = $user_id ON dus.document_id = docs.id AND dus.user_id = $user_id
WHERE WHERE
docs.deleted = false AND ( (docs.id = sqlc.narg('id') OR $id IS NULL)
$query IS NULL OR ( AND (docs.deleted = sqlc.narg(deleted) OR $deleted IS NULL)
docs.title LIKE $query OR AND (
(
docs.title LIKE sqlc.narg('query') OR
docs.author LIKE $query docs.author LIKE $query
) ) OR $query IS NULL
) )
ORDER BY dus.last_read DESC, docs.created_at DESC ORDER BY dus.last_read DESC, docs.created_at DESC
LIMIT $limit LIMIT $limit
@@ -280,7 +246,7 @@ SELECT
ROUND(CAST(progress.percentage AS REAL) * 100, 2) AS percentage, ROUND(CAST(progress.percentage AS REAL) * 100, 2) AS percentage,
progress.document_id, progress.document_id,
progress.user_id, progress.user_id,
CAST(STRFTIME('%Y-%m-%d %H:%M:%S', progress.created_at, users.time_offset) AS TEXT) AS created_at LOCAL_TIME(progress.created_at, users.timezone) AS created_at
FROM document_progress AS progress FROM document_progress AS progress
LEFT JOIN users ON progress.user_id = users.id LEFT JOIN users ON progress.user_id = users.id
LEFT JOIN devices ON progress.device_id = devices.id LEFT JOIN devices ON progress.device_id = devices.id
@@ -369,7 +335,8 @@ UPDATE users
SET SET
pass = COALESCE($password, pass), pass = COALESCE($password, pass),
auth_hash = COALESCE($auth_hash, auth_hash), auth_hash = COALESCE($auth_hash, auth_hash),
time_offset = COALESCE($time_offset, time_offset) timezone = COALESCE($timezone, timezone),
admin = COALESCE($admin, admin)
WHERE id = $user_id WHERE id = $user_id
RETURNING *; RETURNING *;
@@ -395,6 +362,7 @@ RETURNING *;
INSERT INTO documents ( INSERT INTO documents (
id, id,
md5, md5,
basepath,
filepath, filepath,
coverfile, coverfile,
title, title,
@@ -409,10 +377,11 @@ INSERT INTO documents (
isbn10, isbn10,
isbn13 isbn13
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT DO UPDATE ON CONFLICT DO UPDATE
SET SET
md5 = COALESCE(excluded.md5, md5), md5 = COALESCE(excluded.md5, md5),
basepath = COALESCE(excluded.basepath, basepath),
filepath = COALESCE(excluded.filepath, filepath), filepath = COALESCE(excluded.filepath, filepath),
coverfile = COALESCE(excluded.coverfile, coverfile), coverfile = COALESCE(excluded.coverfile, coverfile),
title = COALESCE(excluded.title, title), title = COALESCE(excluded.title, title),

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.25.0 // sqlc v1.29.0
// source: query.sql // source: query.sql
package database package database
@@ -85,7 +85,7 @@ type AddMetadataParams struct {
Isbn13 *string `json:"isbn13"` Isbn13 *string `json:"isbn13"`
} }
func (q *Queries) AddMetadata(ctx context.Context, arg AddMetadataParams) (Metadatum, error) { func (q *Queries) AddMetadata(ctx context.Context, arg AddMetadataParams) (Metadata, error) {
row := q.db.QueryRowContext(ctx, addMetadata, row := q.db.QueryRowContext(ctx, addMetadata,
arg.DocumentID, arg.DocumentID,
arg.Title, arg.Title,
@@ -96,7 +96,7 @@ func (q *Queries) AddMetadata(ctx context.Context, arg AddMetadataParams) (Metad
arg.Isbn10, arg.Isbn10,
arg.Isbn13, arg.Isbn13,
) )
var i Metadatum var i Metadata
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,
&i.DocumentID, &i.DocumentID,
@@ -153,6 +153,18 @@ func (q *Queries) DeleteDocument(ctx context.Context, id string) (int64, error)
return result.RowsAffected() return result.RowsAffected()
} }
const deleteUser = `-- name: DeleteUser :execrows
DELETE FROM users WHERE id = ?1
`
func (q *Queries) DeleteUser(ctx context.Context, id string) (int64, error) {
result, err := q.db.ExecContext(ctx, deleteUser, id)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
const getActivity = `-- name: GetActivity :many const getActivity = `-- name: GetActivity :many
WITH filtered_activity AS ( WITH filtered_activity AS (
SELECT SELECT
@@ -181,7 +193,7 @@ WITH filtered_activity AS (
SELECT SELECT
document_id, document_id,
device_id, device_id,
CAST(STRFTIME('%Y-%m-%d %H:%M:%S', activity.start_time, users.time_offset) AS TEXT) AS start_time, LOCAL_TIME(activity.start_time, users.timezone) AS start_time,
title, title,
author, author,
duration, duration,
@@ -204,7 +216,7 @@ type GetActivityParams struct {
type GetActivityRow struct { type GetActivityRow struct {
DocumentID string `json:"document_id"` DocumentID string `json:"document_id"`
DeviceID string `json:"device_id"` DeviceID string `json:"device_id"`
StartTime string `json:"start_time"` StartTime interface{} `json:"start_time"`
Title *string `json:"title"` Title *string `json:"title"`
Author *string `json:"author"` Author *string `json:"author"`
Duration int64 `json:"duration"` Duration int64 `json:"duration"`
@@ -254,7 +266,7 @@ func (q *Queries) GetActivity(ctx context.Context, arg GetActivityParams) ([]Get
const getDailyReadStats = `-- name: GetDailyReadStats :many const getDailyReadStats = `-- name: GetDailyReadStats :many
WITH RECURSIVE last_30_days AS ( WITH RECURSIVE last_30_days AS (
SELECT DATE('now', time_offset) AS date SELECT LOCAL_DATE(STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now'), timezone) AS date
FROM users WHERE users.id = ?1 FROM users WHERE users.id = ?1
UNION ALL UNION ALL
SELECT DATE(date, '-1 days') SELECT DATE(date, '-1 days')
@@ -273,7 +285,7 @@ filtered_activity AS (
activity_days AS ( activity_days AS (
SELECT SELECT
SUM(duration) AS seconds_read, SUM(duration) AS seconds_read,
DATE(start_time, time_offset) AS day LOCAL_DATE(start_time, timezone) AS day
FROM filtered_activity AS activity FROM filtered_activity AS activity
LEFT JOIN users ON users.id = activity.user_id LEFT JOIN users ON users.id = activity.user_id
GROUP BY day GROUP BY day
@@ -410,8 +422,8 @@ const getDevices = `-- name: GetDevices :many
SELECT SELECT
devices.id, devices.id,
devices.device_name, devices.device_name,
CAST(STRFTIME('%Y-%m-%d %H:%M:%S', devices.created_at, users.time_offset) AS TEXT) AS created_at, LOCAL_TIME(devices.created_at, users.timezone) AS created_at,
CAST(STRFTIME('%Y-%m-%d %H:%M:%S', devices.last_synced, users.time_offset) AS TEXT) AS last_synced LOCAL_TIME(devices.last_synced, users.timezone) AS last_synced
FROM devices FROM devices
JOIN users ON users.id = devices.user_id JOIN users ON users.id = devices.user_id
WHERE users.id = ?1 WHERE users.id = ?1
@@ -421,8 +433,8 @@ ORDER BY devices.last_synced DESC
type GetDevicesRow struct { type GetDevicesRow struct {
ID string `json:"id"` ID string `json:"id"`
DeviceName string `json:"device_name"` DeviceName string `json:"device_name"`
CreatedAt string `json:"created_at"` CreatedAt interface{} `json:"created_at"`
LastSynced string `json:"last_synced"` LastSynced interface{} `json:"last_synced"`
} }
func (q *Queries) GetDevices(ctx context.Context, userID string) ([]GetDevicesRow, error) { func (q *Queries) GetDevices(ctx context.Context, userID string) ([]GetDevicesRow, error) {
@@ -454,7 +466,7 @@ func (q *Queries) GetDevices(ctx context.Context, userID string) ([]GetDevicesRo
} }
const getDocument = `-- name: GetDocument :one const getDocument = `-- name: GetDocument :one
SELECT id, md5, filepath, coverfile, title, author, series, series_index, lang, description, words, gbid, olid, isbn10, isbn13, synced, deleted, updated_at, created_at FROM documents SELECT id, md5, basepath, filepath, coverfile, title, author, series, series_index, lang, description, words, gbid, olid, isbn10, isbn13, synced, deleted, updated_at, created_at FROM documents
WHERE id = ?1 LIMIT 1 WHERE id = ?1 LIMIT 1
` `
@@ -464,6 +476,7 @@ func (q *Queries) GetDocument(ctx context.Context, documentID string) (Document,
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,
&i.Md5, &i.Md5,
&i.Basepath,
&i.Filepath, &i.Filepath,
&i.Coverfile, &i.Coverfile,
&i.Title, &i.Title,
@@ -530,89 +543,8 @@ func (q *Queries) GetDocumentProgress(ctx context.Context, arg GetDocumentProgre
return i, err return i, err
} }
const getDocumentWithStats = `-- name: GetDocumentWithStats :one
SELECT
docs.id,
docs.title,
docs.author,
docs.description,
docs.isbn10,
docs.isbn13,
docs.filepath,
docs.words,
CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm,
COALESCE(dus.read_percentage, 0) AS read_percentage,
COALESCE(dus.total_time_seconds, 0) AS total_time_seconds,
STRFTIME('%Y-%m-%d %H:%M:%S', COALESCE(dus.last_read, "1970-01-01"), users.time_offset)
AS last_read,
ROUND(CAST(CASE
WHEN dus.percentage IS NULL THEN 0.0
WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0
ELSE dus.percentage * 100.0
END AS REAL), 2) AS percentage,
CAST(CASE
WHEN dus.total_time_seconds IS NULL THEN 0.0
ELSE
CAST(dus.total_time_seconds AS REAL)
/ (dus.read_percentage * 100.0)
END AS INTEGER) AS seconds_per_percent
FROM documents AS docs
LEFT JOIN users ON users.id = ?1
LEFT JOIN
document_user_statistics AS dus
ON dus.document_id = docs.id AND dus.user_id = ?1
WHERE users.id = ?1
AND docs.id = ?2
LIMIT 1
`
type GetDocumentWithStatsParams struct {
UserID string `json:"user_id"`
DocumentID string `json:"document_id"`
}
type GetDocumentWithStatsRow struct {
ID string `json:"id"`
Title *string `json:"title"`
Author *string `json:"author"`
Description *string `json:"description"`
Isbn10 *string `json:"isbn10"`
Isbn13 *string `json:"isbn13"`
Filepath *string `json:"filepath"`
Words *int64 `json:"words"`
Wpm int64 `json:"wpm"`
ReadPercentage float64 `json:"read_percentage"`
TotalTimeSeconds int64 `json:"total_time_seconds"`
LastRead interface{} `json:"last_read"`
Percentage float64 `json:"percentage"`
SecondsPerPercent int64 `json:"seconds_per_percent"`
}
func (q *Queries) GetDocumentWithStats(ctx context.Context, arg GetDocumentWithStatsParams) (GetDocumentWithStatsRow, error) {
row := q.db.QueryRowContext(ctx, getDocumentWithStats, arg.UserID, arg.DocumentID)
var i GetDocumentWithStatsRow
err := row.Scan(
&i.ID,
&i.Title,
&i.Author,
&i.Description,
&i.Isbn10,
&i.Isbn13,
&i.Filepath,
&i.Words,
&i.Wpm,
&i.ReadPercentage,
&i.TotalTimeSeconds,
&i.LastRead,
&i.Percentage,
&i.SecondsPerPercent,
)
return i, err
}
const getDocuments = `-- name: GetDocuments :many const getDocuments = `-- name: GetDocuments :many
SELECT id, md5, filepath, coverfile, title, author, series, series_index, lang, description, words, gbid, olid, isbn10, isbn13, synced, deleted, updated_at, created_at FROM documents SELECT id, md5, basepath, filepath, coverfile, title, author, series, series_index, lang, description, words, gbid, olid, isbn10, isbn13, synced, deleted, updated_at, created_at FROM documents
ORDER BY created_at DESC ORDER BY created_at DESC
LIMIT ?2 LIMIT ?2
OFFSET ?1 OFFSET ?1
@@ -635,6 +567,7 @@ func (q *Queries) GetDocuments(ctx context.Context, arg GetDocumentsParams) ([]D
if err := rows.Scan( if err := rows.Scan(
&i.ID, &i.ID,
&i.Md5, &i.Md5,
&i.Basepath,
&i.Filepath, &i.Filepath,
&i.Coverfile, &i.Coverfile,
&i.Title, &i.Title,
@@ -698,42 +631,43 @@ SELECT
CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm, CAST(COALESCE(dus.total_wpm, 0.0) AS INTEGER) AS wpm,
COALESCE(dus.read_percentage, 0) AS read_percentage, COALESCE(dus.read_percentage, 0) AS read_percentage,
COALESCE(dus.total_time_seconds, 0) AS total_time_seconds, COALESCE(dus.total_time_seconds, 0) AS total_time_seconds,
STRFTIME('%Y-%m-%d %H:%M:%S', COALESCE(dus.last_read, "1970-01-01"), users.time_offset) STRFTIME('%Y-%m-%d %H:%M:%S', LOCAL_TIME(COALESCE(dus.last_read, STRFTIME('%Y-%m-%dT%H:%M:%SZ', 0, 'unixepoch')), users.timezone))
AS last_read, AS last_read,
ROUND(CAST(CASE ROUND(CAST(CASE
WHEN dus.percentage IS NULL THEN 0.0 WHEN dus.percentage IS NULL THEN 0.0
WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0 WHEN (dus.percentage * 100.0) > 97.0 THEN 100.0
ELSE dus.percentage * 100.0 ELSE dus.percentage * 100.0
END AS REAL), 2) AS percentage, END AS REAL), 2) AS percentage,
CAST(CASE
CASE
WHEN dus.total_time_seconds IS NULL THEN 0.0 WHEN dus.total_time_seconds IS NULL THEN 0.0
ELSE ELSE
ROUND(
CAST(dus.total_time_seconds AS REAL) CAST(dus.total_time_seconds AS REAL)
/ (dus.read_percentage * 100.0) / (dus.read_percentage * 100.0)
) END AS INTEGER) AS seconds_per_percent
END AS seconds_per_percent
FROM documents AS docs FROM documents AS docs
LEFT JOIN users ON users.id = ?1 LEFT JOIN users ON users.id = ?1
LEFT JOIN LEFT JOIN
document_user_statistics AS dus document_user_statistics AS dus
ON dus.document_id = docs.id AND dus.user_id = ?1 ON dus.document_id = docs.id AND dus.user_id = ?1
WHERE WHERE
docs.deleted = false AND ( (docs.id = ?2 OR ?2 IS NULL)
?2 IS NULL OR ( AND (docs.deleted = ?3 OR ?3 IS NULL)
docs.title LIKE ?2 OR AND (
docs.author LIKE ?2 (
) docs.title LIKE ?4 OR
docs.author LIKE ?4
) OR ?4 IS NULL
) )
ORDER BY dus.last_read DESC, docs.created_at DESC ORDER BY dus.last_read DESC, docs.created_at DESC
LIMIT ?4 LIMIT ?6
OFFSET ?3 OFFSET ?5
` `
type GetDocumentsWithStatsParams struct { type GetDocumentsWithStatsParams struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
Query interface{} `json:"query"` ID *string `json:"id"`
Deleted *bool `json:"-"`
Query *string `json:"query"`
Offset int64 `json:"offset"` Offset int64 `json:"offset"`
Limit int64 `json:"limit"` Limit int64 `json:"limit"`
} }
@@ -752,12 +686,14 @@ type GetDocumentsWithStatsRow struct {
TotalTimeSeconds int64 `json:"total_time_seconds"` TotalTimeSeconds int64 `json:"total_time_seconds"`
LastRead interface{} `json:"last_read"` LastRead interface{} `json:"last_read"`
Percentage float64 `json:"percentage"` Percentage float64 `json:"percentage"`
SecondsPerPercent interface{} `json:"seconds_per_percent"` SecondsPerPercent int64 `json:"seconds_per_percent"`
} }
func (q *Queries) GetDocumentsWithStats(ctx context.Context, arg GetDocumentsWithStatsParams) ([]GetDocumentsWithStatsRow, error) { func (q *Queries) GetDocumentsWithStats(ctx context.Context, arg GetDocumentsWithStatsParams) ([]GetDocumentsWithStatsRow, error) {
rows, err := q.db.QueryContext(ctx, getDocumentsWithStats, rows, err := q.db.QueryContext(ctx, getDocumentsWithStats,
arg.UserID, arg.UserID,
arg.ID,
arg.Deleted,
arg.Query, arg.Query,
arg.Offset, arg.Offset,
arg.Limit, arg.Limit,
@@ -819,7 +755,7 @@ func (q *Queries) GetLastActivity(ctx context.Context, arg GetLastActivityParams
} }
const getMissingDocuments = `-- name: GetMissingDocuments :many const getMissingDocuments = `-- name: GetMissingDocuments :many
SELECT documents.id, documents.md5, documents.filepath, documents.coverfile, documents.title, documents.author, documents.series, documents.series_index, documents.lang, documents.description, documents.words, documents.gbid, documents.olid, documents.isbn10, documents.isbn13, documents.synced, documents.deleted, documents.updated_at, documents.created_at FROM documents SELECT documents.id, documents.md5, documents.basepath, documents.filepath, documents.coverfile, documents.title, documents.author, documents.series, documents.series_index, documents.lang, documents.description, documents.words, documents.gbid, documents.olid, documents.isbn10, documents.isbn13, documents.synced, documents.deleted, documents.updated_at, documents.created_at FROM documents
WHERE WHERE
documents.filepath IS NOT NULL documents.filepath IS NOT NULL
AND documents.deleted = false AND documents.deleted = false
@@ -848,6 +784,7 @@ func (q *Queries) GetMissingDocuments(ctx context.Context, documentIds []string)
if err := rows.Scan( if err := rows.Scan(
&i.ID, &i.ID,
&i.Md5, &i.Md5,
&i.Basepath,
&i.Filepath, &i.Filepath,
&i.Coverfile, &i.Coverfile,
&i.Title, &i.Title,
@@ -887,7 +824,7 @@ SELECT
ROUND(CAST(progress.percentage AS REAL) * 100, 2) AS percentage, ROUND(CAST(progress.percentage AS REAL) * 100, 2) AS percentage,
progress.document_id, progress.document_id,
progress.user_id, progress.user_id,
CAST(STRFTIME('%Y-%m-%d %H:%M:%S', progress.created_at, users.time_offset) AS TEXT) AS created_at LOCAL_TIME(progress.created_at, users.timezone) AS created_at
FROM document_progress AS progress FROM document_progress AS progress
LEFT JOIN users ON progress.user_id = users.id LEFT JOIN users ON progress.user_id = users.id
LEFT JOIN devices ON progress.device_id = devices.id LEFT JOIN devices ON progress.device_id = devices.id
@@ -920,7 +857,7 @@ type GetProgressRow struct {
Percentage float64 `json:"percentage"` Percentage float64 `json:"percentage"`
DocumentID string `json:"document_id"` DocumentID string `json:"document_id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
CreatedAt string `json:"created_at"` CreatedAt interface{} `json:"created_at"`
} }
func (q *Queries) GetProgress(ctx context.Context, arg GetProgressParams) ([]GetProgressRow, error) { func (q *Queries) GetProgress(ctx context.Context, arg GetProgressParams) ([]GetProgressRow, error) {
@@ -961,7 +898,7 @@ func (q *Queries) GetProgress(ctx context.Context, arg GetProgressParams) ([]Get
} }
const getUser = `-- name: GetUser :one const getUser = `-- name: GetUser :one
SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users SELECT id, pass, auth_hash, admin, timezone, created_at FROM users
WHERE id = ?1 LIMIT 1 WHERE id = ?1 LIMIT 1
` `
@@ -973,7 +910,7 @@ func (q *Queries) GetUser(ctx context.Context, userID string) (User, error) {
&i.Pass, &i.Pass,
&i.AuthHash, &i.AuthHash,
&i.Admin, &i.Admin,
&i.TimeOffset, &i.Timezone,
&i.CreatedAt, &i.CreatedAt,
) )
return i, err return i, err
@@ -1063,7 +1000,7 @@ func (q *Queries) GetUserStatistics(ctx context.Context) ([]GetUserStatisticsRow
} }
const getUserStreaks = `-- name: GetUserStreaks :many const getUserStreaks = `-- name: GetUserStreaks :many
SELECT user_id, "window", max_streak, max_streak_start_date, max_streak_end_date, current_streak, current_streak_start_date, current_streak_end_date FROM user_streaks SELECT user_id, "window", max_streak, max_streak_start_date, max_streak_end_date, current_streak, current_streak_start_date, current_streak_end_date, last_timezone, last_seen, last_record, last_calculated FROM user_streaks
WHERE user_id = ?1 WHERE user_id = ?1
` `
@@ -1085,6 +1022,10 @@ func (q *Queries) GetUserStreaks(ctx context.Context, userID string) ([]UserStre
&i.CurrentStreak, &i.CurrentStreak,
&i.CurrentStreakStartDate, &i.CurrentStreakStartDate,
&i.CurrentStreakEndDate, &i.CurrentStreakEndDate,
&i.LastTimezone,
&i.LastSeen,
&i.LastRecord,
&i.LastCalculated,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -1100,7 +1041,7 @@ func (q *Queries) GetUserStreaks(ctx context.Context, userID string) ([]UserStre
} }
const getUsers = `-- name: GetUsers :many const getUsers = `-- name: GetUsers :many
SELECT id, pass, auth_hash, admin, time_offset, created_at FROM users SELECT id, pass, auth_hash, admin, timezone, created_at FROM users
` `
func (q *Queries) GetUsers(ctx context.Context) ([]User, error) { func (q *Queries) GetUsers(ctx context.Context) ([]User, error) {
@@ -1117,7 +1058,7 @@ func (q *Queries) GetUsers(ctx context.Context) ([]User, error) {
&i.Pass, &i.Pass,
&i.AuthHash, &i.AuthHash,
&i.Admin, &i.Admin,
&i.TimeOffset, &i.Timezone,
&i.CreatedAt, &i.CreatedAt,
); err != nil { ); err != nil {
return nil, err return nil, err
@@ -1251,15 +1192,17 @@ UPDATE users
SET SET
pass = COALESCE(?1, pass), pass = COALESCE(?1, pass),
auth_hash = COALESCE(?2, auth_hash), auth_hash = COALESCE(?2, auth_hash),
time_offset = COALESCE(?3, time_offset) timezone = COALESCE(?3, timezone),
WHERE id = ?4 admin = COALESCE(?4, admin)
RETURNING id, pass, auth_hash, admin, time_offset, created_at WHERE id = ?5
RETURNING id, pass, auth_hash, admin, timezone, created_at
` `
type UpdateUserParams struct { type UpdateUserParams struct {
Password *string `json:"-"` Password *string `json:"-"`
AuthHash *string `json:"auth_hash"` AuthHash *string `json:"auth_hash"`
TimeOffset *string `json:"time_offset"` Timezone *string `json:"timezone"`
Admin bool `json:"-"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
} }
@@ -1267,7 +1210,8 @@ func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, e
row := q.db.QueryRowContext(ctx, updateUser, row := q.db.QueryRowContext(ctx, updateUser,
arg.Password, arg.Password,
arg.AuthHash, arg.AuthHash,
arg.TimeOffset, arg.Timezone,
arg.Admin,
arg.UserID, arg.UserID,
) )
var i User var i User
@@ -1276,7 +1220,7 @@ func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, e
&i.Pass, &i.Pass,
&i.AuthHash, &i.AuthHash,
&i.Admin, &i.Admin,
&i.TimeOffset, &i.Timezone,
&i.CreatedAt, &i.CreatedAt,
) )
return i, err return i, err
@@ -1322,6 +1266,7 @@ const upsertDocument = `-- name: UpsertDocument :one
INSERT INTO documents ( INSERT INTO documents (
id, id,
md5, md5,
basepath,
filepath, filepath,
coverfile, coverfile,
title, title,
@@ -1336,10 +1281,11 @@ INSERT INTO documents (
isbn10, isbn10,
isbn13 isbn13
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT DO UPDATE ON CONFLICT DO UPDATE
SET SET
md5 = COALESCE(excluded.md5, md5), md5 = COALESCE(excluded.md5, md5),
basepath = COALESCE(excluded.basepath, basepath),
filepath = COALESCE(excluded.filepath, filepath), filepath = COALESCE(excluded.filepath, filepath),
coverfile = COALESCE(excluded.coverfile, coverfile), coverfile = COALESCE(excluded.coverfile, coverfile),
title = COALESCE(excluded.title, title), title = COALESCE(excluded.title, title),
@@ -1353,12 +1299,13 @@ SET
gbid = COALESCE(excluded.gbid, gbid), gbid = COALESCE(excluded.gbid, gbid),
isbn10 = COALESCE(excluded.isbn10, isbn10), isbn10 = COALESCE(excluded.isbn10, isbn10),
isbn13 = COALESCE(excluded.isbn13, isbn13) isbn13 = COALESCE(excluded.isbn13, isbn13)
RETURNING id, md5, filepath, coverfile, title, author, series, series_index, lang, description, words, gbid, olid, isbn10, isbn13, synced, deleted, updated_at, created_at RETURNING id, md5, basepath, filepath, coverfile, title, author, series, series_index, lang, description, words, gbid, olid, isbn10, isbn13, synced, deleted, updated_at, created_at
` `
type UpsertDocumentParams struct { type UpsertDocumentParams struct {
ID string `json:"id"` ID string `json:"id"`
Md5 *string `json:"md5"` Md5 *string `json:"md5"`
Basepath *string `json:"basepath"`
Filepath *string `json:"filepath"` Filepath *string `json:"filepath"`
Coverfile *string `json:"coverfile"` Coverfile *string `json:"coverfile"`
Title *string `json:"title"` Title *string `json:"title"`
@@ -1378,6 +1325,7 @@ func (q *Queries) UpsertDocument(ctx context.Context, arg UpsertDocumentParams)
row := q.db.QueryRowContext(ctx, upsertDocument, row := q.db.QueryRowContext(ctx, upsertDocument,
arg.ID, arg.ID,
arg.Md5, arg.Md5,
arg.Basepath,
arg.Filepath, arg.Filepath,
arg.Coverfile, arg.Coverfile,
arg.Title, arg.Title,
@@ -1396,6 +1344,7 @@ func (q *Queries) UpsertDocument(ctx context.Context, arg UpsertDocumentParams)
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,
&i.Md5, &i.Md5,
&i.Basepath,
&i.Filepath, &i.Filepath,
&i.Coverfile, &i.Coverfile,
&i.Title, &i.Title,

View File

@@ -9,7 +9,7 @@ CREATE TABLE IF NOT EXISTS users (
pass TEXT NOT NULL, pass TEXT NOT NULL,
auth_hash TEXT NOT NULL, auth_hash TEXT NOT NULL,
admin BOOLEAN NOT NULL DEFAULT 0 CHECK (admin IN (0, 1)), admin BOOLEAN NOT NULL DEFAULT 0 CHECK (admin IN (0, 1)),
time_offset TEXT NOT NULL DEFAULT '0 hours', timezone TEXT NOT NULL DEFAULT 'Europe/London',
created_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now')) created_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now'))
); );
@@ -19,6 +19,7 @@ CREATE TABLE IF NOT EXISTS documents (
id TEXT NOT NULL PRIMARY KEY, id TEXT NOT NULL PRIMARY KEY,
md5 TEXT, md5 TEXT,
basepath TEXT,
filepath TEXT, filepath TEXT,
coverfile TEXT, coverfile TEXT,
title TEXT, title TEXT,
@@ -117,30 +118,13 @@ CREATE TABLE IF NOT EXISTS settings (
created_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now')) created_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now'))
); );
--------------------------------------------------------------- -- Document User Statistics Table
----------------------- Temporary Tables ---------------------- CREATE TABLE IF NOT EXISTS document_user_statistics (
---------------------------------------------------------------
-- Temporary User Streaks Table (Cached from View)
CREATE TEMPORARY TABLE IF NOT EXISTS user_streaks (
user_id TEXT NOT NULL,
window TEXT NOT NULL,
max_streak INTEGER NOT NULL,
max_streak_start_date TEXT NOT NULL,
max_streak_end_date TEXT NOT NULL,
current_streak INTEGER NOT NULL,
current_streak_start_date TEXT NOT NULL,
current_streak_end_date TEXT NOT NULL
);
-- Temporary Document User Statistics Table (Cached from View)
CREATE TEMPORARY TABLE IF NOT EXISTS document_user_statistics (
document_id TEXT NOT NULL, document_id TEXT NOT NULL,
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
percentage REAL NOT NULL, percentage REAL NOT NULL,
last_read TEXT NOT NULL, last_read DATETIME NOT NULL,
last_seen DATETIME NOT NULL,
read_percentage REAL NOT NULL, read_percentage REAL NOT NULL,
total_time_seconds INTEGER NOT NULL, total_time_seconds INTEGER NOT NULL,
@@ -162,321 +146,39 @@ CREATE TEMPORARY TABLE IF NOT EXISTS document_user_statistics (
UNIQUE(document_id, user_id) ON CONFLICT REPLACE UNIQUE(document_id, user_id) ON CONFLICT REPLACE
); );
-- User Streaks Table
CREATE TABLE IF NOT EXISTS user_streaks (
user_id TEXT NOT NULL,
window TEXT NOT NULL,
max_streak INTEGER NOT NULL,
max_streak_start_date TEXT NOT NULL,
max_streak_end_date TEXT NOT NULL,
current_streak INTEGER NOT NULL,
current_streak_start_date TEXT NOT NULL,
current_streak_end_date TEXT NOT NULL,
last_timezone TEXT NOT NULL,
last_seen TEXT NOT NULL,
last_record TEXT NOT NULL,
last_calculated TEXT NOT NULL,
UNIQUE(user_id, window) ON CONFLICT REPLACE
);
--------------------------------------------------------------- ---------------------------------------------------------------
--------------------------- Indexes --------------------------- --------------------------- Indexes ---------------------------
--------------------------------------------------------------- ---------------------------------------------------------------
CREATE INDEX IF NOT EXISTS activity_start_time ON activity (start_time); CREATE INDEX IF NOT EXISTS activity_start_time ON activity (start_time);
CREATE INDEX IF NOT EXISTS activity_created_at ON activity (created_at);
CREATE INDEX IF NOT EXISTS activity_user_id ON activity (user_id); CREATE INDEX IF NOT EXISTS activity_user_id ON activity (user_id);
CREATE INDEX IF NOT EXISTS activity_user_id_document_id ON activity ( CREATE INDEX IF NOT EXISTS activity_user_id_document_id ON activity (
user_id, user_id,
document_id document_id
); );
---------------------------------------------------------------
---------------------------- Views ----------------------------
---------------------------------------------------------------
DROP VIEW IF EXISTS view_user_streaks;
DROP VIEW IF EXISTS view_document_user_statistics;
--------------------------------
--------- User Streaks ---------
--------------------------------
CREATE VIEW view_user_streaks AS
WITH document_windows AS (
SELECT
activity.user_id,
users.time_offset,
DATE(
activity.start_time,
users.time_offset,
'weekday 0', '-7 day'
) AS weekly_read,
DATE(activity.start_time, users.time_offset) AS daily_read
FROM activity
LEFT JOIN users ON users.id = activity.user_id
GROUP BY activity.user_id, weekly_read, daily_read
),
weekly_partitions AS (
SELECT
user_id,
time_offset,
'WEEK' AS "window",
weekly_read AS read_window,
row_number() OVER (
PARTITION BY user_id ORDER BY weekly_read DESC
) AS seqnum
FROM document_windows
GROUP BY user_id, weekly_read
),
daily_partitions AS (
SELECT
user_id,
time_offset,
'DAY' AS "window",
daily_read AS read_window,
row_number() OVER (
PARTITION BY user_id ORDER BY daily_read DESC
) AS seqnum
FROM document_windows
GROUP BY user_id, daily_read
),
streaks AS (
SELECT
COUNT(*) AS streak,
MIN(read_window) AS start_date,
MAX(read_window) AS end_date,
window,
user_id,
time_offset
FROM daily_partitions
GROUP BY
time_offset,
user_id,
DATE(read_window, '+' || seqnum || ' day')
UNION ALL
SELECT
COUNT(*) AS streak,
MIN(read_window) AS start_date,
MAX(read_window) AS end_date,
window,
user_id,
time_offset
FROM weekly_partitions
GROUP BY
time_offset,
user_id,
DATE(read_window, '+' || (seqnum * 7) || ' day')
),
max_streak AS (
SELECT
MAX(streak) AS max_streak,
start_date AS max_streak_start_date,
end_date AS max_streak_end_date,
window,
user_id
FROM streaks
GROUP BY user_id, window
),
current_streak AS (
SELECT
streak AS current_streak,
start_date AS current_streak_start_date,
end_date AS current_streak_end_date,
window,
user_id
FROM streaks
WHERE CASE
WHEN window = "WEEK" THEN
DATE('now', time_offset, 'weekday 0', '-14 day') = current_streak_end_date
OR DATE('now', time_offset, 'weekday 0', '-7 day') = current_streak_end_date
WHEN window = "DAY" THEN
DATE('now', time_offset, '-1 day') = current_streak_end_date
OR DATE('now', time_offset) = current_streak_end_date
END
GROUP BY user_id, window
)
SELECT
max_streak.user_id,
max_streak.window,
IFNULL(max_streak, 0) AS max_streak,
IFNULL(max_streak_start_date, "N/A") AS max_streak_start_date,
IFNULL(max_streak_end_date, "N/A") AS max_streak_end_date,
IFNULL(current_streak, 0) AS current_streak,
IFNULL(current_streak_start_date, "N/A") AS current_streak_start_date,
IFNULL(current_streak_end_date, "N/A") AS current_streak_end_date
FROM max_streak
LEFT JOIN current_streak ON
current_streak.user_id = max_streak.user_id
AND current_streak.window = max_streak.window;
--------------------------------
------- Document Stats ---------
--------------------------------
CREATE VIEW view_document_user_statistics AS
WITH intermediate_ga AS (
SELECT
ga1.id AS row_id,
ga1.user_id,
ga1.document_id,
ga1.duration,
ga1.start_time,
ga1.start_percentage,
ga1.end_percentage,
-- Find Overlapping Events (Assign Unique ID)
(
SELECT MIN(id)
FROM activity AS ga2
WHERE
ga1.document_id = ga2.document_id
AND ga1.user_id = ga2.user_id
AND ga1.start_percentage <= ga2.end_percentage
AND ga1.end_percentage >= ga2.start_percentage
) AS group_leader
FROM activity AS ga1
),
grouped_activity AS (
SELECT
user_id,
document_id,
MAX(start_time) AS start_time,
MIN(start_percentage) AS start_percentage,
MAX(end_percentage) AS end_percentage,
MAX(end_percentage) - MIN(start_percentage) AS read_percentage,
SUM(duration) AS duration
FROM intermediate_ga
GROUP BY group_leader
),
current_progress AS (
SELECT
user_id,
document_id,
COALESCE((
SELECT percentage
FROM document_progress AS dp
WHERE
dp.user_id = iga.user_id
AND dp.document_id = iga.document_id
ORDER BY created_at DESC
LIMIT 1
), end_percentage) AS percentage
FROM intermediate_ga AS iga
GROUP BY user_id, document_id
HAVING MAX(start_time)
)
SELECT
ga.document_id,
ga.user_id,
cp.percentage,
MAX(start_time) AS last_read,
SUM(read_percentage) AS read_percentage,
-- All Time WPM
SUM(duration) AS total_time_seconds,
(CAST(COALESCE(d.words, 0.0) AS REAL) * SUM(read_percentage))
AS total_words_read,
(CAST(COALESCE(d.words, 0.0) AS REAL) * SUM(read_percentage))
/ (SUM(duration) / 60.0) AS total_wpm,
-- Yearly WPM
SUM(CASE WHEN start_time >= DATE('now', '-1 year') THEN duration ELSE 0 END)
AS yearly_time_seconds,
(
CAST(COALESCE(d.words, 0.0) AS REAL)
* SUM(
CASE
WHEN start_time >= DATE('now', '-1 year') THEN read_percentage
ELSE 0
END
)
)
AS yearly_words_read,
COALESCE((
CAST(COALESCE(d.words, 0.0) AS REAL)
* SUM(
CASE
WHEN start_time >= DATE('now', '-1 year') THEN read_percentage
END
)
)
/ (
SUM(
CASE
WHEN start_time >= DATE('now', '-1 year') THEN duration
END
)
/ 60.0
), 0.0)
AS yearly_wpm,
-- Monthly WPM
SUM(
CASE WHEN start_time >= DATE('now', '-1 month') THEN duration ELSE 0 END
)
AS monthly_time_seconds,
(
CAST(COALESCE(d.words, 0.0) AS REAL)
* SUM(
CASE
WHEN start_time >= DATE('now', '-1 month') THEN read_percentage
ELSE 0
END
)
)
AS monthly_words_read,
COALESCE((
CAST(COALESCE(d.words, 0.0) AS REAL)
* SUM(
CASE
WHEN start_time >= DATE('now', '-1 month') THEN read_percentage
END
)
)
/ (
SUM(
CASE
WHEN start_time >= DATE('now', '-1 month') THEN duration
END
)
/ 60.0
), 0.0)
AS monthly_wpm,
-- Weekly WPM
SUM(CASE WHEN start_time >= DATE('now', '-7 days') THEN duration ELSE 0 END)
AS weekly_time_seconds,
(
CAST(COALESCE(d.words, 0.0) AS REAL)
* SUM(
CASE
WHEN start_time >= DATE('now', '-7 days') THEN read_percentage
ELSE 0
END
)
)
AS weekly_words_read,
COALESCE((
CAST(COALESCE(d.words, 0.0) AS REAL)
* SUM(
CASE
WHEN start_time >= DATE('now', '-7 days') THEN read_percentage
END
)
)
/ (
SUM(
CASE
WHEN start_time >= DATE('now', '-7 days') THEN duration
END
)
/ 60.0
), 0.0)
AS weekly_wpm
FROM grouped_activity AS ga
INNER JOIN
current_progress AS cp
ON ga.user_id = cp.user_id AND ga.document_id = cp.document_id
INNER JOIN
documents AS d
ON ga.document_id = d.id
GROUP BY ga.document_id, ga.user_id
ORDER BY total_wpm DESC;
--------------------------------------------------------------- ---------------------------------------------------------------
--------------------------- Triggers -------------------------- --------------------------- Triggers --------------------------
--------------------------------------------------------------- ---------------------------------------------------------------
@@ -488,3 +190,11 @@ UPDATE documents
SET updated_at = STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now') SET updated_at = STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now')
WHERE id = old.id; WHERE id = old.id;
END; END;
-- Delete User
CREATE TRIGGER IF NOT EXISTS user_deleted
BEFORE DELETE ON users BEGIN
DELETE FROM activity WHERE activity.user_id=OLD.id;
DELETE FROM devices WHERE devices.user_id=OLD.id;
DELETE FROM document_progress WHERE document_progress.user_id=OLD.id;
END;

154
database/user_streaks.sql Normal file
View File

@@ -0,0 +1,154 @@
WITH updated_users AS (
SELECT a.user_id
FROM activity AS a
LEFT JOIN users AS u ON u.id = a.user_id
LEFT JOIN user_streaks AS s ON a.user_id = s.user_id AND s.window = 'DAY'
WHERE
a.created_at > COALESCE(s.last_seen, '1970-01-01')
AND LOCAL_DATE(s.last_record, u.timezone) != LOCAL_DATE(a.start_time, u.timezone)
GROUP BY a.user_id
),
outdated_users AS (
SELECT
a.user_id,
u.timezone AS last_timezone,
MAX(a.created_at) AS last_seen,
MAX(a.start_time) AS last_record,
STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now') AS last_calculated
FROM activity AS a
LEFT JOIN users AS u ON u.id = a.user_id
LEFT JOIN user_streaks AS s ON a.user_id = s.user_id AND s.window = 'DAY'
GROUP BY a.user_id
HAVING
-- User Changed Timezones
s.last_timezone != u.timezone
-- Users Date Changed
OR LOCAL_DATE(COALESCE(s.last_calculated, '1970-01-01T00:00:00Z'), u.timezone) !=
LOCAL_DATE(STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now'), u.timezone)
-- User Added New Data
OR a.user_id IN updated_users
),
document_windows AS (
SELECT
activity.user_id,
users.timezone,
DATE(
LOCAL_DATE(activity.start_time, users.timezone),
'weekday 0', '-7 day'
) AS weekly_read,
LOCAL_DATE(activity.start_time, users.timezone) AS daily_read
FROM activity
INNER JOIN outdated_users ON outdated_users.user_id = activity.user_id
LEFT JOIN users ON users.id = activity.user_id
GROUP BY activity.user_id, weekly_read, daily_read
),
weekly_partitions AS (
SELECT
user_id,
timezone,
'WEEK' AS "window",
weekly_read AS read_window,
ROW_NUMBER() OVER (
PARTITION BY user_id ORDER BY weekly_read DESC
) AS seqnum
FROM document_windows
GROUP BY user_id, weekly_read
),
daily_partitions AS (
SELECT
user_id,
timezone,
'DAY' AS "window",
daily_read AS read_window,
ROW_NUMBER() OVER (
PARTITION BY user_id ORDER BY daily_read DESC
) AS seqnum
FROM document_windows
GROUP BY user_id, daily_read
),
streaks AS (
SELECT
COUNT(*) AS streak,
MIN(read_window) AS start_date,
MAX(read_window) AS end_date,
window,
user_id,
timezone
FROM daily_partitions
GROUP BY
timezone,
user_id,
DATE(read_window, '+' || seqnum || ' day')
UNION ALL
SELECT
COUNT(*) AS streak,
MIN(read_window) AS start_date,
MAX(read_window) AS end_date,
window,
user_id,
timezone
FROM weekly_partitions
GROUP BY
timezone,
user_id,
DATE(read_window, '+' || (seqnum * 7) || ' day')
),
max_streak AS (
SELECT
MAX(streak) AS max_streak,
start_date AS max_streak_start_date,
end_date AS max_streak_end_date,
window,
user_id
FROM streaks
GROUP BY user_id, window
),
current_streak AS (
SELECT
streak AS current_streak,
start_date AS current_streak_start_date,
end_date AS current_streak_end_date,
window,
user_id
FROM streaks
WHERE CASE
WHEN window = "WEEK" THEN
DATE(LOCAL_DATE(STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now'), timezone), 'weekday 0', '-14 day') = current_streak_end_date
OR DATE(LOCAL_DATE(STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now'), timezone), 'weekday 0', '-7 day') = current_streak_end_date
WHEN window = "DAY" THEN
DATE(LOCAL_DATE(STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now'), timezone), '-1 day') = current_streak_end_date
OR DATE(LOCAL_DATE(STRFTIME('%Y-%m-%dT%H:%M:%SZ', 'now'), timezone)) = current_streak_end_date
END
GROUP BY user_id, window
)
INSERT INTO user_streaks
SELECT
max_streak.user_id,
max_streak.window,
IFNULL(max_streak, 0) AS max_streak,
IFNULL(max_streak_start_date, "N/A") AS max_streak_start_date,
IFNULL(max_streak_end_date, "N/A") AS max_streak_end_date,
IFNULL(current_streak.current_streak, 0) AS current_streak,
IFNULL(current_streak.current_streak_start_date, "N/A") AS current_streak_start_date,
IFNULL(current_streak.current_streak_end_date, "N/A") AS current_streak_end_date,
outdated_users.last_timezone AS last_timezone,
outdated_users.last_seen AS last_seen,
outdated_users.last_record AS last_record,
outdated_users.last_calculated AS last_calculated
FROM max_streak
JOIN outdated_users ON max_streak.user_id = outdated_users.user_id
LEFT JOIN current_streak ON
current_streak.user_id = max_streak.user_id
AND current_streak.window = max_streak.window;

205
database/users_test.go Normal file
View File

@@ -0,0 +1,205 @@
package database
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/suite"
"reichard.io/antholume/config"
"reichard.io/antholume/utils"
)
var (
testUserID string = "testUser"
testUserPass string = "testPass"
)
type UsersTestSuite struct {
suite.Suite
dbm *DBManager
}
func TestUsers(t *testing.T) {
suite.Run(t, new(UsersTestSuite))
}
func (suite *UsersTestSuite) SetupTest() {
cfg := config.Config{
DBType: "memory",
}
suite.dbm = NewMgr(&cfg)
// Create User
rawAuthHash, _ := utils.GenerateToken(64)
authHash := fmt.Sprintf("%x", rawAuthHash)
_, err := suite.dbm.Queries.CreateUser(context.Background(), CreateUserParams{
ID: testUserID,
Pass: &testUserPass,
AuthHash: &authHash,
})
suite.NoError(err)
// Create Document
_, err = suite.dbm.Queries.UpsertDocument(context.Background(), UpsertDocumentParams{
ID: documentID,
Title: &documentTitle,
Author: &documentAuthor,
Words: &documentWords,
})
suite.NoError(err)
// Create Device
_, err = suite.dbm.Queries.UpsertDevice(context.Background(), UpsertDeviceParams{
ID: deviceID,
UserID: testUserID,
DeviceName: deviceName,
})
suite.NoError(err)
}
func (suite *UsersTestSuite) TestGetUser() {
user, err := suite.dbm.Queries.GetUser(context.Background(), testUserID)
suite.Nil(err, "should have nil err")
suite.Equal(testUserPass, *user.Pass)
}
func (suite *UsersTestSuite) TestCreateUser() {
testUser := "user1"
testPass := "pass1"
// Generate Auth Hash
rawAuthHash, err := utils.GenerateToken(64)
suite.Nil(err, "should have nil err")
authHash := fmt.Sprintf("%x", rawAuthHash)
changed, err := suite.dbm.Queries.CreateUser(context.Background(), CreateUserParams{
ID: testUser,
Pass: &testPass,
AuthHash: &authHash,
})
suite.Nil(err, "should have nil err")
suite.Equal(int64(1), changed)
user, err := suite.dbm.Queries.GetUser(context.Background(), testUser)
suite.Nil(err, "should have nil err")
suite.Equal(testPass, *user.Pass)
}
func (suite *UsersTestSuite) TestDeleteUser() {
changed, err := suite.dbm.Queries.DeleteUser(context.Background(), testUserID)
suite.Nil(err, "should have nil err")
suite.Equal(int64(1), changed, "should have one changed row")
_, err = suite.dbm.Queries.GetUser(context.Background(), testUserID)
suite.ErrorIs(err, sql.ErrNoRows, "should have no rows error")
}
func (suite *UsersTestSuite) TestGetUsers() {
users, err := suite.dbm.Queries.GetUsers(context.Background())
suite.Nil(err, "should have nil err")
suite.Len(users, 1, "should have single user")
}
func (suite *UsersTestSuite) TestUpdateUser() {
newPassword := "newPass123"
user, err := suite.dbm.Queries.UpdateUser(context.Background(), UpdateUserParams{
UserID: testUserID,
Password: &newPassword,
})
suite.Nil(err, "should have nil err")
suite.Equal(newPassword, *user.Pass, "should have new password")
}
func (suite *UsersTestSuite) TestGetUserStatistics() {
err := suite.dbm.CacheTempTables(context.Background())
suite.NoError(err)
// Ensure Zero Items
userStats, err := suite.dbm.Queries.GetUserStatistics(context.Background())
suite.Nil(err, "should have nil err")
suite.Empty(userStats, "should be empty")
// Create Activity
end := time.Now()
start := end.AddDate(0, 0, -9)
var counter int64 = 0
for d := start; d.After(end) == false; d = d.AddDate(0, 0, 1) {
counter += 1
// Add Item
activity, err := suite.dbm.Queries.AddActivity(context.Background(), AddActivityParams{
DocumentID: documentID,
DeviceID: deviceID,
UserID: testUserID,
StartTime: d.UTC().Format(time.RFC3339),
Duration: 60,
StartPercentage: float64(counter) / 100.0,
EndPercentage: float64(counter+1) / 100.0,
})
suite.Nil(err, fmt.Sprintf("[%d] should have nil err for add activity", counter))
suite.Equal(counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter))
}
err = suite.dbm.CacheTempTables(context.Background())
suite.NoError(err)
// Ensure One Item
userStats, err = suite.dbm.Queries.GetUserStatistics(context.Background())
suite.Nil(err, "should have nil err")
suite.Len(userStats, 1, "should have length of one")
}
func (suite *UsersTestSuite) TestGetUsersStreaks() {
err := suite.dbm.CacheTempTables(context.Background())
suite.NoError(err)
// Ensure Zero Items
userStats, err := suite.dbm.Queries.GetUserStreaks(context.Background(), testUserID)
suite.Nil(err, "should have nil err")
suite.Empty(userStats, "should be empty")
// Create Activity
end := time.Now()
start := end.AddDate(0, 0, -9)
var counter int64 = 0
for d := start; d.After(end) == false; d = d.AddDate(0, 0, 1) {
counter += 1
// Add Item
activity, err := suite.dbm.Queries.AddActivity(context.Background(), AddActivityParams{
DocumentID: documentID,
DeviceID: deviceID,
UserID: testUserID,
StartTime: d.UTC().Format(time.RFC3339),
Duration: 60,
StartPercentage: float64(counter) / 100.0,
EndPercentage: float64(counter+1) / 100.0,
})
suite.Nil(err, fmt.Sprintf("[%d] should have nil err for add activity", counter))
suite.Equal(counter, activity.ID, fmt.Sprintf("[%d] should have correct id for add activity", counter))
}
err = suite.dbm.CacheTempTables(context.Background())
suite.NoError(err)
// Ensure Two Item
userStats, err = suite.dbm.Queries.GetUserStreaks(context.Background(), testUserID)
suite.Nil(err, "should have nil err")
suite.Len(userStats, 2, "should have length of two")
// Ensure Streak Stats
dayStats := userStats[0]
weekStats := userStats[1]
suite.Equal(int64(10), dayStats.CurrentStreak, "should be 10 days")
suite.Greater(weekStats.CurrentStreak, int64(1), "should be 2 or 3")
}

61
flake.lock generated Normal file
View File

@@ -0,0 +1,61 @@
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1773524153,
"narHash": "sha256-Jms57zzlFf64ayKzzBWSE2SGvJmK+NGt8Gli71d9kmY=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "e9f278faa1d0c2fc835bd331d4666b59b505a410",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-25.11",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

37
flake.nix Normal file
View File

@@ -0,0 +1,37 @@
{
description = "Development Environment";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.11";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
{ self
, nixpkgs
, flake-utils
,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = nixpkgs.legacyPackages.${system};
in
{
devShells.default = pkgs.mkShell {
packages = with pkgs; [
go
golangci-lint
gopls
bun
nodejs
tailwindcss
];
shellHook = ''
export PATH=$PATH:~/go/bin
'';
};
}
);
}

2
frontend/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
node_modules
dist

2
frontend/.prettierignore Normal file
View File

@@ -0,0 +1,2 @@
# Generated API code
src/generated/**/*

11
frontend/.prettierrc Normal file
View File

@@ -0,0 +1,11 @@
{
"semi": true,
"singleQuote": true,
"tabWidth": 2,
"useTabs": false,
"trailingComma": "es5",
"printWidth": 100,
"bracketSpacing": true,
"arrowParens": "avoid",
"endOfLine": "lf"
}

76
frontend/AGENTS.md Normal file
View File

@@ -0,0 +1,76 @@
# AnthoLume Frontend Agent Guide
Read this file for work in `frontend/`.
Also follow the repository root guide at `../AGENTS.md`.
## 1) Stack
- Package manager: `bun`
- Framework: React + Vite
- Data fetching: React Query
- API generation: Orval
- Linting: ESLint + Tailwind plugin
- Formatting: Prettier
## 2) Conventions
- Use local icon components from `src/icons/`.
- Do not add external icon libraries.
- Prefer generated types from `src/generated/model/` over `any`.
- Avoid custom class names in JSX `className` values unless the Tailwind lint config already allows them.
- For decorative icons in inputs or labels, disable hover styling via the icon component API rather than overriding it ad hoc.
- Prefer `LoadingState` for result-area loading indicators; avoid early returns that unmount search/filter forms during fetches.
- Use theme tokens from `tailwind.config.js` / `src/index.css` (`bg-surface`, `text-content`, `border-border`, `primary`, etc.) for new UI work instead of adding raw light/dark color pairs.
- Store frontend-only preferences in `src/utils/localSettings.ts` so appearance and view settings share one local-storage shape.
## 3) Generated API client
- Do not edit `src/generated/**` directly.
- Edit `../api/v1/openapi.yaml` and regenerate instead.
- Regenerate with: `bun run generate:api`
### Important behavior
- The generated client returns `{ data, status, headers }` for both success and error responses.
- Do not assume non-2xx responses throw.
- Check `response.status` and response shape before treating a request as successful.
## 4) Auth / Query State
- When changing auth flows, account for React Query cache state.
- Pay special attention to `/api/v1/auth/me`.
- A local auth state update may not be enough if cached query data still reflects a previous auth state.
## 5) Commands
- Lint: `bun run lint`
- Typecheck: `bun run typecheck`
- Lint fix: `bun run lint:fix`
- Format check: `bun run format`
- Format fix: `bun run format:fix`
- Build: `bun run build`
- Generate API client: `bun run generate:api`
## 6) Validation Notes
- ESLint ignores `src/generated/**`.
- Frontend unit tests use Vitest and live alongside source as `src/**/*.test.ts(x)`.
- Read `TESTING_STRATEGY.md` before adding or expanding frontend tests.
- Prefer tests for meaningful app behavior, branching logic, side effects, and user-visible outcomes.
- Avoid low-value tests that mainly assert exact styling classes, duplicate existing coverage, or re-test framework/library behavior.
- `bun run lint` includes test files but does not typecheck.
- Use `bun run typecheck` to run TypeScript validation for app code and colocated tests without a full production build.
- Run frontend tests with `bun run test`.
- `bun run build` still runs `tsc && vite build`, so unrelated TypeScript issues elsewhere in `src/` can fail the build.
- When possible, validate changed files directly before escalating to full-project fixes.
## 7) Updating This File
After completing a frontend task, update this file if you learned something general that would help future frontend agents.
Rules for updates:
- Add only frontend-wide guidance.
- Do not record one-off task history.
- Keep updates concise and action-oriented.
- Prefer notes that prevent repeated mistakes.

111
frontend/README.md Normal file
View File

@@ -0,0 +1,111 @@
# AnthoLume Frontend
A React + TypeScript frontend for AnthoLume, replacing the server-side rendering (SSR) templates.
## Tech Stack
- **React 19** - UI framework
- **TypeScript** - Type safety
- **React Query (TanStack Query)** - Server state management
- **Orval** - API client generation from OpenAPI spec
- **React Router** - Navigation
- **Tailwind CSS** - Styling
- **Vite** - Build tool
- **Axios** - HTTP client with auth interceptors
## Authentication
The frontend includes a complete authentication system:
### Auth Context
- `AuthProvider` - Manages authentication state globally
- `useAuth()` - Hook to access auth state and methods
- Token stored in `localStorage`
- Axios interceptors automatically attach Bearer token to API requests
### Protected Routes
- All main routes are wrapped in `ProtectedRoute`
- Unauthenticated users are redirected to `/login`
- Layout redirects to login if not authenticated
### Login Flow
1. User enters credentials on `/login`
2. POST to `/api/v1/auth/login`
3. Token stored in localStorage
4. Redirect to home page
5. Axios interceptor includes token in subsequent requests
### Logout Flow
1. User clicks "Logout" in dropdown menu
2. POST to `/api/v1/auth/logout`
3. Token cleared from localStorage
4. Redirect to `/login`
### 401 Handling
- Axios response interceptor clears token on 401 errors
- Prevents stale auth state
## Architecture
The frontend mirrors the existing SSR templates structure:
### Pages
- `HomePage` - Landing page with recent documents
- `DocumentsPage` - Document listing with search and pagination
- `DocumentPage` - Single document view with details
- `ProgressPage` - Reading progress table
- `ActivityPage` - User activity log
- `SearchPage` - Search interface
- `SettingsPage` - User settings
- `LoginPage` - Authentication
### Components
- `Layout` - Main layout with navigation sidebar and header
- Generated API hooks from `api/v1/openapi.yaml`
## API Integration
The frontend uses **Orval** to generate TypeScript types and React Query hooks from the OpenAPI spec:
```bash
npm run generate:api
```
This generates:
- Type definitions for all API schemas
- React Query hooks (`useGetDocuments`, `useGetDocument`, etc.)
- Mutation hooks (`useLogin`, `useLogout`)
## Development
```bash
# Install dependencies
npm install
# Generate API types (if OpenAPI spec changes)
npm run generate:api
# Start development server
npm run dev
# Build for production
npm run build
```
## Deployment
The built output is in `dist/` and can be served by the Go backend or deployed separately.
## Migration from SSR
The frontend replicates the functionality of the following SSR templates:
- `templates/pages/home.tmpl``HomePage.tsx`
- `templates/pages/documents.tmpl``DocumentsPage.tsx`
- `templates/pages/document.tmpl``DocumentPage.tsx`
- `templates/pages/progress.tmpl``ProgressPage.tsx`
- `templates/pages/activity.tmpl``ActivityPage.tsx`
- `templates/pages/search.tmpl``SearchPage.tsx`
- `templates/pages/settings.tmpl``SettingsPage.tsx`
- `templates/pages/login.tmpl``LoginPage.tsx`
The styling follows the same Tailwind CSS classes as the original templates for consistency.

View File

@@ -0,0 +1,73 @@
# Frontend Testing Strategy
This project prefers meaningful frontend tests over high test counts.
## What we want to test
Prioritize tests for app-owned behavior such as:
- user-visible page and component behavior
- auth and routing behavior
- branching logic and business rules
- data normalization and error handling
- timing behavior with real app logic
- side effects that could regress, such as token handling or redirects
- algorithmic or formatting logic that defines product behavior
Good examples in this repo:
- login and registration flows
- protected-route behavior
- auth interceptor token injection and cleanup
- error message extraction
- debounce timing
- human-readable formatting logic
- graph/algorithm output where exact parity matters
## What we usually do not want to test
Avoid tests that mostly prove:
- the language/runtime works
- React forwards basic props correctly
- a third-party library behaves as documented
- exact Tailwind class strings with no product meaning
- implementation details not observable in behavior
- duplicated examples that re-assert the same logic
In other words, do not add tests equivalent to checking that JavaScript can compute `1 + 1`.
## Preferred test style
- Prefer behavior-focused assertions over implementation-detail assertions.
- Prefer user-visible outcomes over internal state inspection.
- Mock at module boundaries when needed.
- Keep test setup small and local.
- Use exact-output assertions only when the output itself is the contract.
## When exact assertions are appropriate
Exact assertions are appropriate when they protect a real contract, for example:
- a formatter's exact human-readable output
- auth decision outcomes for a given API response shape
- exact algorithm output that must remain stable
Exact assertions are usually not appropriate for:
- incidental class names
- framework internals
- non-observable React keys
## Cleanup rule of thumb
Keep tests that would catch meaningful regressions in product behavior.
Trim or remove tests that are brittle, duplicated, or mostly validate tooling rather than app logic.
## Validation
For frontend test work, validate with:
- `cd frontend && bun run lint`
- `cd frontend && bun run typecheck`
- `cd frontend && bun run test`

1350
frontend/bun.lock Normal file

File diff suppressed because it is too large Load Diff

82
frontend/eslint.config.js Normal file
View File

@@ -0,0 +1,82 @@
import js from "@eslint/js";
import typescriptParser from "@typescript-eslint/parser";
import typescriptPlugin from "@typescript-eslint/eslint-plugin";
import reactPlugin from "eslint-plugin-react";
import reactHooksPlugin from "eslint-plugin-react-hooks";
import tailwindcss from "eslint-plugin-tailwindcss";
import prettier from "eslint-plugin-prettier";
import eslintConfigPrettier from "eslint-config-prettier";
export default [
js.configs.recommended,
{
files: ["**/*.ts", "**/*.tsx"],
ignores: ["**/generated/**"],
languageOptions: {
parser: typescriptParser,
parserOptions: {
ecmaVersion: "latest",
sourceType: "module",
ecmaFeatures: {
jsx: true,
},
projectService: true,
},
globals: {
localStorage: "readonly",
sessionStorage: "readonly",
document: "readonly",
window: "readonly",
setTimeout: "readonly",
clearTimeout: "readonly",
setInterval: "readonly",
clearInterval: "readonly",
HTMLElement: "readonly",
HTMLDivElement: "readonly",
HTMLButtonElement: "readonly",
HTMLAnchorElement: "readonly",
MouseEvent: "readonly",
Node: "readonly",
File: "readonly",
Blob: "readonly",
FormData: "readonly",
alert: "readonly",
confirm: "readonly",
prompt: "readonly",
React: "readonly",
},
},
plugins: {
"@typescript-eslint": typescriptPlugin,
react: reactPlugin,
"react-hooks": reactHooksPlugin,
tailwindcss,
prettier,
},
rules: {
...eslintConfigPrettier.rules,
...tailwindcss.configs.recommended.rules,
"react/react-in-jsx-scope": "off",
"react/prop-types": "off",
"no-console": ["warn", { allow: ["warn", "error"] }],
"no-undef": "off",
"@typescript-eslint/no-explicit-any": "warn",
"no-unused-vars": "off",
"@typescript-eslint/no-unused-vars": [
"error",
{
argsIgnorePattern: "^_",
varsIgnorePattern: "^_",
caughtErrorsIgnorePattern: "^_",
ignoreRestSiblings: true,
},
],
"no-useless-catch": "off",
},
settings: {
react: {
version: "detect",
},
},
},
];

31
frontend/index.html Normal file
View File

@@ -0,0 +1,31 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta
name="viewport"
content="width=device-width, initial-scale=0.90, user-scalable=no, viewport-fit=cover"
/>
<meta name="apple-mobile-web-app-capable" content="yes" />
<meta
name="apple-mobile-web-app-status-bar-style"
content="black-translucent"
/>
<meta
name="theme-color"
content="#F3F4F6"
media="(prefers-color-scheme: light)"
/>
<meta
name="theme-color"
content="#1F2937"
media="(prefers-color-scheme: dark)"
/>
<title>AnthoLume</title>
<link rel="manifest" href="/manifest.json" />
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>

21
frontend/orval.config.ts Normal file
View File

@@ -0,0 +1,21 @@
import { defineConfig } from 'orval';
export default defineConfig({
antholume: {
output: {
mode: 'split',
baseUrl: '/api/v1',
target: 'src/generated',
schemas: 'src/generated/model',
client: 'react-query',
mock: false,
override: {
useQuery: true,
mutations: true,
},
},
input: {
target: '../api/v1/openapi.yaml',
},
},
});

56
frontend/package.json Normal file
View File

@@ -0,0 +1,56 @@
{
"name": "antholume-frontend",
"private": true,
"version": "1.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"typecheck": "tsc --noEmit",
"build": "tsc && vite build",
"preview": "vite preview",
"generate:api": "orval",
"lint": "eslint src --max-warnings=0",
"lint:fix": "eslint src --fix",
"format": "prettier --check src",
"format:fix": "prettier --write src",
"test": "vitest run"
},
"dependencies": {
"@tanstack/react-query": "^5.62.16",
"ajv": "^8.18.0",
"axios": "^1.13.6",
"clsx": "^2.1.1",
"epubjs": "^0.3.93",
"nosleep.js": "^0.12.0",
"orval": "8.5.3",
"react": "^19.0.0",
"react-dom": "^19.0.0",
"react-router-dom": "^7.1.1",
"tailwind-merge": "^3.5.0"
},
"devDependencies": {
"@eslint/js": "^9.17.0",
"@testing-library/jest-dom": "^6.9.1",
"@testing-library/react": "^16.3.2",
"@testing-library/user-event": "^14.6.1",
"@types/react": "^19.0.8",
"@types/react-dom": "^19.0.8",
"@typescript-eslint/eslint-plugin": "^8.13.0",
"@typescript-eslint/parser": "^8.13.0",
"@vitejs/plugin-react": "^4.3.4",
"autoprefixer": "^10.4.20",
"eslint": "^9.17.0",
"eslint-config-prettier": "^9.1.0",
"eslint-plugin-prettier": "^5.2.1",
"eslint-plugin-react": "^7.37.5",
"eslint-plugin-react-hooks": "^5.0.0",
"eslint-plugin-tailwindcss": "^3.18.2",
"jsdom": "^29.0.1",
"postcss": "^8.4.49",
"prettier": "^3.3.3",
"tailwindcss": "^3.4.17",
"typescript": "~5.6.2",
"vite": "^6.0.5",
"vitest": "^4.1.0"
}
}

View File

@@ -0,0 +1,6 @@
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
};

12
frontend/src/App.tsx Normal file
View File

@@ -0,0 +1,12 @@
import { AuthProvider } from './auth/AuthContext';
import { Routes } from './Routes';
function App() {
return (
<AuthProvider>
<Routes />
</AuthProvider>
);
}
export default App;

134
frontend/src/Routes.tsx Normal file
View File

@@ -0,0 +1,134 @@
import { Route, Routes as ReactRoutes } from 'react-router-dom';
import Layout from './components/Layout';
import HomePage from './pages/HomePage';
import DocumentsPage from './pages/DocumentsPage';
import DocumentPage from './pages/DocumentPage';
import ProgressPage from './pages/ProgressPage';
import ActivityPage from './pages/ActivityPage';
import SearchPage from './pages/SearchPage';
import SettingsPage from './pages/SettingsPage';
import LoginPage from './pages/LoginPage';
import RegisterPage from './pages/RegisterPage';
import AdminPage from './pages/AdminPage';
import AdminImportPage from './pages/AdminImportPage';
import AdminImportResultsPage from './pages/AdminImportResultsPage';
import AdminUsersPage from './pages/AdminUsersPage';
import AdminLogsPage from './pages/AdminLogsPage';
import ReaderPage from './pages/ReaderPage';
import { ProtectedRoute } from './auth/ProtectedRoute';
export function Routes() {
return (
<ReactRoutes>
<Route path="/" element={<Layout />}>
<Route
index
element={
<ProtectedRoute>
<HomePage />
</ProtectedRoute>
}
/>
<Route
path="documents"
element={
<ProtectedRoute>
<DocumentsPage />
</ProtectedRoute>
}
/>
<Route
path="documents/:id"
element={
<ProtectedRoute>
<DocumentPage />
</ProtectedRoute>
}
/>
<Route
path="progress"
element={
<ProtectedRoute>
<ProgressPage />
</ProtectedRoute>
}
/>
<Route
path="activity"
element={
<ProtectedRoute>
<ActivityPage />
</ProtectedRoute>
}
/>
<Route
path="search"
element={
<ProtectedRoute>
<SearchPage />
</ProtectedRoute>
}
/>
<Route
path="settings"
element={
<ProtectedRoute>
<SettingsPage />
</ProtectedRoute>
}
/>
{/* Admin routes */}
<Route
path="admin"
element={
<ProtectedRoute>
<AdminPage />
</ProtectedRoute>
}
/>
<Route
path="admin/import"
element={
<ProtectedRoute>
<AdminImportPage />
</ProtectedRoute>
}
/>
<Route
path="admin/import-results"
element={
<ProtectedRoute>
<AdminImportResultsPage />
</ProtectedRoute>
}
/>
<Route
path="admin/users"
element={
<ProtectedRoute>
<AdminUsersPage />
</ProtectedRoute>
}
/>
<Route
path="admin/logs"
element={
<ProtectedRoute>
<AdminLogsPage />
</ProtectedRoute>
}
/>
</Route>
<Route
path="/reader/:id"
element={
<ProtectedRoute>
<ReaderPage />
</ProtectedRoute>
}
/>
<Route path="/login" element={<LoginPage />} />
<Route path="/register" element={<RegisterPage />} />
</ReactRoutes>
);
}

View File

@@ -0,0 +1,135 @@
import { createContext, useContext, useState, useEffect, ReactNode, useCallback } from 'react';
import { useQueryClient } from '@tanstack/react-query';
import { useNavigate } from 'react-router-dom';
import {
getGetMeQueryKey,
useLogin,
useLogout,
useGetMe,
useRegister,
} from '../generated/anthoLumeAPIV1';
import {
type AuthState,
getAuthenticatedAuthState,
getUnauthenticatedAuthState,
resolveAuthStateFromMe,
validateAuthMutationResponse,
} from './authHelpers';
interface AuthContextType extends AuthState {
login: (_username: string, _password: string) => Promise<void>;
register: (_username: string, _password: string) => Promise<void>;
logout: () => void;
}
const AuthContext = createContext<AuthContextType | undefined>(undefined);
const initialAuthState: AuthState = {
isAuthenticated: false,
user: null,
isCheckingAuth: true,
};
export function AuthProvider({ children }: { children: ReactNode }) {
const [authState, setAuthState] = useState<AuthState>(initialAuthState);
const loginMutation = useLogin();
const registerMutation = useRegister();
const logoutMutation = useLogout();
const { data: meData, error: meError, isLoading: meLoading } = useGetMe();
const queryClient = useQueryClient();
const navigate = useNavigate();
useEffect(() => {
setAuthState(prev =>
resolveAuthStateFromMe({
meData,
meError,
meLoading,
previousState: prev,
})
);
}, [meData, meError, meLoading]);
const login = useCallback(
async (username: string, password: string) => {
try {
const response = await loginMutation.mutateAsync({
data: {
username,
password,
},
});
const user = validateAuthMutationResponse(response, 200);
if (!user) {
setAuthState(getUnauthenticatedAuthState());
throw new Error('Login failed');
}
setAuthState(getAuthenticatedAuthState(user));
await queryClient.invalidateQueries({ queryKey: getGetMeQueryKey() });
navigate('/');
} catch (_error) {
setAuthState(getUnauthenticatedAuthState());
throw new Error('Login failed');
}
},
[loginMutation, navigate, queryClient]
);
const register = useCallback(
async (username: string, password: string) => {
try {
const response = await registerMutation.mutateAsync({
data: {
username,
password,
},
});
const user = validateAuthMutationResponse(response, 201);
if (!user) {
setAuthState(getUnauthenticatedAuthState());
throw new Error('Registration failed');
}
setAuthState(getAuthenticatedAuthState(user));
await queryClient.invalidateQueries({ queryKey: getGetMeQueryKey() });
navigate('/');
} catch (_error) {
setAuthState(getUnauthenticatedAuthState());
throw new Error('Registration failed');
}
},
[navigate, queryClient, registerMutation]
);
const logout = useCallback(() => {
logoutMutation.mutate(undefined, {
onSuccess: async () => {
setAuthState(getUnauthenticatedAuthState());
await queryClient.removeQueries({ queryKey: getGetMeQueryKey() });
navigate('/login');
},
});
}, [logoutMutation, navigate, queryClient]);
return (
<AuthContext.Provider value={{ ...authState, login, register, logout }}>
{children}
</AuthContext.Provider>
);
}
export function useAuth() {
const context = useContext(AuthContext);
if (context === undefined) {
throw new Error('useAuth must be used within an AuthProvider');
}
return context;
}

View File

@@ -0,0 +1,90 @@
import { describe, expect, it, vi, beforeEach } from 'vitest';
import { render, screen } from '@testing-library/react';
import { MemoryRouter, Route, Routes } from 'react-router-dom';
import { ProtectedRoute } from './ProtectedRoute';
import { useAuth } from './AuthContext';
vi.mock('./AuthContext', () => ({
useAuth: vi.fn(),
}));
const mockedUseAuth = vi.mocked(useAuth);
describe('ProtectedRoute', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('shows a loading state while auth is being checked', () => {
mockedUseAuth.mockReturnValue({
isAuthenticated: false,
isCheckingAuth: true,
user: null,
login: vi.fn(),
register: vi.fn(),
logout: vi.fn(),
});
render(
<MemoryRouter initialEntries={['/private']}>
<ProtectedRoute>
<div>Secret</div>
</ProtectedRoute>
</MemoryRouter>
);
expect(screen.getByText('Loading...')).toBeInTheDocument();
expect(screen.queryByText('Secret')).not.toBeInTheDocument();
});
it('redirects unauthenticated users to the login page', () => {
mockedUseAuth.mockReturnValue({
isAuthenticated: false,
isCheckingAuth: false,
user: null,
login: vi.fn(),
register: vi.fn(),
logout: vi.fn(),
});
render(
<MemoryRouter initialEntries={['/private']}>
<Routes>
<Route
path="/private"
element={
<ProtectedRoute>
<div>Secret</div>
</ProtectedRoute>
}
/>
<Route path="/login" element={<div>Login Page</div>} />
</Routes>
</MemoryRouter>
);
expect(screen.getByText('Login Page')).toBeInTheDocument();
expect(screen.queryByText('Secret')).not.toBeInTheDocument();
});
it('renders children for authenticated users', () => {
mockedUseAuth.mockReturnValue({
isAuthenticated: true,
isCheckingAuth: false,
user: { username: 'evan', is_admin: false },
login: vi.fn(),
register: vi.fn(),
logout: vi.fn(),
});
render(
<MemoryRouter>
<ProtectedRoute>
<div>Secret</div>
</ProtectedRoute>
</MemoryRouter>
);
expect(screen.getByText('Secret')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,21 @@
import { Navigate, useLocation } from 'react-router-dom';
import { useAuth } from './AuthContext';
interface ProtectedRouteProps {
children: React.ReactNode;
}
export function ProtectedRoute({ children }: ProtectedRouteProps) {
const { isAuthenticated, isCheckingAuth } = useAuth();
const location = useLocation();
if (isCheckingAuth) {
return <div className="text-content-muted">Loading...</div>;
}
if (!isAuthenticated) {
return <Navigate to="/login" state={{ from: location }} replace />;
}
return children;
}

View File

@@ -0,0 +1,157 @@
import { describe, expect, it } from 'vitest';
import {
getCheckingAuthState,
getUnauthenticatedAuthState,
normalizeAuthenticatedUser,
resolveAuthStateFromMe,
validateAuthMutationResponse,
type AuthState,
} from './authHelpers';
const previousState: AuthState = {
isAuthenticated: false,
user: null,
isCheckingAuth: true,
};
describe('authHelpers', () => {
it('normalizes a valid authenticated user payload', () => {
expect(normalizeAuthenticatedUser({ username: 'evan', is_admin: true })).toEqual({
username: 'evan',
is_admin: true,
});
});
it('rejects invalid authenticated user payloads', () => {
expect(normalizeAuthenticatedUser(null)).toBeNull();
expect(normalizeAuthenticatedUser({ username: 'evan' })).toBeNull();
expect(normalizeAuthenticatedUser({ username: 123, is_admin: true })).toBeNull();
expect(normalizeAuthenticatedUser({ username: 'evan', is_admin: 'yes' })).toBeNull();
});
it('returns a checking state while preserving previous auth information', () => {
expect(
getCheckingAuthState({
isAuthenticated: true,
user: { username: 'evan', is_admin: false },
isCheckingAuth: false,
})
).toEqual({
isAuthenticated: true,
user: { username: 'evan', is_admin: false },
isCheckingAuth: true,
});
});
it('resolves auth state from a successful /auth/me response', () => {
expect(
resolveAuthStateFromMe({
meData: {
status: 200,
data: { username: 'evan', is_admin: false },
},
meError: undefined,
meLoading: false,
previousState,
})
).toEqual({
isAuthenticated: true,
user: { username: 'evan', is_admin: false },
isCheckingAuth: false,
});
});
it('resolves auth state to unauthenticated on 401 or query error', () => {
expect(
resolveAuthStateFromMe({
meData: {
status: 401,
},
meError: undefined,
meLoading: false,
previousState,
})
).toEqual(getUnauthenticatedAuthState());
expect(
resolveAuthStateFromMe({
meData: undefined,
meError: new Error('failed'),
meLoading: false,
previousState,
})
).toEqual(getUnauthenticatedAuthState());
});
it('keeps checking state while /auth/me is still loading', () => {
expect(
resolveAuthStateFromMe({
meData: undefined,
meError: undefined,
meLoading: true,
previousState: {
isAuthenticated: true,
user: { username: 'evan', is_admin: true },
isCheckingAuth: false,
},
})
).toEqual({
isAuthenticated: true,
user: { username: 'evan', is_admin: true },
isCheckingAuth: true,
});
});
it('returns the previous state with checking disabled when there is no decisive me result', () => {
expect(
resolveAuthStateFromMe({
meData: {
status: 204,
},
meError: undefined,
meLoading: false,
previousState: {
isAuthenticated: false,
user: null,
isCheckingAuth: true,
},
})
).toEqual({
isAuthenticated: false,
user: null,
isCheckingAuth: false,
});
});
it('validates auth mutation responses by expected status and payload shape', () => {
expect(
validateAuthMutationResponse(
{
status: 200,
data: { username: 'evan', is_admin: false },
},
200
)
).toEqual({ username: 'evan', is_admin: false });
expect(
validateAuthMutationResponse(
{
status: 201,
data: { username: 'evan', is_admin: false },
},
200
)
).toBeNull();
expect(
validateAuthMutationResponse(
{
status: 200,
data: { username: 'evan' },
},
200
)
).toBeNull();
});
});

View File

@@ -0,0 +1,98 @@
export interface AuthUser {
username: string;
is_admin: boolean;
}
export interface AuthState {
isAuthenticated: boolean;
user: AuthUser | null;
isCheckingAuth: boolean;
}
interface ResponseLike {
status?: number;
data?: unknown;
}
export function getUnauthenticatedAuthState(): AuthState {
return {
isAuthenticated: false,
user: null,
isCheckingAuth: false,
};
}
export function getCheckingAuthState(previousState?: AuthState): AuthState {
return {
isAuthenticated: previousState?.isAuthenticated ?? false,
user: previousState?.user ?? null,
isCheckingAuth: true,
};
}
export function getAuthenticatedAuthState(user: AuthUser): AuthState {
return {
isAuthenticated: true,
user,
isCheckingAuth: false,
};
}
export function normalizeAuthenticatedUser(value: unknown): AuthUser | null {
if (!value || typeof value !== 'object') {
return null;
}
if (!('username' in value) || typeof value.username !== 'string') {
return null;
}
if (!('is_admin' in value) || typeof value.is_admin !== 'boolean') {
return null;
}
return {
username: value.username,
is_admin: value.is_admin,
};
}
export function resolveAuthStateFromMe(params: {
meData?: ResponseLike;
meError?: unknown;
meLoading: boolean;
previousState: AuthState;
}): AuthState {
const { meData, meError, meLoading, previousState } = params;
if (meLoading) {
return getCheckingAuthState(previousState);
}
if (meData?.status === 200) {
const user = normalizeAuthenticatedUser(meData.data);
if (user) {
return getAuthenticatedAuthState(user);
}
}
if (meError || meData?.status === 401) {
return getUnauthenticatedAuthState();
}
return {
...previousState,
isCheckingAuth: false,
};
}
export function validateAuthMutationResponse(
response: ResponseLike,
expectedStatus: number
): AuthUser | null {
if (response.status !== expectedStatus) {
return null;
}
return normalizeAuthenticatedUser(response.data);
}

View File

@@ -0,0 +1,11 @@
import { describe, expect, it } from 'vitest';
import { setupAuthInterceptors } from './authInterceptor';
describe('setupAuthInterceptors', () => {
it('is a no-op when auth is handled by HttpOnly cookies', () => {
const cleanup = setupAuthInterceptors();
expect(typeof cleanup).toBe('function');
expect(() => cleanup()).not.toThrow();
});
});

View File

@@ -0,0 +1,3 @@
export function setupAuthInterceptors() {
return () => {};
}

View File

@@ -0,0 +1,45 @@
import { ButtonHTMLAttributes, AnchorHTMLAttributes, forwardRef } from 'react';
interface BaseButtonProps {
variant?: 'default' | 'secondary';
children: React.ReactNode;
className?: string;
}
type ButtonProps = BaseButtonProps & ButtonHTMLAttributes<HTMLButtonElement>;
type LinkProps = BaseButtonProps & AnchorHTMLAttributes<HTMLAnchorElement> & { href: string };
const getVariantClasses = (variant: 'default' | 'secondary' = 'default'): string => {
const baseClass =
'h-full w-full px-2 py-1 font-medium transition duration-100 ease-in disabled:cursor-not-allowed disabled:opacity-50';
if (variant === 'secondary') {
return `${baseClass} bg-content text-content-inverse shadow-md hover:bg-content-muted disabled:hover:bg-content`;
}
return `${baseClass} bg-primary-500 text-primary-foreground hover:bg-primary-700 disabled:hover:bg-primary-500`;
};
export const Button = forwardRef<HTMLButtonElement, ButtonProps>(
({ variant = 'default', children, className = '', ...props }, ref) => {
return (
<button ref={ref} className={`${getVariantClasses(variant)} ${className}`.trim()} {...props}>
{children}
</button>
);
}
);
Button.displayName = 'Button';
export const ButtonLink = forwardRef<HTMLAnchorElement, LinkProps>(
({ variant = 'default', children, className = '', ...props }, ref) => {
return (
<a ref={ref} className={`${getVariantClasses(variant)} ${className}`.trim()} {...props}>
{children}
</a>
);
}
);
ButtonLink.displayName = 'ButtonLink';

View File

@@ -0,0 +1,41 @@
import { ReactNode } from 'react';
interface FieldProps {
label: ReactNode;
children: ReactNode;
isEditing?: boolean;
}
export function Field({ label, children, isEditing: _isEditing = false }: FieldProps) {
return (
<div className="relative rounded">
<div className="relative inline-flex gap-2 text-content-muted">{label}</div>
{children}
</div>
);
}
interface FieldLabelProps {
children: ReactNode;
}
export function FieldLabel({ children }: FieldLabelProps) {
return <p>{children}</p>;
}
interface FieldValueProps {
children: ReactNode;
className?: string;
}
export function FieldValue({ children, className = '' }: FieldValueProps) {
return <p className={`text-lg font-medium ${className}`}>{children}</p>;
}
interface FieldActionsProps {
children: ReactNode;
}
export function FieldActions({ children }: FieldActionsProps) {
return <div className="inline-flex gap-2">{children}</div>;
}

View File

@@ -0,0 +1,181 @@
import { useState } from 'react';
import { Link, useLocation } from 'react-router-dom';
import { HomeIcon, DocumentsIcon, ActivityIcon, SearchIcon, SettingsIcon, GitIcon } from '../icons';
import { useAuth } from '../auth/AuthContext';
import { useGetInfo } from '../generated/anthoLumeAPIV1';
interface NavItem {
path: string;
label: string;
icon: React.ElementType;
title: string;
}
const navItems: NavItem[] = [
{ path: '/', label: 'Home', icon: HomeIcon, title: 'Home' },
{ path: '/documents', label: 'Documents', icon: DocumentsIcon, title: 'Documents' },
{ path: '/progress', label: 'Progress', icon: ActivityIcon, title: 'Progress' },
{ path: '/activity', label: 'Activity', icon: ActivityIcon, title: 'Activity' },
{ path: '/search', label: 'Search', icon: SearchIcon, title: 'Search' },
];
const adminSubItems: NavItem[] = [
{ path: '/admin', label: 'General', icon: SettingsIcon, title: 'General' },
{ path: '/admin/import', label: 'Import', icon: SettingsIcon, title: 'Import' },
{ path: '/admin/users', label: 'Users', icon: SettingsIcon, title: 'Users' },
{ path: '/admin/logs', label: 'Logs', icon: SettingsIcon, title: 'Logs' },
];
function hasPrefix(path: string, prefix: string): boolean {
return path.startsWith(prefix);
}
export default function HamburgerMenu() {
const location = useLocation();
const { user } = useAuth();
const [isOpen, setIsOpen] = useState(false);
const isAdmin = user?.is_admin ?? false;
const { data: infoData } = useGetInfo({
query: {
staleTime: Infinity,
},
});
const version =
infoData && 'data' in infoData && infoData.data && 'version' in infoData.data
? infoData.data.version
: 'v1.0.0';
return (
<div className="relative z-40 ml-6 flex flex-col">
<input
type="checkbox"
className="absolute -top-2 z-50 flex size-7 cursor-pointer opacity-0 lg:hidden"
id="mobile-nav-checkbox"
checked={isOpen}
onChange={e => setIsOpen(e.target.checked)}
/>
<span
className="z-40 mt-0.5 h-0.5 w-7 bg-content transition-opacity duration-500 lg:hidden"
style={{
transformOrigin: '5px 0px',
transition:
'transform 0.5s cubic-bezier(0.77, 0.2, 0.05, 1), background 0.5s cubic-bezier(0.77, 0.2, 0.05, 1), opacity 0.55s ease',
transform: isOpen ? 'rotate(45deg) translate(2px, -2px)' : 'none',
}}
/>
<span
className="z-40 mt-1 h-0.5 w-7 bg-content transition-opacity duration-500 lg:hidden"
style={{
transformOrigin: '0% 100%',
transition:
'transform 0.5s cubic-bezier(0.77, 0.2, 0.05, 1), background 0.5s cubic-bezier(0.77, 0.2, 0.05, 1), opacity 0.55s ease',
opacity: isOpen ? 0 : 1,
transform: isOpen ? 'rotate(0deg) scale(0.2, 0.2)' : 'none',
}}
/>
<span
className="z-40 mt-1 h-0.5 w-7 bg-content transition-opacity duration-500 lg:hidden"
style={{
transformOrigin: '0% 0%',
transition:
'transform 0.5s cubic-bezier(0.77, 0.2, 0.05, 1), background 0.5s cubic-bezier(0.77, 0.2, 0.05, 1), opacity 0.55s ease',
transform: isOpen ? 'rotate(-45deg) translate(0, 6px)' : 'none',
}}
/>
<div
id="menu"
className="fixed -ml-6 h-full w-56 bg-surface shadow-lg lg:w-48"
style={{
top: 0,
paddingTop: 'env(safe-area-inset-top)',
transformOrigin: '0% 0%',
transform: isOpen ? 'none' : 'translate(-100%, 0)',
transition: 'transform 0.5s cubic-bezier(0.77, 0.2, 0.05, 1)',
}}
>
<style>{`
@media (min-width: 1024px) {
#menu {
transform: none !important;
}
}
`}</style>
<div className="flex h-16 justify-end lg:justify-around">
<p className="my-auto pr-8 text-right text-xl font-bold text-content lg:pr-0">AnthoLume</p>
</div>
<nav>
{navItems.map(item => (
<Link
key={item.path}
to={item.path}
onClick={() => setIsOpen(false)}
className={`my-2 flex w-full items-center justify-start border-l-4 p-2 pl-6 transition-colors duration-200 ${
location.pathname === item.path
? 'border-primary-500 text-content'
: 'border-transparent text-content-subtle hover:text-content'
}`}
>
<item.icon size={20} />
<span className="mx-4 text-sm font-normal">{item.label}</span>
</Link>
))}
{isAdmin && (
<div
className={`my-2 flex flex-col gap-4 border-l-4 p-2 pl-6 transition-colors duration-200 ${
hasPrefix(location.pathname, '/admin')
? 'border-primary-500 text-content'
: 'border-transparent text-content-subtle'
}`}
>
<Link
to="/admin"
onClick={() => setIsOpen(false)}
className={`flex w-full justify-start ${
location.pathname === '/admin' && !hasPrefix(location.pathname, '/admin/')
? 'text-content'
: 'text-content-subtle hover:text-content'
}`}
>
<SettingsIcon size={20} />
<span className="mx-4 text-sm font-normal">Admin</span>
</Link>
{hasPrefix(location.pathname, '/admin') && (
<div className="flex flex-col gap-4">
{adminSubItems.map(item => (
<Link
key={item.path}
to={item.path}
onClick={() => setIsOpen(false)}
className={`flex w-full justify-start ${
location.pathname === item.path
? 'text-content'
: 'text-content-subtle hover:text-content'
}`}
style={{ paddingLeft: '1.75em' }}
>
<span className="mx-4 text-sm font-normal">{item.label}</span>
</Link>
))}
</div>
)}
</div>
)}
</nav>
<a
className="absolute bottom-0 flex w-full flex-col items-center justify-center gap-2 p-6 text-content"
target="_blank"
href="https://gitea.va.reichard.io/evan/AnthoLume"
rel="noreferrer"
>
<GitIcon size={20} />
<span className="text-xs">{version}</span>
</a>
</div>
</div>
);
}

View File

@@ -0,0 +1,178 @@
import { useState, useEffect, useRef } from 'react';
import { Link, useLocation, Outlet, Navigate } from 'react-router-dom';
import { useGetMe } from '../generated/anthoLumeAPIV1';
import { useAuth } from '../auth/AuthContext';
import { UserIcon, DropdownIcon } from '../icons';
import { useTheme } from '../theme/ThemeProvider';
import type { ThemeMode } from '../utils/localSettings';
import HamburgerMenu from './HamburgerMenu';
const themeModes: ThemeMode[] = ['light', 'dark', 'system'];
export default function Layout() {
const location = useLocation();
const { isAuthenticated, user, logout, isCheckingAuth } = useAuth();
const { themeMode, setThemeMode } = useTheme();
const { data } = useGetMe(isAuthenticated ? {} : undefined);
const fetchedUser =
data?.status === 200 && data.data && 'username' in data.data ? data.data : null;
const userData = user ?? fetchedUser;
const [isUserDropdownOpen, setIsUserDropdownOpen] = useState(false);
const dropdownRef = useRef<HTMLDivElement>(null);
const handleLogout = () => {
logout();
setIsUserDropdownOpen(false);
};
useEffect(() => {
const handleClickOutside = (event: MouseEvent) => {
if (dropdownRef.current && !dropdownRef.current.contains(event.target as Node)) {
setIsUserDropdownOpen(false);
}
};
document.addEventListener('mousedown', handleClickOutside);
return () => {
document.removeEventListener('mousedown', handleClickOutside);
};
}, []);
const navItems = [
{ path: '/admin/import-results', title: 'Admin - Import' },
{ path: '/admin/import', title: 'Admin - Import' },
{ path: '/admin/users', title: 'Admin - Users' },
{ path: '/admin/logs', title: 'Admin - Logs' },
{ path: '/admin', title: 'Admin - General' },
{ path: '/documents', title: 'Documents' },
{ path: '/progress', title: 'Progress' },
{ path: '/activity', title: 'Activity' },
{ path: '/search', title: 'Search' },
{ path: '/settings', title: 'Settings' },
{ path: '/', title: 'Home' },
];
const currentPageTitle =
navItems.find(item =>
item.path === '/' ? location.pathname === item.path : location.pathname.startsWith(item.path)
)?.title || 'Home';
useEffect(() => {
document.title = `AnthoLume - ${currentPageTitle}`;
}, [currentPageTitle]);
if (isCheckingAuth) {
return <div className="text-content-muted">Loading...</div>;
}
if (!isAuthenticated) {
return <Navigate to="/login" replace />;
}
return (
<div className="min-h-screen bg-canvas">
<div className="flex h-16 w-full items-center justify-between">
<HamburgerMenu />
<h1 className="whitespace-nowrap px-6 text-xl font-bold text-content lg:ml-44">
{currentPageTitle}
</h1>
<div
className="relative flex w-full items-center justify-end space-x-4 p-4"
ref={dropdownRef}
>
<button
onClick={() => setIsUserDropdownOpen(!isUserDropdownOpen)}
className="relative block text-content"
>
<UserIcon size={20} />
</button>
{isUserDropdownOpen && (
<div className="absolute right-4 top-16 z-20 pt-4 transition duration-200">
<div className="w-64 origin-top-right rounded-md bg-surface shadow-lg ring-1 ring-border/30">
<div
className="border-b border-border px-4 py-3"
role="group"
aria-label="Theme mode"
>
<p className="mb-2 text-xs font-semibold uppercase tracking-wide text-content-subtle">
Theme
</p>
<div className="inline-flex w-full rounded border border-border bg-surface-muted p-1">
{themeModes.map(mode => (
<button
key={mode}
type="button"
onClick={() => setThemeMode(mode)}
className={`flex-1 rounded px-2 py-1 text-xs font-medium capitalize transition-colors ${
themeMode === mode
? 'bg-content text-content-inverse'
: 'text-content-muted hover:bg-surface hover:text-content'
}`}
>
{mode}
</button>
))}
</div>
</div>
<div
className="py-1"
role="menu"
aria-orientation="vertical"
aria-labelledby="options-menu"
>
<Link
to="/settings"
onClick={() => setIsUserDropdownOpen(false)}
className="block px-4 py-2 text-content-muted hover:bg-surface-muted hover:text-content"
role="menuitem"
>
<span className="flex flex-col">
<span>Settings</span>
</span>
</Link>
<button
onClick={handleLogout}
className="block w-full px-4 py-2 text-left text-content-muted hover:bg-surface-muted hover:text-content"
role="menuitem"
>
<span className="flex flex-col">
<span>Logout</span>
</span>
</button>
</div>
</div>
</div>
)}
<button
onClick={() => setIsUserDropdownOpen(!isUserDropdownOpen)}
className="flex cursor-pointer items-center gap-2 py-4 text-content-muted"
>
<span>{userData ? ('username' in userData ? userData.username : 'User') : 'User'}</span>
<span
className="text-content transition-transform duration-200"
style={{ transform: isUserDropdownOpen ? 'rotate(180deg)' : 'rotate(0deg)' }}
>
<DropdownIcon size={20} />
</span>
</button>
</div>
</div>
<main
className="relative overflow-hidden"
style={{ height: 'calc(100dvh - 4rem - env(safe-area-inset-top))' }}
>
<div
id="container"
className="h-dvh overflow-auto px-4 md:px-6 lg:ml-48"
style={{ paddingBottom: 'calc(5em + env(safe-area-inset-bottom) * 2)' }}
>
<Outlet />
</div>
</main>
</div>
);
}

View File

@@ -0,0 +1,21 @@
import { LoadingIcon } from '../icons';
import { cn } from '../utils/cn';
interface LoadingStateProps {
message?: string;
className?: string;
iconSize?: number;
}
export function LoadingState({
message = 'Loading...',
className = '',
iconSize = 24,
}: LoadingStateProps) {
return (
<div className={cn('flex items-center justify-center gap-3 text-content-muted', className)}>
<LoadingIcon size={iconSize} className="text-primary-500" />
<span className="text-sm font-medium">{message}</span>
</div>
);
}

View File

@@ -0,0 +1,203 @@
# UI Components
This directory contains reusable UI components for the AnthoLume application.
## Toast Notifications
### Usage
The toast system provides info, warning, and error notifications that respect the current theme and dark/light mode.
```tsx
import { useToasts } from './components/ToastContext';
function MyComponent() {
const { showInfo, showWarning, showError, showToast } = useToasts();
const handleAction = async () => {
try {
// Do something
showInfo('Operation completed successfully!');
} catch (error) {
showError('An error occurred while processing your request.');
}
};
return <button onClick={handleAction}>Click me</button>;
}
```
### API
- `showToast(message: string, type?: 'info' | 'warning' | 'error', duration?: number): string`
- Shows a toast notification
- Returns the toast ID for manual removal
- Default type: 'info'
- Default duration: 5000ms (0 = no auto-dismiss)
- `showInfo(message: string, duration?: number): string`
- Shortcut for showing an info toast
- `showWarning(message: string, duration?: number): string`
- Shortcut for showing a warning toast
- `showError(message: string, duration?: number): string`
- Shortcut for showing an error toast
- `removeToast(id: string): void`
- Manually remove a toast by ID
- `clearToasts(): void`
- Clear all active toasts
### Examples
```tsx
// Info toast (auto-dismisses after 5 seconds)
showInfo('Document saved successfully!');
// Warning toast (auto-dismisses after 10 seconds)
showWarning('Low disk space warning', 10000);
// Error toast (no auto-dismiss)
showError('Failed to load data', 0);
// Generic toast
showToast('Custom message', 'warning', 3000);
```
## Skeleton Loading
### Usage
Skeleton components provide placeholder content while data is loading. They automatically adapt to dark/light mode.
### Components
#### `Skeleton`
Basic skeleton element with various variants:
```tsx
import { Skeleton } from './components/Skeleton';
// Default (rounded rectangle)
<Skeleton className="w-full h-8" />
// Text variant
<Skeleton variant="text" className="w-3/4" />
// Circular variant (for avatars)
<Skeleton variant="circular" width={40} height={40} />
// Rectangular variant
<Skeleton variant="rectangular" width="100%" height={200} />
```
#### `SkeletonText`
Multiple lines of text skeleton:
```tsx
<SkeletonText lines={3} />
<SkeletonText lines={5} className="max-w-md" />
```
#### `SkeletonAvatar`
Avatar placeholder:
```tsx
<SkeletonAvatar size="md" />
<SkeletonAvatar size={56} />
```
#### `SkeletonCard`
Card placeholder with optional elements:
```tsx
// Default card
<SkeletonCard />
// With avatar
<SkeletonCard showAvatar />
// Custom configuration
<SkeletonCard
showAvatar
showTitle
showText
textLines={4}
className="max-w-sm"
/>
```
#### `SkeletonTable`
Table placeholder:
```tsx
<SkeletonTable rows={5} columns={4} />
<SkeletonTable rows={10} columns={6} showHeader={false} />
```
#### `SkeletonButton`
Button placeholder:
```tsx
<SkeletonButton width={120} />
<SkeletonButton className="w-full" />
```
#### `PageLoader`
Full-page loading indicator:
```tsx
<PageLoader message="Loading your documents..." />
```
#### `InlineLoader`
Small inline loading spinner:
```tsx
<InlineLoader size="sm" />
<InlineLoader size="md" />
<InlineLoader size="lg" />
```
## Integration with Table Component
The Table component now supports skeleton loading:
```tsx
import { Table, SkeletonTable } from './components/Table';
function DocumentList() {
const { data, isLoading } = useGetDocuments();
if (isLoading) {
return <SkeletonTable rows={10} columns={5} />;
}
return <Table columns={columns} data={data?.documents || []} />;
}
```
## Theme Support
All components automatically adapt to the current theme:
- **Light mode**: Uses gray tones for skeletons, appropriate colors for toasts
- **Dark mode**: Uses darker gray tones for skeletons, adjusted colors for toasts
The theme is controlled via Tailwind's `dark:` classes, which respond to the system preference or manual theme toggles.
## Dependencies
- `clsx` - Utility for constructing className strings
- `tailwind-merge` - Merges Tailwind CSS classes intelligently
- `lucide-react` - Icon library used by Toast component

View File

@@ -0,0 +1,53 @@
import { describe, expect, it } from 'vitest';
import { getSVGGraphData } from './ReadingHistoryGraph';
// Intentionally exact fixture data for algorithm parity coverage
const testInput = [
{ date: '2024-01-01', minutes_read: 10 },
{ date: '2024-01-02', minutes_read: 90 },
{ date: '2024-01-03', minutes_read: 50 },
{ date: '2024-01-04', minutes_read: 5 },
{ date: '2024-01-05', minutes_read: 10 },
{ date: '2024-01-06', minutes_read: 5 },
{ date: '2024-01-07', minutes_read: 70 },
{ date: '2024-01-08', minutes_read: 60 },
{ date: '2024-01-09', minutes_read: 50 },
{ date: '2024-01-10', minutes_read: 90 },
];
const svgWidth = 500;
const svgHeight = 100;
describe('ReadingHistoryGraph', () => {
describe('getSVGGraphData', () => {
it('should match exactly', () => {
const result = getSVGGraphData(testInput, svgWidth, svgHeight);
// Expected exact algorithm output
const expectedBezierPath =
'M 50,95 C63,95 80,50 100,50 C120,50 128,73 150,73 C172,73 180,98 200,98 C220,98 230,95 250,95 C270,95 279,98 300,98 C321,98 330,62 350,62 C370,62 380,67 400,67 C420,67 430,73 450,73 C470,73 489,50 500,50';
const expectedBezierFill = 'L 500,98 L 50,98 Z';
const expectedWidth = 500;
const expectedHeight = 100;
const expectedOffset = 50;
expect(result.BezierPath).toBe(expectedBezierPath);
expect(result.BezierFill).toBe(expectedBezierFill);
expect(svgWidth).toBe(expectedWidth);
expect(svgHeight).toBe(expectedHeight);
expect(result.Offset).toBe(expectedOffset);
// Verify line points are integer pixel values
result.LinePoints.forEach((p, _i) => {
expect(Number.isInteger(p.x)).toBe(true);
expect(Number.isInteger(p.y)).toBe(true);
});
// Expected line points from the current algorithm:
// idx 0: itemSize=5, itemY=95, lineX=50
// idx 1: itemSize=45, itemY=55, lineX=100
// idx 2: itemSize=25, itemY=75, lineX=150
// ...and so on
});
});
});

View File

@@ -0,0 +1,210 @@
import type { GraphDataPoint } from '../generated/model';
interface ReadingHistoryGraphProps {
data: GraphDataPoint[];
}
export interface SVGPoint {
x: number;
y: number;
}
function getSVGBezierOpposedLine(
pointA: SVGPoint,
pointB: SVGPoint
): { Length: number; Angle: number } {
const lengthX = pointB.x - pointA.x;
const lengthY = pointB.y - pointA.y;
return {
Length: Math.floor(Math.sqrt(lengthX * lengthX + lengthY * lengthY)),
Angle: Math.trunc(Math.atan2(lengthY, lengthX)),
};
}
function getBezierControlPoint(
currentPoint: SVGPoint,
prevPoint: SVGPoint | null,
nextPoint: SVGPoint | null,
isReverse: boolean
): SVGPoint {
let pPrev = prevPoint;
let pNext = nextPoint;
if (!pPrev) {
pPrev = currentPoint;
}
if (!pNext) {
pNext = currentPoint;
}
const smoothingRatio = 0.2;
const directionModifier = isReverse ? Math.PI : 0;
const opposingLine = getSVGBezierOpposedLine(pPrev, pNext);
const lineAngle = opposingLine.Angle + directionModifier;
const lineLength = opposingLine.Length * smoothingRatio;
return {
x: Math.floor(currentPoint.x + Math.trunc(Math.cos(lineAngle) * lineLength)),
y: Math.floor(currentPoint.y + Math.trunc(Math.sin(lineAngle) * lineLength)),
};
}
function getSVGBezierPath(points: SVGPoint[]): string {
if (points.length === 0) {
return '';
}
let bezierSVGPath = '';
for (let index = 0; index < points.length; index++) {
const point = points[index];
if (!point) {
continue;
}
if (index === 0) {
bezierSVGPath += `M ${point.x},${point.y}`;
continue;
}
const pointMinusOne = points[index - 1];
if (!pointMinusOne) {
continue;
}
const pointPlusOne = points[index + 1] ?? point;
const pointMinusTwo = index - 2 >= 0 ? (points[index - 2] ?? null) : null;
const startControlPoint = getBezierControlPoint(pointMinusOne, pointMinusTwo, point, false);
const endControlPoint = getBezierControlPoint(point, pointMinusOne, pointPlusOne, true);
bezierSVGPath += ` C${startControlPoint.x},${startControlPoint.y} ${endControlPoint.x},${endControlPoint.y} ${point.x},${point.y}`;
}
return bezierSVGPath;
}
export interface SVGGraphData {
LinePoints: SVGPoint[];
BezierPath: string;
BezierFill: string;
Offset: number;
}
export function getSVGGraphData(
inputData: GraphDataPoint[],
svgWidth: number,
svgHeight: number
): SVGGraphData {
let maxHeight = 0;
for (const item of inputData) {
if (item.minutes_read > maxHeight) {
maxHeight = item.minutes_read;
}
}
const sizePercentage = 0.5;
const sizeRatio = maxHeight > 0 ? (svgHeight * sizePercentage) / maxHeight : 0;
const blockOffset = inputData.length > 0 ? Math.floor(svgWidth / inputData.length) : 0;
const linePoints: SVGPoint[] = [];
let maxBX = 0;
let maxBY = 0;
let minBX = 0;
for (let idx = 0; idx < inputData.length; idx++) {
const item = inputData[idx];
if (!item) {
continue;
}
const itemSize = Math.floor(item.minutes_read * sizeRatio);
const itemY = svgHeight - itemSize;
const lineX = (idx + 1) * blockOffset;
linePoints.push({ x: lineX, y: itemY });
if (lineX > maxBX) {
maxBX = lineX;
}
if (lineX < minBX) {
minBX = lineX;
}
if (itemY > maxBY) {
maxBY = itemY;
}
}
return {
LinePoints: linePoints,
BezierPath: getSVGBezierPath(linePoints),
BezierFill: `L ${Math.floor(maxBX)},${Math.floor(maxBY)} L ${Math.floor(minBX + blockOffset)},${Math.floor(maxBY)} Z`,
Offset: blockOffset,
};
}
function formatDate(dateString: string): string {
const date = new Date(dateString);
const year = date.getUTCFullYear();
const month = String(date.getUTCMonth() + 1).padStart(2, '0');
const day = String(date.getUTCDate()).padStart(2, '0');
return `${year}-${month}-${day}`;
}
export default function ReadingHistoryGraph({ data }: ReadingHistoryGraphProps) {
const svgWidth = 800;
const svgHeight = 70;
if (!data || data.length < 2) {
return (
<div className="relative flex h-24 items-center justify-center bg-surface-muted">
<p className="text-content-subtle">No data available</p>
</div>
);
}
const { BezierPath, BezierFill } = getSVGGraphData(data, svgWidth, svgHeight);
return (
<div className="relative">
<svg viewBox={`26 0 755 ${svgHeight}`} preserveAspectRatio="none" width="100%" height="6em">
<path fill="rgb(var(--secondary-600))" fillOpacity="0.5" stroke="none" d={`${BezierPath} ${BezierFill}`} />
<path fill="none" stroke="rgb(var(--secondary-600))" d={BezierPath} />
</svg>
<div
className="absolute top-0 flex size-full"
style={{
width: 'calc(100% * 31 / 30)',
transform: 'translateX(-50%)',
left: '50%',
}}
>
{data.map((point, i) => (
<div
key={i}
className="w-full opacity-0 hover:opacity-100"
style={{
background:
'linear-gradient(rgba(128, 128, 128, 0.5), rgba(128, 128, 128, 0.5)) no-repeat center/2px 100%',
}}
>
<div
className="pointer-events-none absolute top-3 flex flex-col items-center rounded bg-surface/80 p-2 text-xs text-content"
style={{
transform: 'translateX(-50%)',
left: '50%',
}}
>
<span>{formatDate(point.date)}</span>
<span>{point.minutes_read} minutes</span>
</div>
</div>
))}
</div>
</div>
);
}

View File

@@ -0,0 +1,215 @@
import { cn } from '../utils/cn';
interface SkeletonProps {
className?: string;
variant?: 'default' | 'text' | 'circular' | 'rectangular';
width?: string | number;
height?: string | number;
animation?: 'pulse' | 'wave' | 'none';
}
export function Skeleton({
className = '',
variant = 'default',
width,
height,
animation = 'pulse',
}: SkeletonProps) {
const baseClasses = 'bg-surface-strong';
const variantClasses = {
default: 'rounded',
text: 'h-4 rounded-md',
circular: 'rounded-full',
rectangular: 'rounded-none',
};
const animationClasses = {
pulse: 'animate-pulse',
wave: 'animate-wave',
none: '',
};
const style = {
width: width !== undefined ? (typeof width === 'number' ? `${width}px` : width) : undefined,
height:
height !== undefined ? (typeof height === 'number' ? `${height}px` : height) : undefined,
};
return (
<div
className={cn(baseClasses, variantClasses[variant], animationClasses[animation], className)}
style={style}
/>
);
}
interface SkeletonTextProps {
lines?: number;
className?: string;
lineClassName?: string;
}
export function SkeletonText({ lines = 3, className = '', lineClassName = '' }: SkeletonTextProps) {
return (
<div className={cn('space-y-2', className)}>
{Array.from({ length: lines }).map((_, i) => (
<Skeleton
key={i}
variant="text"
className={cn(lineClassName, i === lines - 1 && lines > 1 ? 'w-3/4' : 'w-full')}
/>
))}
</div>
);
}
interface SkeletonAvatarProps {
size?: number | 'sm' | 'md' | 'lg';
className?: string;
}
export function SkeletonAvatar({ size = 'md', className = '' }: SkeletonAvatarProps) {
const sizeMap = {
sm: 32,
md: 40,
lg: 56,
};
const pixelSize = typeof size === 'number' ? size : sizeMap[size];
return <Skeleton variant="circular" width={pixelSize} height={pixelSize} className={className} />;
}
interface SkeletonCardProps {
className?: string;
showAvatar?: boolean;
showTitle?: boolean;
showText?: boolean;
textLines?: number;
}
export function SkeletonCard({
className = '',
showAvatar = false,
showTitle = true,
showText = true,
textLines = 3,
}: SkeletonCardProps) {
return (
<div className={cn('rounded-lg border border-border bg-surface p-4', className)}>
{showAvatar && (
<div className="mb-4 flex items-start gap-4">
<SkeletonAvatar />
<div className="flex-1">
<Skeleton variant="text" className="mb-2 w-3/4" />
<Skeleton variant="text" className="w-1/2" />
</div>
</div>
)}
{showTitle && <Skeleton variant="text" className="mb-4 h-6 w-1/2" />}
{showText && <SkeletonText lines={textLines} />}
</div>
);
}
interface SkeletonTableProps {
rows?: number;
columns?: number;
className?: string;
showHeader?: boolean;
}
export function SkeletonTable({
rows = 5,
columns = 4,
className = '',
showHeader = true,
}: SkeletonTableProps) {
return (
<div className={cn('overflow-hidden rounded-lg bg-surface', className)}>
<table className="min-w-full">
{showHeader && (
<thead>
<tr className="border-b border-border">
{Array.from({ length: columns }).map((_, i) => (
<th key={i} className="p-3">
<Skeleton variant="text" className="h-5 w-3/4" />
</th>
))}
</tr>
</thead>
)}
<tbody>
{Array.from({ length: rows }).map((_, rowIndex) => (
<tr key={rowIndex} className="border-b border-border last:border-0">
{Array.from({ length: columns }).map((_, colIndex) => (
<td key={colIndex} className="p-3">
<Skeleton
variant="text"
className={colIndex === columns - 1 ? 'w-1/2' : 'w-full'}
/>
</td>
))}
</tr>
))}
</tbody>
</table>
</div>
);
}
interface SkeletonButtonProps {
className?: string;
width?: string | number;
}
export function SkeletonButton({ className = '', width }: SkeletonButtonProps) {
return (
<Skeleton
variant="rectangular"
height={36}
width={width || '100%'}
className={cn('rounded', className)}
/>
);
}
interface PageLoaderProps {
message?: string;
className?: string;
}
export function PageLoader({ message = 'Loading...', className = '' }: PageLoaderProps) {
return (
<div className={cn('flex min-h-[400px] flex-col items-center justify-center gap-4', className)}>
<div className="relative">
<div className="size-12 animate-spin rounded-full border-4 border-surface-strong border-t-secondary-500" />
</div>
<p className="text-sm font-medium text-content-muted">{message}</p>
</div>
);
}
interface InlineLoaderProps {
size?: 'sm' | 'md' | 'lg';
className?: string;
}
export function InlineLoader({ size = 'md', className = '' }: InlineLoaderProps) {
const sizeMap = {
sm: 'h-4 w-4 border-2',
md: 'h-6 w-6 border-[3px]',
lg: 'h-8 w-8 border-4',
};
return (
<div className={cn('flex items-center justify-center', className)}>
<div
className={`${sizeMap[size]} animate-spin rounded-full border-surface-strong border-t-secondary-500`}
/>
</div>
);
}
export { SkeletonTable as SkeletonTableExport };

View File

@@ -0,0 +1,56 @@
import { describe, expect, it } from 'vitest';
import { render, screen } from '@testing-library/react';
import { Table, type Column } from './Table';
interface TestRow {
id: string;
name: string;
role: string;
}
const columns: Column<TestRow>[] = [
{
key: 'name',
header: 'Name',
},
{
key: 'role',
header: 'Role',
},
];
const data: TestRow[] = [
{ id: 'user-1', name: 'Ada', role: 'Admin' },
{ id: 'user-2', name: 'Grace', role: 'Reader' },
];
describe('Table', () => {
it('renders a skeleton table while loading', () => {
const { container } = render(<Table columns={columns} data={[]} loading />);
expect(screen.queryByText('No Results')).not.toBeInTheDocument();
expect(container.querySelectorAll('tbody tr')).toHaveLength(5);
});
it('renders the empty state message when there is no data', () => {
render(<Table columns={columns} data={[]} emptyMessage="Nothing here" />);
expect(screen.getByText('Nothing here')).toBeInTheDocument();
});
it('uses a custom render function for column output', () => {
const customColumns: Column<TestRow>[] = [
{
key: 'name',
header: 'Name',
render: (_value, row, index) => `${index + 1}. ${row.name.toUpperCase()}`,
},
];
render(<Table columns={customColumns} data={data} />);
expect(screen.getByText('1. ADA')).toBeInTheDocument();
expect(screen.getByText('2. GRACE')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,125 @@
import React from 'react';
import { Skeleton } from './Skeleton';
import { cn } from '../utils/cn';
export interface Column<T extends object> {
key: keyof T;
header: string;
render?: (value: T[keyof T], _row: T, _index: number) => React.ReactNode;
className?: string;
}
export interface TableProps<T extends object> {
columns: Column<T>[];
data: T[];
loading?: boolean;
emptyMessage?: string;
rowKey?: keyof T | ((row: T) => string);
}
function SkeletonTable({
rows = 5,
columns = 4,
className = '',
}: {
rows?: number;
columns?: number;
className?: string;
}) {
return (
<div className={cn('overflow-hidden rounded-lg bg-surface', className)}>
<table className="min-w-full">
<thead>
<tr className="border-b border-border">
{Array.from({ length: columns }).map((_, i) => (
<th key={i} className="p-3">
<Skeleton variant="text" className="h-5 w-3/4" />
</th>
))}
</tr>
</thead>
<tbody>
{Array.from({ length: rows }).map((_, rowIndex) => (
<tr key={rowIndex} className="border-b border-border last:border-0">
{Array.from({ length: columns }).map((_, colIndex) => (
<td key={colIndex} className="p-3">
<Skeleton
variant="text"
className={colIndex === columns - 1 ? 'w-1/2' : 'w-full'}
/>
</td>
))}
</tr>
))}
</tbody>
</table>
</div>
);
}
export function Table<T extends object>({
columns,
data,
loading = false,
emptyMessage = 'No Results',
rowKey,
}: TableProps<T>) {
const getRowKey = (row: T, index: number): string => {
if (typeof rowKey === 'function') {
return rowKey(row);
}
if (rowKey) {
return String(row[rowKey] ?? index);
}
return `row-${index}`;
};
if (loading) {
return <SkeletonTable rows={5} columns={columns.length} />;
}
return (
<div className="overflow-x-auto">
<div className="inline-block min-w-full overflow-hidden rounded shadow">
<table className="min-w-full bg-surface">
<thead>
<tr className="border-b border-border">
{columns.map(column => (
<th
key={String(column.key)}
className={`p-3 text-left text-content-muted ${column.className || ''}`}
>
{column.header}
</th>
))}
</tr>
</thead>
<tbody>
{data.length === 0 ? (
<tr>
<td colSpan={columns.length} className="p-3 text-center text-content-muted">
{emptyMessage}
</td>
</tr>
) : (
data.map((row, index) => (
<tr key={getRowKey(row, index)} className="border-b border-border">
{columns.map(column => (
<td
key={`${getRowKey(row, index)}-${String(column.key)}`}
className={`p-3 text-content ${column.className || ''}`}
>
{column.render
? column.render(row[column.key], row, index)
: (row[column.key] as React.ReactNode)}
</td>
))}
</tr>
))
)}
</tbody>
</table>
</div>
</div>
);
}

View File

@@ -0,0 +1,87 @@
import { useEffect, useState } from 'react';
import { InfoIcon, WarningIcon, ErrorIcon, CloseIcon } from '../icons';
export type ToastType = 'info' | 'warning' | 'error';
export interface ToastProps {
id: string;
type: ToastType;
message: string;
duration?: number;
onClose?: (id: string) => void;
}
const getToastStyles = (_type: ToastType) => {
const baseStyles =
'flex items-center gap-3 rounded-lg border-l-4 p-4 shadow-lg transition-all duration-300';
const typeStyles = {
info: 'border-secondary-500 bg-secondary-100',
warning: 'border-yellow-500 bg-yellow-100',
error: 'border-red-500 bg-red-100',
};
const iconStyles = {
info: 'text-secondary-700',
warning: 'text-yellow-700',
error: 'text-red-700',
};
const textStyles = {
info: 'text-secondary-900',
warning: 'text-yellow-900',
error: 'text-red-900',
};
return { baseStyles, typeStyles, iconStyles, textStyles };
};
export function Toast({ id, type, message, duration = 5000, onClose }: ToastProps) {
const [isVisible, setIsVisible] = useState(true);
const [isAnimatingOut, setIsAnimatingOut] = useState(false);
const { baseStyles, typeStyles, iconStyles, textStyles } = getToastStyles(type);
const handleClose = () => {
setIsAnimatingOut(true);
setTimeout(() => {
setIsVisible(false);
onClose?.(id);
}, 300);
};
useEffect(() => {
if (duration > 0) {
const timer = setTimeout(handleClose, duration);
return () => clearTimeout(timer);
}
}, [duration]);
if (!isVisible) {
return null;
}
const icons = {
info: <InfoIcon size={20} className={iconStyles[type]} />,
warning: <WarningIcon size={20} className={iconStyles[type]} />,
error: <ErrorIcon size={20} className={iconStyles[type]} />,
};
return (
<div
className={`${baseStyles} ${typeStyles[type]} ${
isAnimatingOut ? 'translate-x-full opacity-0' : 'animate-slideInRight opacity-100'
}`}
>
{icons[type]}
<p className={`flex-1 text-sm font-medium ${textStyles[type]}`}>{message}</p>
<button
onClick={handleClose}
className={`ml-2 opacity-70 transition-opacity hover:opacity-100 ${textStyles[type]}`}
aria-label="Close"
>
<CloseIcon size={18} />
</button>
</div>
);
}

View File

@@ -0,0 +1,95 @@
import { createContext, useContext, useState, useCallback, ReactNode } from 'react';
import { Toast, ToastType, ToastProps } from './Toast';
interface ToastContextType {
showToast: (message: string, type?: ToastType, duration?: number) => string;
showInfo: (message: string, duration?: number) => string;
showWarning: (message: string, duration?: number) => string;
showError: (message: string, duration?: number) => string;
removeToast: (id: string) => void;
clearToasts: () => void;
}
const ToastContext = createContext<ToastContextType | undefined>(undefined);
export function ToastProvider({ children }: { children: ReactNode }) {
const [toasts, setToasts] = useState<(ToastProps & { id: string })[]>([]);
const removeToast = useCallback((id: string) => {
setToasts(prev => prev.filter(toast => toast.id !== id));
}, []);
const showToast = useCallback(
(message: string, _type: ToastType = 'info', _duration?: number): string => {
const id = `toast-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
setToasts(prev => [
...prev,
{ id, type: _type, message, duration: _duration, onClose: removeToast },
]);
return id;
},
[removeToast]
);
const showInfo = useCallback(
(message: string, _duration?: number) => {
return showToast(message, 'info', _duration);
},
[showToast]
);
const showWarning = useCallback(
(message: string, _duration?: number) => {
return showToast(message, 'warning', _duration);
},
[showToast]
);
const showError = useCallback(
(message: string, _duration?: number) => {
return showToast(message, 'error', _duration);
},
[showToast]
);
const clearToasts = useCallback(() => {
setToasts([]);
}, []);
return (
<ToastContext.Provider
value={{ showToast, showInfo, showWarning, showError, removeToast, clearToasts }}
>
{children}
<ToastContainer toasts={toasts} />
</ToastContext.Provider>
);
}
interface ToastContainerProps {
toasts: (ToastProps & { id: string })[];
}
function ToastContainer({ toasts }: ToastContainerProps) {
if (toasts.length === 0) {
return null;
}
return (
<div className="pointer-events-none fixed bottom-4 right-4 z-50 flex w-full max-w-sm flex-col gap-2">
{toasts.map(toast => (
<div key={toast.id} className="pointer-events-auto">
<Toast {...toast} />
</div>
))}
</div>
);
}
export function useToasts() {
const context = useContext(ToastContext);
if (context === undefined) {
throw new Error('useToasts must be used within a ToastProvider');
}
return context;
}

View File

@@ -0,0 +1,23 @@
// Reading History Graph
export { default as ReadingHistoryGraph } from './ReadingHistoryGraph';
// Toast components
export { Toast } from './Toast';
export { ToastProvider, useToasts } from './ToastContext';
export type { ToastType, ToastProps } from './Toast';
// Skeleton components
export {
Skeleton,
SkeletonText,
SkeletonAvatar,
SkeletonCard,
SkeletonTable,
SkeletonButton,
PageLoader,
InlineLoader,
} from './Skeleton';
export { LoadingState } from './LoadingState';
// Field components
export { Field, FieldLabel, FieldValue, FieldActions } from './Field';

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
/**
* Generated by orval v8.5.3 🍺
* Do not edit manually.
* AnthoLume API v1
* REST API for AnthoLume document management system
* OpenAPI spec version: 1.0.0
*/
export interface Activity {
document_id: string;
device_id: string;
start_time: string;
title?: string;
author?: string;
duration: number;
start_percentage: number;
end_percentage: number;
read_percentage: number;
}

View File

@@ -0,0 +1,12 @@
/**
* Generated by orval v8.5.3 🍺
* Do not edit manually.
* AnthoLume API v1
* REST API for AnthoLume document management system
* OpenAPI spec version: 1.0.0
*/
import type { Activity } from './activity';
export interface ActivityResponse {
activities: Activity[];
}

View File

@@ -0,0 +1,15 @@
/**
* Generated by orval v8.5.3 🍺
* Do not edit manually.
* AnthoLume API v1
* REST API for AnthoLume document management system
* OpenAPI spec version: 1.0.0
*/
export type BackupType = typeof BackupType[keyof typeof BackupType];
export const BackupType = {
COVERS: 'COVERS',
DOCUMENTS: 'DOCUMENTS',
} as const;

View File

@@ -0,0 +1,13 @@
/**
* Generated by orval v8.5.3 🍺
* Do not edit manually.
* AnthoLume API v1
* REST API for AnthoLume document management system
* OpenAPI spec version: 1.0.0
*/
export interface ConfigResponse {
version: string;
search_enabled: boolean;
registration_enabled: boolean;
}

View File

@@ -0,0 +1,15 @@
/**
* Generated by orval v8.5.3 🍺
* Do not edit manually.
* AnthoLume API v1
* REST API for AnthoLume document management system
* OpenAPI spec version: 1.0.0
*/
export interface CreateActivityItem {
document_id: string;
start_time: number;
duration: number;
page: number;
pages: number;
}

View File

@@ -0,0 +1,14 @@
/**
* Generated by orval v8.5.3 🍺
* Do not edit manually.
* AnthoLume API v1
* REST API for AnthoLume document management system
* OpenAPI spec version: 1.0.0
*/
import type { CreateActivityItem } from './createActivityItem';
export interface CreateActivityRequest {
device_id: string;
device_name: string;
activity: CreateActivityItem[];
}

Some files were not shown because too many files have changed in this diff Show More