feat: v1 API + frontend migration
Some checks failed
continuous-integration/drone/pr Build is failing

This commit is contained in:
2026-07-03 14:06:33 -04:00
parent ff44db311d
commit c9e4d345b8
212 changed files with 26747 additions and 398 deletions

View File

@@ -16,7 +16,6 @@ import (
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/microcosm-cc/bluemonday"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
@@ -114,6 +113,11 @@ func (api *API) Start() error {
return api.httpServer.ListenAndServe()
}
// Handler returns the underlying http.Handler for the Gin router
func (api *API) Handler() http.Handler {
return api.httpServer.Handler
}
func (api *API) Stop() error {
// Stop server
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -293,7 +297,7 @@ func (api *API) loadTemplates(
templateDirectory := fmt.Sprintf("templates/%ss", basePath)
allFiles, err := fs.ReadDir(api.assets, templateDirectory)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("unable to read template dir: %s", templateDirectory))
return fmt.Errorf("unable to read template dir %s: %w", templateDirectory, err)
}
// Generate Templates
@@ -305,7 +309,7 @@ func (api *API) loadTemplates(
// Read Template
b, err := fs.ReadFile(api.assets, templatePath)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("unable to read template: %s", templateName))
return fmt.Errorf("unable to read template %s: %w", templateName, err)
}
// Clone? (Pages - Don't Stomp)
@@ -316,7 +320,7 @@ func (api *API) loadTemplates(
// Parse Template
baseTemplate, err = baseTemplate.New(templateName).Parse(string(b))
if err != nil {
return errors.Wrap(err, fmt.Sprintf("unable to parse template: %s", templateName))
return fmt.Errorf("unable to parse template %s: %w", templateName, err)
}
allTemplates[templateName] = baseTemplate

View File

@@ -22,7 +22,6 @@ import (
"github.com/gabriel-vasile/mimetype"
"github.com/gin-gonic/gin"
"github.com/itchyny/gojq"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
"reichard.io/antholume/metadata"
@@ -722,7 +721,7 @@ func (api *API) createBackup(ctx context.Context, w io.Writer, directories []str
// Vacuum DB
_, err := api.db.DB.ExecContext(ctx, "VACUUM;")
if err != nil {
return errors.Wrap(err, "Unable to vacuum database")
return fmt.Errorf("Unable to vacuum database: %w", err)
}
ar := zip.NewWriter(w)
@@ -796,7 +795,7 @@ func (api *API) createBackup(ctx context.Context, w io.Writer, directories []str
func (api *API) isLastAdmin(ctx context.Context, userID string) (bool, error) {
allUsers, err := api.db.Queries.GetUsers(ctx)
if err != nil {
return false, errors.Wrap(err, fmt.Sprintf("GetUsers DB Error: %v", err))
return false, fmt.Errorf("GetUsers DB Error: %w", err)
}
hasAdmin := false
@@ -873,7 +872,7 @@ func (api *API) updateUser(ctx context.Context, user string, rawPassword *string
} else {
user, err := api.db.Queries.GetUser(ctx, user)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("GetUser DB Error: %v", err))
return fmt.Errorf("GetUser DB Error: %w", err)
}
updateParams.Admin = user.Admin
}
@@ -911,7 +910,7 @@ func (api *API) updateUser(ctx context.Context, user string, rawPassword *string
// Update User
_, err := api.db.Queries.UpdateUser(ctx, updateParams)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("UpdateUser DB Error: %v", err))
return fmt.Errorf("UpdateUser DB Error: %w", err)
}
return nil
@@ -943,7 +942,7 @@ func (api *API) deleteUser(ctx context.Context, user string) error {
// Delete User
_, err = api.db.Queries.DeleteUser(ctx, user)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("DeleteUser DB Error: %v", err))
return fmt.Errorf("DeleteUser DB Error: %w", err)
}
return nil

View File

@@ -124,10 +124,13 @@ func (api *API) appGetDocuments(c *gin.Context) {
return
}
length, err := api.db.Queries.GetDocumentsSize(c, query)
length, err := api.db.Queries.GetDocumentsWithStatsCount(c, database.GetDocumentsWithStatsCountParams{
Deleted: ptr.Of(false),
Query: query,
})
if err != nil {
log.Error("GetDocumentsSize DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsSize DB Error: %v", err))
log.Error("GetDocumentsWithStatsCount DB Error: ", err)
appErrorPage(c, http.StatusInternalServerError, fmt.Sprintf("GetDocumentsWithStatsCount DB Error: %v", err))
return
}

View File

@@ -49,7 +49,45 @@ func (api *API) authorizeCredentials(ctx context.Context, username string, passw
}
}
// resolveDevAuth returns an authData for the dev user when DISABLE_AUTH is
// set. If DISABLE_AUTH_USER names a specific user, that user is looked up;
// otherwise the first user in the database is used.
func (api *API) resolveDevAuth(c *gin.Context) (authData, bool) {
if api.cfg.DisableAuthUser != "" {
user, err := api.db.Queries.GetUser(c, api.cfg.DisableAuthUser)
if err != nil {
log.Errorf("DISABLE_AUTH_USER=%q not found in database: %v", api.cfg.DisableAuthUser, err)
return authData{}, false
}
return authData{
UserName: user.ID,
IsAdmin: user.Admin,
AuthHash: *user.AuthHash,
}, true
}
users, err := api.db.Queries.GetUsers(c)
if err != nil || len(users) == 0 {
return authData{}, false
}
return authData{
UserName: users[0].ID,
IsAdmin: users[0].Admin,
AuthHash: *users[0].AuthHash,
}, true
}
func (api *API) authKOMiddleware(c *gin.Context) {
// Dev Auth Bypass
if api.cfg.DisableAuth {
if auth, ok := api.resolveDevAuth(c); ok {
c.Set("Authorization", auth)
c.Header("Cache-Control", "private")
c.Next()
return
}
}
session := sessions.Default(c)
// Check Session First
@@ -89,6 +127,16 @@ func (api *API) authKOMiddleware(c *gin.Context) {
}
func (api *API) authOPDSMiddleware(c *gin.Context) {
// Dev Auth Bypass
if api.cfg.DisableAuth {
if auth, ok := api.resolveDevAuth(c); ok {
c.Set("Authorization", auth)
c.Header("Cache-Control", "private")
c.Next()
return
}
}
c.Header("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
user, rawPassword, hasAuth := c.Request.BasicAuth()
@@ -113,6 +161,16 @@ func (api *API) authOPDSMiddleware(c *gin.Context) {
}
func (api *API) authWebAppMiddleware(c *gin.Context) {
// Dev Auth Bypass
if api.cfg.DisableAuth {
if auth, ok := api.resolveDevAuth(c); ok {
c.Set("Authorization", auth)
c.Header("Cache-Control", "private")
c.Next()
return
}
}
session := sessions.Default(c)
// Check Session

View File

@@ -407,7 +407,10 @@ func (api *API) koCheckDocumentsSync(c *gin.Context) {
return
}
wantedDocs, err := api.db.Queries.GetWantedDocuments(c, string(jsonHaves))
wantedDocs, err := api.db.Queries.GetWantedDocuments(c, database.GetWantedDocumentsParams{
JsonEach: string(jsonHaves),
DocumentIds: string(jsonHaves),
})
if err != nil {
log.Error("GetWantedDocuments DB Error", err)
apiErrorPage(c, http.StatusBadRequest, "Invalid Request")

176
api/v1/activity.go Normal file
View File

@@ -0,0 +1,176 @@
package v1
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
)
// GET /activity
func (s *Server) GetActivity(ctx context.Context, request GetActivityRequestObject) (GetActivityResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetActivity401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
docFilter := false
if request.Params.DocFilter != nil {
docFilter = *request.Params.DocFilter
}
documentID := ""
if request.Params.DocumentId != nil {
documentID = *request.Params.DocumentId
}
page := int64(1)
if request.Params.Page != nil {
page = *request.Params.Page
}
limit := int64(25)
if request.Params.Limit != nil {
limit = *request.Params.Limit
}
activities, err := s.db.Queries.GetActivity(ctx, database.GetActivityParams{
UserID: auth.UserName,
DocFilter: docFilter,
DocumentID: documentID,
Offset: (page - 1) * limit,
Limit: limit,
})
if err != nil {
return GetActivity500JSONResponse{Code: 500, Message: err.Error()}, nil
}
// Get Total Count
total, err := s.db.Queries.GetActivityCount(ctx, database.GetActivityCountParams{
UserID: auth.UserName,
DocFilter: docFilter,
DocumentID: documentID,
})
if err != nil {
return GetActivity500JSONResponse{Code: 500, Message: err.Error()}, nil
}
// Calculate Pagination
var nextPage *int64
var previousPage *int64
if page*limit < total {
nextPage = ptrOf(page + 1)
}
if page > 1 {
previousPage = ptrOf(page - 1)
}
apiActivities := make([]Activity, len(activities))
for i, a := range activities {
// Convert StartTime from interface{} to string
startTimeStr := ""
if a.StartTime != nil {
if str, ok := a.StartTime.(string); ok {
startTimeStr = str
}
}
apiActivities[i] = Activity{
DocumentId: a.DocumentID,
DeviceId: a.DeviceID,
StartTime: startTimeStr,
Title: a.Title,
Author: a.Author,
Duration: a.Duration,
StartPercentage: float32(a.StartPercentage),
EndPercentage: float32(a.EndPercentage),
ReadPercentage: float32(a.ReadPercentage),
}
}
response := ActivityResponse{
Activities: apiActivities,
Page: page,
Limit: limit,
Total: total,
NextPage: nextPage,
PreviousPage: previousPage,
}
return GetActivity200JSONResponse(response), nil
}
// POST /activity
func (s *Server) CreateActivity(ctx context.Context, request CreateActivityRequestObject) (CreateActivityResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return CreateActivity401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return CreateActivity400JSONResponse{Code: 400, Message: "Request body is required"}, nil
}
tx, err := s.db.DB.Begin()
if err != nil {
log.Error("Transaction Begin DB Error:", err)
return CreateActivity500JSONResponse{Code: 500, Message: "Database error"}, nil
}
committed := false
defer func() {
if committed {
return
}
if rollbackErr := tx.Rollback(); rollbackErr != nil {
log.Debug("Transaction Rollback DB Error:", rollbackErr)
}
}()
qtx := s.db.Queries.WithTx(tx)
allDocumentsMap := make(map[string]struct{})
for _, item := range request.Body.Activity {
allDocumentsMap[item.DocumentId] = struct{}{}
}
for documentID := range allDocumentsMap {
if _, err := qtx.UpsertDocument(ctx, database.UpsertDocumentParams{ID: documentID}); err != nil {
log.Error("UpsertDocument DB Error:", err)
return CreateActivity400JSONResponse{Code: 400, Message: "Invalid document"}, nil
}
}
if _, err := qtx.UpsertDevice(ctx, database.UpsertDeviceParams{
ID: request.Body.DeviceId,
UserID: auth.UserName,
DeviceName: request.Body.DeviceName,
LastSynced: time.Now().UTC().Format(time.RFC3339),
}); err != nil {
log.Error("UpsertDevice DB Error:", err)
return CreateActivity400JSONResponse{Code: 400, Message: "Invalid device"}, nil
}
for _, item := range request.Body.Activity {
if _, err := qtx.AddActivity(ctx, database.AddActivityParams{
UserID: auth.UserName,
DocumentID: item.DocumentId,
DeviceID: request.Body.DeviceId,
StartTime: time.Unix(item.StartTime, 0).UTC().Format(time.RFC3339),
Duration: item.Duration,
StartPercentage: float64(item.Page) / float64(item.Pages),
EndPercentage: float64(item.Page+1) / float64(item.Pages),
}); err != nil {
log.Error("AddActivity DB Error:", err)
return CreateActivity400JSONResponse{Code: 400, Message: "Invalid activity"}, nil
}
}
if err := tx.Commit(); err != nil {
log.Error("Transaction Commit DB Error:", err)
return CreateActivity500JSONResponse{Code: 500, Message: "Database error"}, nil
}
committed = true
response := CreateActivityResponse{Added: int64(len(request.Body.Activity))}
return CreateActivity200JSONResponse(response), nil
}

1070
api/v1/admin.go Normal file

File diff suppressed because it is too large Load Diff

152
api/v1/admin_test.go Normal file
View File

@@ -0,0 +1,152 @@
package v1
import (
"bytes"
"context"
"crypto/md5"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
argon2 "github.com/alexedwards/argon2id"
"github.com/stretchr/testify/require"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
)
func createAdminTestUser(t *testing.T, db *database.DBManager, username, password string) {
t.Helper()
md5Hash := fmt.Sprintf("%x", md5.Sum([]byte(password)))
hashedPassword, err := argon2.CreateHash(md5Hash, argon2.DefaultParams)
require.NoError(t, err)
authHash := "test-auth-hash"
_, err = db.Queries.CreateUser(context.Background(), database.CreateUserParams{
ID: username,
Pass: &hashedPassword,
AuthHash: &authHash,
Admin: true,
})
require.NoError(t, err)
}
func loginAdminTestUser(t *testing.T, srv *Server, username, password string) *http.Cookie {
t.Helper()
body, err := json.Marshal(LoginRequest{Username: username, Password: password})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
cookies := w.Result().Cookies()
require.Len(t, cookies, 1)
return cookies[0]
}
func TestGetLogsPagination(t *testing.T) {
configPath := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(configPath, "logs"), 0o755))
require.NoError(t, os.WriteFile(filepath.Join(configPath, "logs", "antholume.log"), []byte(
"{\"level\":\"info\",\"msg\":\"one\"}\n"+
"plain two\n"+
"{\"level\":\"error\",\"msg\":\"three\"}\n"+
"plain four\n",
), 0o644))
cfg := &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: configPath,
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
db := database.NewMgr(cfg)
srv := NewServer(db, cfg, nil)
createAdminTestUser(t, db, "admin", "password")
cookie := loginAdminTestUser(t, srv, "admin", "password")
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/logs?page=2&limit=2", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp LogsResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
require.NotNil(t, resp.Logs)
require.Len(t, *resp.Logs, 2)
require.NotNil(t, resp.Page)
require.Equal(t, int64(2), *resp.Page)
require.NotNil(t, resp.Limit)
require.Equal(t, int64(2), *resp.Limit)
require.NotNil(t, resp.Total)
require.Equal(t, int64(4), *resp.Total)
require.Nil(t, resp.NextPage)
require.NotNil(t, resp.PreviousPage)
require.Equal(t, int64(1), *resp.PreviousPage)
require.Contains(t, (*resp.Logs)[0], "three")
require.Contains(t, (*resp.Logs)[1], "plain four")
}
func TestGetLogsPaginationWithBasicFilter(t *testing.T) {
configPath := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(configPath, "logs"), 0o755))
require.NoError(t, os.WriteFile(filepath.Join(configPath, "logs", "antholume.log"), []byte(
"{\"level\":\"info\",\"msg\":\"match-1\"}\n"+
"{\"level\":\"info\",\"msg\":\"skip\"}\n"+
"plain match-2\n"+
"{\"level\":\"info\",\"msg\":\"match-3\"}\n",
), 0o644))
cfg := &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: configPath,
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
db := database.NewMgr(cfg)
srv := NewServer(db, cfg, nil)
createAdminTestUser(t, db, "admin", "password")
cookie := loginAdminTestUser(t, srv, "admin", "password")
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/logs?filter=%22match%22&page=1&limit=2", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp LogsResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
require.NotNil(t, resp.Logs)
require.Len(t, *resp.Logs, 2)
require.NotNil(t, resp.Total)
require.Equal(t, int64(3), *resp.Total)
require.NotNil(t, resp.NextPage)
require.Equal(t, int64(2), *resp.NextPage)
}

4151
api/v1/api.gen.go Normal file

File diff suppressed because it is too large Load Diff

286
api/v1/auth.go Normal file
View File

@@ -0,0 +1,286 @@
package v1
import (
"context"
"crypto/md5"
"fmt"
"net/http"
"time"
argon2 "github.com/alexedwards/argon2id"
"github.com/gorilla/sessions"
log "github.com/sirupsen/logrus"
)
// POST /auth/login
func (s *Server) Login(ctx context.Context, request LoginRequestObject) (LoginResponseObject, error) {
if request.Body == nil {
return Login400JSONResponse{Code: 400, Message: "Invalid request body"}, nil
}
req := *request.Body
if req.Username == "" || req.Password == "" {
return Login400JSONResponse{Code: 400, Message: "Invalid credentials"}, nil
}
// MD5 - KOSync compatibility
password := fmt.Sprintf("%x", md5.Sum([]byte(req.Password)))
// Verify credentials
user, err := s.db.Queries.GetUser(ctx, req.Username)
if err != nil {
return Login401JSONResponse{Code: 401, Message: "Invalid credentials"}, nil
}
if match, err := argon2.ComparePasswordAndHash(password, *user.Pass); err != nil || !match {
return Login401JSONResponse{Code: 401, Message: "Invalid credentials"}, nil
}
if err := s.saveUserSession(ctx, user.ID, user.Admin, *user.AuthHash); err != nil {
return Login500JSONResponse{Code: 500, Message: err.Error()}, nil
}
return Login200JSONResponse{
Body: LoginResponse{
Username: user.ID,
IsAdmin: user.Admin,
},
Headers: Login200ResponseHeaders{
SetCookie: s.getSetCookieFromContext(ctx),
},
}, nil
}
// POST /auth/register
func (s *Server) Register(ctx context.Context, request RegisterRequestObject) (RegisterResponseObject, error) {
if !s.cfg.RegistrationEnabled {
return Register403JSONResponse{Code: 403, Message: "Registration is disabled"}, nil
}
if request.Body == nil {
return Register400JSONResponse{Code: 400, Message: "Invalid request body"}, nil
}
req := *request.Body
if req.Username == "" || req.Password == "" {
return Register400JSONResponse{Code: 400, Message: "Invalid user or password"}, nil
}
currentUsers, err := s.db.Queries.GetUsers(ctx)
if err != nil {
return Register500JSONResponse{Code: 500, Message: "Failed to create user"}, nil
}
isAdmin := len(currentUsers) == 0
if err := s.createUser(ctx, req.Username, &req.Password, &isAdmin); err != nil {
return Register400JSONResponse{Code: 400, Message: err.Error()}, nil
}
user, err := s.db.Queries.GetUser(ctx, req.Username)
if err != nil {
return Register500JSONResponse{Code: 500, Message: "Failed to load created user"}, nil
}
if err := s.saveUserSession(ctx, user.ID, user.Admin, *user.AuthHash); err != nil {
return Register500JSONResponse{Code: 500, Message: err.Error()}, nil
}
return Register201JSONResponse{
Body: LoginResponse{
Username: user.ID,
IsAdmin: user.Admin,
},
Headers: Register201ResponseHeaders{
SetCookie: s.getSetCookieFromContext(ctx),
},
}, nil
}
// POST /auth/logout
func (s *Server) Logout(ctx context.Context, request LogoutRequestObject) (LogoutResponseObject, error) {
_, ok := s.getSessionFromContext(ctx)
if !ok {
return Logout401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
r := s.getRequestFromContext(ctx)
w := s.getResponseWriterFromContext(ctx)
if r == nil || w == nil {
return Logout401JSONResponse{Code: 401, Message: "Internal context error"}, nil
}
session, err := s.getCookieSession(r)
if err != nil {
return Logout401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
session.Values = make(map[any]any)
if err := session.Save(r, w); err != nil {
return Logout401JSONResponse{Code: 401, Message: "Failed to logout"}, nil
}
return Logout200Response{}, nil
}
// GET /auth/me
func (s *Server) GetMe(ctx context.Context, request GetMeRequestObject) (GetMeResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetMe401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
return GetMe200JSONResponse{
Username: auth.UserName,
IsAdmin: auth.IsAdmin,
}, nil
}
func (s *Server) saveUserSession(ctx context.Context, username string, isAdmin bool, authHash string) error {
r := s.getRequestFromContext(ctx)
w := s.getResponseWriterFromContext(ctx)
if r == nil || w == nil {
return fmt.Errorf("internal context error")
}
session, err := s.getCookieSession(r)
if err != nil {
return fmt.Errorf("unauthorized")
}
session.Values["authorizedUser"] = username
session.Values["isAdmin"] = isAdmin
session.Values["expiresAt"] = time.Now().Unix() + (60 * 60 * 24 * 7)
session.Values["authHash"] = authHash
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to create session")
}
return nil
}
func (s *Server) getCookieSession(r *http.Request) (*sessions.Session, error) {
store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey))
if s.cfg.CookieEncKey != "" {
if len(s.cfg.CookieEncKey) == 16 || len(s.cfg.CookieEncKey) == 32 {
store = sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey), []byte(s.cfg.CookieEncKey))
}
}
session, err := store.Get(r, "token")
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
}
session.Options.SameSite = http.SameSiteLaxMode
session.Options.HttpOnly = true
session.Options.Secure = s.cfg.CookieSecure
return session, nil
}
// getSessionFromContext extracts authData from context
func (s *Server) getSessionFromContext(ctx context.Context) (authData, bool) {
auth, ok := ctx.Value("auth").(authData)
if !ok {
return authData{}, false
}
return auth, true
}
// isAdmin checks if a user has admin privileges
func (s *Server) isAdmin(ctx context.Context) bool {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return false
}
return auth.IsAdmin
}
// getRequestFromContext extracts the HTTP request from context
func (s *Server) getRequestFromContext(ctx context.Context) *http.Request {
r, ok := ctx.Value("request").(*http.Request)
if !ok {
return nil
}
return r
}
// getResponseWriterFromContext extracts the response writer from context
func (s *Server) getResponseWriterFromContext(ctx context.Context) http.ResponseWriter {
w, ok := ctx.Value("response").(http.ResponseWriter)
if !ok {
return nil
}
return w
}
func (s *Server) getSetCookieFromContext(ctx context.Context) string {
w := s.getResponseWriterFromContext(ctx)
if w == nil {
return ""
}
return w.Header().Get("Set-Cookie")
}
// getSession retrieves auth data from the session cookie
func (s *Server) getSession(r *http.Request) (auth authData, ok bool) {
// Get session from cookie store
store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey))
if s.cfg.CookieEncKey != "" {
if len(s.cfg.CookieEncKey) == 16 || len(s.cfg.CookieEncKey) == 32 {
store = sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey), []byte(s.cfg.CookieEncKey))
} else {
log.Error("invalid cookie encryption key (must be 16 or 32 bytes)")
return authData{}, false
}
}
session, err := store.Get(r, "token")
if err != nil {
return authData{}, false
}
// Get session values
authorizedUser := session.Values["authorizedUser"]
isAdmin := session.Values["isAdmin"]
expiresAt := session.Values["expiresAt"]
authHash := session.Values["authHash"]
if authorizedUser == nil || isAdmin == nil || expiresAt == nil || authHash == nil {
return authData{}, false
}
auth = authData{
UserName: authorizedUser.(string),
IsAdmin: isAdmin.(bool),
AuthHash: authHash.(string),
}
// Validate auth hash
ctx := r.Context()
correctAuthHash, err := s.getUserAuthHash(ctx, auth.UserName)
if err != nil || correctAuthHash != auth.AuthHash {
return authData{}, false
}
return auth, true
}
// getUserAuthHash retrieves the user's auth hash from DB or cache
func (s *Server) getUserAuthHash(ctx context.Context, username string) (string, error) {
user, err := s.db.Queries.GetUser(ctx, username)
if err != nil {
return "", err
}
return *user.AuthHash, nil
}
// authData represents authenticated user information
type authData struct {
UserName string
IsAdmin bool
AuthHash string
}

228
api/v1/auth_test.go Normal file
View File

@@ -0,0 +1,228 @@
package v1
import (
"bytes"
"crypto/md5"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
argon2 "github.com/alexedwards/argon2id"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
)
type AuthTestSuite struct {
suite.Suite
db *database.DBManager
cfg *config.Config
srv *Server
}
func (suite *AuthTestSuite) setupConfig() *config.Config {
return &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: "/tmp",
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
}
func TestAuth(t *testing.T) {
suite.Run(t, new(AuthTestSuite))
}
func (suite *AuthTestSuite) SetupTest() {
suite.cfg = suite.setupConfig()
suite.db = database.NewMgr(suite.cfg)
suite.srv = NewServer(suite.db, suite.cfg, nil)
}
func (suite *AuthTestSuite) createTestUser(username, password string) {
md5Hash := fmt.Sprintf("%x", md5.Sum([]byte(password)))
hashedPassword, err := argon2.CreateHash(md5Hash, argon2.DefaultParams)
suite.Require().NoError(err)
authHash := "test-auth-hash"
_, err = suite.db.Queries.CreateUser(suite.T().Context(), database.CreateUserParams{
ID: username,
Pass: &hashedPassword,
AuthHash: &authHash,
Admin: true,
})
suite.Require().NoError(err)
}
func (suite *AuthTestSuite) assertSessionCookie(cookie *http.Cookie) {
suite.Require().NotNil(cookie)
suite.Equal("token", cookie.Name)
suite.NotEmpty(cookie.Value)
suite.True(cookie.HttpOnly)
}
func (suite *AuthTestSuite) login(username, password string) *http.Cookie {
reqBody := LoginRequest{
Username: username,
Password: password,
}
body, err := json.Marshal(reqBody)
suite.Require().NoError(err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code, "login should return 200")
var resp LoginResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1, "should have session cookie")
suite.assertSessionCookie(cookies[0])
return cookies[0]
}
func (suite *AuthTestSuite) TestAPILogin() {
suite.createTestUser("testuser", "testpass")
reqBody := LoginRequest{
Username: "testuser",
Password: "testpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
var resp LoginResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal("testuser", resp.Username)
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1)
suite.assertSessionCookie(cookies[0])
}
func (suite *AuthTestSuite) TestAPILoginInvalidCredentials() {
reqBody := LoginRequest{
Username: "testuser",
Password: "wrongpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}
func (suite *AuthTestSuite) TestAPIRegister() {
reqBody := LoginRequest{
Username: "newuser",
Password: "newpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusCreated, w.Code)
var resp LoginResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal("newuser", resp.Username)
suite.True(resp.IsAdmin, "first registered user should mirror legacy admin bootstrap behavior")
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1, "register should set a session cookie")
suite.assertSessionCookie(cookies[0])
user, err := suite.db.Queries.GetUser(suite.T().Context(), "newuser")
suite.Require().NoError(err)
suite.True(user.Admin)
}
func (suite *AuthTestSuite) TestAPIRegisterDisabled() {
suite.cfg.RegistrationEnabled = false
suite.srv = NewServer(suite.db, suite.cfg, nil)
reqBody := LoginRequest{
Username: "newuser",
Password: "newpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusForbidden, w.Code)
}
func (suite *AuthTestSuite) TestAPILogout() {
suite.createTestUser("testuser", "testpass")
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1)
suite.Equal("token", cookies[0].Name)
}
func (suite *AuthTestSuite) TestAPIGetMe() {
suite.createTestUser("testuser", "testpass")
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
var resp UserData
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal("testuser", resp.Username)
}
func (suite *AuthTestSuite) TestAPIGetMeUnauthenticated() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}

837
api/v1/documents.go Normal file
View File

@@ -0,0 +1,837 @@
package v1
import (
"context"
"fmt"
"io"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
"reichard.io/antholume/metadata"
)
// GET /documents
func (s *Server) GetDocuments(ctx context.Context, request GetDocumentsRequestObject) (GetDocumentsResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetDocuments401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
page := int64(1)
if request.Params.Page != nil {
page = *request.Params.Page
}
limit := int64(9)
if request.Params.Limit != nil {
limit = *request.Params.Limit
}
var search *string
if request.Params.Search != nil && *request.Params.Search != "" {
search = ptrOf("%" + *request.Params.Search + "%")
}
rows, err := s.db.Queries.GetDocumentsWithStats(
ctx,
database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
Query: search,
Deleted: ptrOf(false),
Offset: (page - 1) * limit,
Limit: limit,
},
)
if err != nil {
return GetDocuments500JSONResponse{Code: 500, Message: err.Error()}, nil
}
// Get Total Count
total, err := s.db.Queries.GetDocumentsWithStatsCount(
ctx,
database.GetDocumentsWithStatsCountParams{
Query: search,
Deleted: ptrOf(false),
},
)
if err != nil {
return GetDocuments500JSONResponse{Code: 500, Message: err.Error()}, nil
}
// Calculate Pagination
var nextPage *int64
var previousPage *int64
if page*limit < total {
nextPage = ptrOf(page + 1)
}
if page > 1 {
previousPage = ptrOf(page - 1)
}
apiDocuments := make([]Document, len(rows))
for i, row := range rows {
apiDocuments[i] = Document{
Id: row.ID,
Title: *row.Title,
Author: *row.Author,
Description: row.Description,
Isbn10: row.Isbn10,
Isbn13: row.Isbn13,
Words: row.Words,
Filepath: row.Filepath,
Percentage: ptrOf(float32(row.Percentage)),
TotalTimeSeconds: ptrOf(row.TotalTimeSeconds),
Wpm: ptrOf(float32(row.Wpm)),
SecondsPerPercent: ptrOf(row.SecondsPerPercent),
LastRead: parseInterfaceTime(row.LastRead),
CreatedAt: time.Now(), // Will be overwritten if we had a proper created_at from DB
UpdatedAt: time.Now(), // Will be overwritten if we had a proper updated_at from DB
Deleted: false, // Default, should be overridden if available
}
}
response := DocumentsResponse{
Documents: apiDocuments,
Total: total,
Page: page,
Limit: limit,
NextPage: nextPage,
PreviousPage: previousPage,
Search: request.Params.Search,
}
return GetDocuments200JSONResponse(response), nil
}
// GET /documents/{id}
func (s *Server) GetDocument(ctx context.Context, request GetDocumentRequestObject) (GetDocumentResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetDocument401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
// Use GetDocumentsWithStats to get document with stats
docs, err := s.db.Queries.GetDocumentsWithStats(
ctx,
database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
ID: &request.Id,
Deleted: ptrOf(false),
Offset: 0,
Limit: 1,
},
)
if err != nil || len(docs) == 0 {
return GetDocument404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
doc := docs[0]
apiDoc := Document{
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Description: doc.Description,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Filepath: doc.Filepath,
Percentage: ptrOf(float32(doc.Percentage)),
TotalTimeSeconds: ptrOf(doc.TotalTimeSeconds),
Wpm: ptrOf(float32(doc.Wpm)),
SecondsPerPercent: ptrOf(doc.SecondsPerPercent),
LastRead: parseInterfaceTime(doc.LastRead),
CreatedAt: time.Now(), // Will be overwritten if we had a proper created_at from DB
UpdatedAt: time.Now(), // Will be overwritten if we had a proper updated_at from DB
Deleted: false, // Default, should be overridden if available
}
response := DocumentResponse{
Document: apiDoc,
}
return GetDocument200JSONResponse(response), nil
}
// POST /documents/{id}
func (s *Server) EditDocument(ctx context.Context, request EditDocumentRequestObject) (EditDocumentResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return EditDocument401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return EditDocument400JSONResponse{Code: 400, Message: "Missing request body"}, nil
}
// Validate document exists and get current state
currentDoc, err := s.db.Queries.GetDocument(ctx, request.Id)
if err != nil {
return EditDocument404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
// Validate at least one editable field is provided
if request.Body.Title == nil &&
request.Body.Author == nil &&
request.Body.Description == nil &&
request.Body.Isbn10 == nil &&
request.Body.Isbn13 == nil &&
request.Body.CoverGbid == nil {
return EditDocument400JSONResponse{Code: 400, Message: "No editable fields provided"}, nil
}
// Handle cover via Google Books ID
var coverFileName *string
if request.Body.CoverGbid != nil {
coverDir := filepath.Join(s.cfg.DataPath, "covers")
fileName, err := metadata.CacheCoverWithContext(ctx, *request.Body.CoverGbid, coverDir, request.Id, true)
if err == nil {
coverFileName = fileName
}
}
// Update document with provided editable fields only
_, err = s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: request.Id,
Title: request.Body.Title,
Author: request.Body.Author,
Description: request.Body.Description,
Isbn10: request.Body.Isbn10,
Isbn13: request.Body.Isbn13,
Coverfile: coverFileName,
// Preserve existing values for non-editable fields
Md5: currentDoc.Md5,
Basepath: currentDoc.Basepath,
Filepath: currentDoc.Filepath,
Words: currentDoc.Words,
})
if err != nil {
log.Error("UpsertDocument DB Error:", err)
return EditDocument500JSONResponse{Code: 500, Message: "Failed to update document"}, nil
}
// Use GetDocumentsWithStats to get document with stats for the response
docs, err := s.db.Queries.GetDocumentsWithStats(
ctx,
database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
ID: &request.Id,
Deleted: ptrOf(false),
Offset: 0,
Limit: 1,
},
)
if err != nil || len(docs) == 0 {
return EditDocument404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
doc := docs[0]
apiDoc := Document{
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Description: doc.Description,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Filepath: doc.Filepath,
Percentage: ptrOf(float32(doc.Percentage)),
TotalTimeSeconds: ptrOf(doc.TotalTimeSeconds),
Wpm: ptrOf(float32(doc.Wpm)),
SecondsPerPercent: ptrOf(doc.SecondsPerPercent),
LastRead: parseInterfaceTime(doc.LastRead),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Deleted: false,
}
response := DocumentResponse{
Document: apiDoc,
}
return EditDocument200JSONResponse(response), nil
}
// deriveBaseFileName builds the base filename for a given MetadataInfo object.
func deriveBaseFileName(metadataInfo *metadata.MetadataInfo) string {
// Derive New FileName
var newFileName string
if metadataInfo.Author != nil && *metadataInfo.Author != "" {
newFileName = newFileName + *metadataInfo.Author
} else {
newFileName = newFileName + "Unknown"
}
if metadataInfo.Title != nil && *metadataInfo.Title != "" {
newFileName = newFileName + " - " + *metadataInfo.Title
} else {
newFileName = newFileName + " - Unknown"
}
// Remove Slashes
fileName := strings.ReplaceAll(newFileName, "/", "")
return "." + filepath.Clean(fmt.Sprintf("/%s [%s]%s", fileName, *metadataInfo.PartialMD5, metadataInfo.Type))
}
// parseInterfaceTime converts an interface{} to time.Time for SQLC queries
func parseInterfaceTime(t any) *time.Time {
if t == nil {
return nil
}
switch v := t.(type) {
case string:
parsed, err := time.Parse(time.RFC3339, v)
if err != nil {
return nil
}
return &parsed
case time.Time:
return &v
default:
return nil
}
}
// serveNoCover serves the default no-cover image from assets
func (s *Server) serveNoCover() (fs.File, string, int64, error) {
// Try to open the no-cover image from assets
file, err := s.assets.Open("assets/images/no-cover.jpg")
if err != nil {
return nil, "", 0, err
}
// Get file info
info, err := file.Stat()
if err != nil {
file.Close()
return nil, "", 0, err
}
return file, "image/jpeg", info.Size(), nil
}
// openFileReader opens a file and returns it as an io.ReaderCloser
func openFileReader(path string) (*os.File, error) {
return os.Open(path)
}
// GET /documents/{id}/cover
func (s *Server) GetDocumentCover(ctx context.Context, request GetDocumentCoverRequestObject) (GetDocumentCoverResponseObject, error) {
// Authentication is handled by middleware, which also adds auth data to context
// This endpoint just serves the cover image
// Validate Document Exists in DB
document, err := s.db.Queries.GetDocument(ctx, request.Id)
if err != nil {
log.Error("GetDocument DB Error:", err)
return GetDocumentCover404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
var coverFile fs.File
var contentType string
var contentLength int64
var needMetadataFetch bool
// Handle Identified Document
if document.Coverfile != nil {
if *document.Coverfile == "UNKNOWN" {
// Serve no-cover image
file, ct, size, err := s.serveNoCover()
if err != nil {
log.Error("Failed to open no-cover image:", err)
return GetDocumentCover404JSONResponse{Code: 404, Message: "Cover not found"}, nil
}
coverFile = file
contentType = ct
contentLength = size
needMetadataFetch = true
} else {
// Derive Path
coverPath := filepath.Join(s.cfg.DataPath, "covers", *document.Coverfile)
// Validate File Exists
fileInfo, err := os.Stat(coverPath)
if os.IsNotExist(err) {
log.Error("Cover file should but doesn't exist: ", err)
// Serve no-cover image
file, ct, size, err := s.serveNoCover()
if err != nil {
log.Error("Failed to open no-cover image:", err)
return GetDocumentCover404JSONResponse{Code: 404, Message: "Cover not found"}, nil
}
coverFile = file
contentType = ct
contentLength = size
needMetadataFetch = true
} else {
// Open the cover file
file, err := openFileReader(coverPath)
if err != nil {
log.Error("Failed to open cover file:", err)
return GetDocumentCover500JSONResponse{Code: 500, Message: "Failed to open cover"}, nil
}
coverFile = file
contentLength = fileInfo.Size()
// Determine content type based on file extension
contentType = "image/jpeg"
if strings.HasSuffix(coverPath, ".png") {
contentType = "image/png"
}
}
}
} else {
needMetadataFetch = true
}
// Attempt Metadata fetch if needed
var cachedCoverFile string = "UNKNOWN"
var coverDir string = filepath.Join(s.cfg.DataPath, "covers")
if needMetadataFetch {
// Create context with timeout for metadata service calls
metadataCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Identify Documents & Save Covers
metadataResults, err := metadata.SearchMetadataWithContext(metadataCtx, metadata.SOURCE_GBOOK, metadata.MetadataInfo{
Title: document.Title,
Author: document.Author,
})
if err == nil && len(metadataResults) > 0 && metadataResults[0].ID != nil {
firstResult := metadataResults[0]
// Save Cover
fileName, err := metadata.CacheCoverWithContext(metadataCtx, *firstResult.ID, coverDir, document.ID, false)
if err == nil {
cachedCoverFile = *fileName
}
// Store First Metadata Result
if _, err = s.db.Queries.AddMetadata(ctx, database.AddMetadataParams{
DocumentID: document.ID,
Title: firstResult.Title,
Author: firstResult.Author,
Description: firstResult.Description,
Gbid: firstResult.ID,
Olid: nil,
Isbn10: firstResult.ISBN10,
Isbn13: firstResult.ISBN13,
}); err != nil {
log.Error("AddMetadata DB Error:", err)
}
}
// Upsert Document
if _, err = s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: document.ID,
Coverfile: &cachedCoverFile,
}); err != nil {
log.Warn("UpsertDocument DB Error:", err)
}
// Update cover file if we got a new cover
if cachedCoverFile != "UNKNOWN" {
coverPath := filepath.Join(coverDir, cachedCoverFile)
fileInfo, err := os.Stat(coverPath)
if err != nil {
log.Error("Failed to stat cached cover:", err)
// Keep the no-cover image
} else {
file, err := openFileReader(coverPath)
if err != nil {
log.Error("Failed to open cached cover:", err)
// Keep the no-cover image
} else {
_ = coverFile.Close() // Close the previous file
coverFile = file
contentLength = fileInfo.Size()
// Determine content type based on file extension
contentType = "image/jpeg"
if strings.HasSuffix(coverPath, ".png") {
contentType = "image/png"
}
}
}
}
}
return &GetDocumentCover200Response{
Body: coverFile,
ContentLength: contentLength,
ContentType: contentType,
}, nil
}
// POST /documents/{id}/cover
func (s *Server) UploadDocumentCover(ctx context.Context, request UploadDocumentCoverRequestObject) (UploadDocumentCoverResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return UploadDocumentCover401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return UploadDocumentCover400JSONResponse{Code: 400, Message: "Missing request body"}, nil
}
// Validate document exists
_, err := s.db.Queries.GetDocument(ctx, request.Id)
if err != nil {
return UploadDocumentCover404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
// Read multipart form
form, err := request.Body.ReadForm(32 << 20) // 32MB max
if err != nil {
log.Error("ReadForm error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Failed to read form"}, nil
}
// Get file from form
fileField := form.File["cover_file"]
if len(fileField) == 0 {
return UploadDocumentCover400JSONResponse{Code: 400, Message: "No file provided"}, nil
}
file := fileField[0]
// Validate file extension
if !strings.HasSuffix(strings.ToLower(file.Filename), ".jpg") && !strings.HasSuffix(strings.ToLower(file.Filename), ".png") {
return UploadDocumentCover400JSONResponse{Code: 400, Message: "Only JPG and PNG files are allowed"}, nil
}
// Open file
f, err := file.Open()
if err != nil {
log.Error("Open file error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Failed to open file"}, nil
}
defer f.Close()
// Read file content
data, err := io.ReadAll(f)
if err != nil {
log.Error("Read file error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Failed to read file"}, nil
}
// Validate actual content type
contentType := http.DetectContentType(data)
allowedTypes := map[string]bool{
"image/jpeg": true,
"image/png": true,
}
if !allowedTypes[contentType] {
return UploadDocumentCover400JSONResponse{
Code: 400,
Message: fmt.Sprintf("Invalid file type: %s. Only JPG and PNG files are allowed.", contentType),
}, nil
}
// Derive storage path
coverDir := filepath.Join(s.cfg.DataPath, "covers")
fileName := fmt.Sprintf("%s%s", request.Id, strings.ToLower(filepath.Ext(file.Filename)))
safePath := filepath.Join(coverDir, fileName)
// Save file
err = os.WriteFile(safePath, data, 0644)
if err != nil {
log.Error("Save file error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Unable to save cover"}, nil
}
// Upsert document with new cover
_, err = s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: request.Id,
Coverfile: &fileName,
})
if err != nil {
log.Error("UpsertDocument DB error:", err)
return UploadDocumentCover500JSONResponse{Code: 500, Message: "Failed to save cover"}, nil
}
// Use GetDocumentsWithStats to get document with stats for the response
docs, err := s.db.Queries.GetDocumentsWithStats(
ctx,
database.GetDocumentsWithStatsParams{
UserID: auth.UserName,
ID: &request.Id,
Deleted: ptrOf(false),
Offset: 0,
Limit: 1,
},
)
if err != nil || len(docs) == 0 {
return UploadDocumentCover404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
doc := docs[0]
apiDoc := Document{
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Description: doc.Description,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Filepath: doc.Filepath,
Percentage: ptrOf(float32(doc.Percentage)),
TotalTimeSeconds: ptrOf(doc.TotalTimeSeconds),
Wpm: ptrOf(float32(doc.Wpm)),
SecondsPerPercent: ptrOf(doc.SecondsPerPercent),
LastRead: parseInterfaceTime(doc.LastRead),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Deleted: false,
}
response := DocumentResponse{
Document: apiDoc,
}
return UploadDocumentCover200JSONResponse(response), nil
}
// GET /documents/{id}/file
func (s *Server) GetDocumentFile(ctx context.Context, request GetDocumentFileRequestObject) (GetDocumentFileResponseObject, error) {
// Authentication is handled by middleware, which also adds auth data to context
// This endpoint just serves the document file download
// Get Document
document, err := s.db.Queries.GetDocument(ctx, request.Id)
if err != nil {
log.Error("GetDocument DB Error:", err)
return GetDocumentFile404JSONResponse{Code: 404, Message: "Document not found"}, nil
}
if document.Filepath == nil {
log.Error("Document Doesn't Have File:", request.Id)
return GetDocumentFile404JSONResponse{Code: 404, Message: "Document file not found"}, nil
}
// Derive Basepath
basepath := filepath.Join(s.cfg.DataPath, "documents")
if document.Basepath != nil && *document.Basepath != "" {
basepath = *document.Basepath
}
// Derive Storage Location
filePath := filepath.Join(basepath, *document.Filepath)
// Validate File Exists
fileInfo, err := os.Stat(filePath)
if os.IsNotExist(err) {
log.Error("File should but doesn't exist:", err)
return GetDocumentFile404JSONResponse{Code: 404, Message: "Document file not found"}, nil
}
// Open file
file, err := os.Open(filePath)
if err != nil {
log.Error("Failed to open document file:", err)
return GetDocumentFile500JSONResponse{Code: 500, Message: "Failed to open document"}, nil
}
return &GetDocumentFile200Response{
Body: file,
ContentLength: fileInfo.Size(),
Filename: filepath.Base(*document.Filepath),
}, nil
}
// POST /documents
func (s *Server) CreateDocument(ctx context.Context, request CreateDocumentRequestObject) (CreateDocumentResponseObject, error) {
_, ok := s.getSessionFromContext(ctx)
if !ok {
return CreateDocument401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return CreateDocument400JSONResponse{Code: 400, Message: "Missing request body"}, nil
}
// Read multipart form
form, err := request.Body.ReadForm(32 << 20) // 32MB max memory
if err != nil {
log.Error("ReadForm error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Failed to read form"}, nil
}
// Get file from form
fileField := form.File["document_file"]
if len(fileField) == 0 {
return CreateDocument400JSONResponse{Code: 400, Message: "No file provided"}, nil
}
file := fileField[0]
// Validate file extension
if !strings.HasSuffix(strings.ToLower(file.Filename), ".epub") {
return CreateDocument400JSONResponse{Code: 400, Message: "Only EPUB files are allowed"}, nil
}
// Open file
f, err := file.Open()
if err != nil {
log.Error("Open file error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Failed to open file"}, nil
}
defer f.Close()
// Read file content
data, err := io.ReadAll(f)
if err != nil {
log.Error("Read file error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Failed to read file"}, nil
}
// Validate actual content type
contentType := http.DetectContentType(data)
if contentType != "application/epub+zip" && contentType != "application/zip" {
return CreateDocument400JSONResponse{
Code: 400,
Message: fmt.Sprintf("Invalid file type: %s. Only EPUB files are allowed.", contentType),
}, nil
}
// Create temp file to get metadata
tempFile, err := os.CreateTemp("", "book")
if err != nil {
log.Error("Temp file create error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Unable to create temp file"}, nil
}
defer os.Remove(tempFile.Name())
defer tempFile.Close()
// Write data to temp file
if _, err := tempFile.Write(data); err != nil {
log.Error("Write temp file error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Unable to write temp file"}, nil
}
// Get metadata using metadata package
metadataInfo, err := metadata.GetMetadata(tempFile.Name())
if err != nil {
log.Error("GetMetadata error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Unable to acquire metadata"}, nil
}
// Check if already exists
_, err = s.db.Queries.GetDocument(ctx, *metadataInfo.PartialMD5)
if err == nil {
// Document already exists
existingDoc, _ := s.db.Queries.GetDocument(ctx, *metadataInfo.PartialMD5)
apiDoc := Document{
Id: existingDoc.ID,
Title: *existingDoc.Title,
Author: *existingDoc.Author,
Description: existingDoc.Description,
Isbn10: existingDoc.Isbn10,
Isbn13: existingDoc.Isbn13,
Words: existingDoc.Words,
Filepath: existingDoc.Filepath,
CreatedAt: parseTime(existingDoc.CreatedAt),
UpdatedAt: parseTime(existingDoc.UpdatedAt),
Deleted: existingDoc.Deleted,
}
response := DocumentResponse{
Document: apiDoc,
}
return CreateDocument200JSONResponse(response), nil
}
// Derive & sanitize file name
fileName := deriveBaseFileName(metadataInfo)
basePath := filepath.Join(s.cfg.DataPath, "documents")
safePath := filepath.Join(basePath, fileName)
// Save file to storage
err = os.WriteFile(safePath, data, 0644)
if err != nil {
log.Error("Save file error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Unable to save file"}, nil
}
// Upsert document
doc, err := s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: *metadataInfo.PartialMD5,
Title: metadataInfo.Title,
Author: metadataInfo.Author,
Description: metadataInfo.Description,
Md5: metadataInfo.MD5,
Words: metadataInfo.WordCount,
Filepath: &fileName,
Basepath: &basePath,
})
if err != nil {
log.Error("UpsertDocument DB error:", err)
return CreateDocument500JSONResponse{Code: 500, Message: "Failed to save document"}, nil
}
apiDoc := Document{
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Description: doc.Description,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Filepath: doc.Filepath,
CreatedAt: parseTime(doc.CreatedAt),
UpdatedAt: parseTime(doc.UpdatedAt),
Deleted: doc.Deleted,
}
response := DocumentResponse{
Document: apiDoc,
}
return CreateDocument200JSONResponse(response), nil
}
// GetDocumentCover200Response is a custom response type that allows setting content type
type GetDocumentCover200Response struct {
Body io.Reader
ContentLength int64
ContentType string
}
func (response GetDocumentCover200Response) VisitGetDocumentCoverResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", response.ContentType)
if response.ContentLength != 0 {
w.Header().Set("Content-Length", fmt.Sprint(response.ContentLength))
}
w.WriteHeader(200)
if closer, ok := response.Body.(io.Closer); ok {
defer closer.Close()
}
_, err := io.Copy(w, response.Body)
return err
}
// GetDocumentFile200Response is a custom response type that allows setting filename for download
type GetDocumentFile200Response struct {
Body io.Reader
ContentLength int64
Filename string
}
func (response GetDocumentFile200Response) VisitGetDocumentFileResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/octet-stream")
if response.ContentLength != 0 {
w.Header().Set("Content-Length", fmt.Sprint(response.ContentLength))
}
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", response.Filename))
w.WriteHeader(200)
if closer, ok := response.Body.(io.Closer); ok {
defer closer.Close()
}
_, err := io.Copy(w, response.Body)
return err
}

178
api/v1/documents_test.go Normal file
View File

@@ -0,0 +1,178 @@
package v1
import (
"bytes"
"crypto/md5"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
argon2 "github.com/alexedwards/argon2id"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
"reichard.io/antholume/pkg/ptr"
)
type DocumentsTestSuite struct {
suite.Suite
db *database.DBManager
cfg *config.Config
srv *Server
}
func (suite *DocumentsTestSuite) setupConfig() *config.Config {
return &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: "/tmp",
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
}
func TestDocuments(t *testing.T) {
suite.Run(t, new(DocumentsTestSuite))
}
func (suite *DocumentsTestSuite) SetupTest() {
suite.cfg = suite.setupConfig()
suite.db = database.NewMgr(suite.cfg)
suite.srv = NewServer(suite.db, suite.cfg, nil)
}
func (suite *DocumentsTestSuite) createTestUser(username, password string) {
suite.authTestSuiteHelper(username, password)
}
func (suite *DocumentsTestSuite) login(username, password string) *http.Cookie {
return suite.authLoginHelper(username, password)
}
func (suite *DocumentsTestSuite) authTestSuiteHelper(username, password string) {
// MD5 hash for KOSync compatibility (matches existing system)
md5Hash := fmt.Sprintf("%x", md5.Sum([]byte(password)))
// Then argon2 hash the MD5
hashedPassword, err := argon2.CreateHash(md5Hash, argon2.DefaultParams)
suite.Require().NoError(err)
_, err = suite.db.Queries.CreateUser(suite.T().Context(), database.CreateUserParams{
ID: username,
Pass: &hashedPassword,
AuthHash: ptr.Of("test-auth-hash"),
Admin: true,
})
suite.Require().NoError(err)
}
func (suite *DocumentsTestSuite) authLoginHelper(username, password string) *http.Cookie {
reqBody := LoginRequest{Username: username, Password: password}
body, err := json.Marshal(reqBody)
suite.Require().NoError(err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
cookies := w.Result().Cookies()
suite.Require().Len(cookies, 1)
return cookies[0]
}
func (suite *DocumentsTestSuite) TestAPIGetDocuments() {
suite.createTestUser("testuser", "testpass")
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents?page=1&limit=9", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
var resp DocumentsResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal(int64(1), resp.Page)
suite.Equal(int64(9), resp.Limit)
}
func (suite *DocumentsTestSuite) TestAPIGetDocumentsUnauthenticated() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}
func (suite *DocumentsTestSuite) TestAPIGetDocument() {
suite.createTestUser("testuser", "testpass")
docID := "test-doc-1"
_, err := suite.db.Queries.UpsertDocument(suite.T().Context(), database.UpsertDocumentParams{
ID: docID,
Title: ptr.Of("Test Document"),
Author: ptr.Of("Test Author"),
})
suite.Require().NoError(err)
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/"+docID, nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusOK, w.Code)
var resp DocumentResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal(docID, resp.Document.Id)
suite.Equal("Test Document", resp.Document.Title)
}
func (suite *DocumentsTestSuite) TestAPIGetDocumentNotFound() {
suite.createTestUser("testuser", "testpass")
cookie := suite.login("testuser", "testpass")
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/non-existent", nil)
req.AddCookie(cookie)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusNotFound, w.Code)
}
func (suite *DocumentsTestSuite) TestAPIGetDocumentCoverUnauthenticated() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/test-id/cover", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}
func (suite *DocumentsTestSuite) TestAPIGetDocumentFileUnauthenticated() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/test-id/file", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}

3
api/v1/generate.go Normal file
View File

@@ -0,0 +1,3 @@
package v1
//go:generate oapi-codegen -config oapi-codegen.yaml openapi.yaml

226
api/v1/home.go Normal file
View File

@@ -0,0 +1,226 @@
package v1
import (
"context"
"sort"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
"reichard.io/antholume/graph"
)
// GET /home
func (s *Server) GetHome(ctx context.Context, request GetHomeRequestObject) (GetHomeResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetHome401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
// Get database info
dbInfo, err := s.db.Queries.GetDatabaseInfo(ctx, auth.UserName)
if err != nil {
log.Error("GetDatabaseInfo DB Error:", err)
return GetHome500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Get streaks
streaks, err := s.db.Queries.GetUserStreaks(ctx, auth.UserName)
if err != nil {
log.Error("GetUserStreaks DB Error:", err)
return GetHome500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Get graph data
graphData, err := s.db.Queries.GetDailyReadStats(ctx, auth.UserName)
if err != nil {
log.Error("GetDailyReadStats DB Error:", err)
return GetHome500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Get user statistics
userStats, err := s.db.Queries.GetUserStatistics(ctx)
if err != nil {
log.Error("GetUserStatistics DB Error:", err)
return GetHome500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Build response
response := HomeResponse{
DatabaseInfo: DatabaseInfo{
DocumentsSize: dbInfo.DocumentsSize,
ActivitySize: dbInfo.ActivitySize,
ProgressSize: dbInfo.ProgressSize,
DevicesSize: dbInfo.DevicesSize,
},
Streaks: StreaksResponse{
Streaks: convertStreaks(streaks),
},
GraphData: GraphDataResponse{
GraphData: convertGraphData(graphData),
},
UserStatistics: arrangeUserStatistics(userStats),
}
return GetHome200JSONResponse(response), nil
}
// GET /home/streaks
func (s *Server) GetStreaks(ctx context.Context, request GetStreaksRequestObject) (GetStreaksResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetStreaks401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
streaks, err := s.db.Queries.GetUserStreaks(ctx, auth.UserName)
if err != nil {
log.Error("GetUserStreaks DB Error:", err)
return GetStreaks500JSONResponse{Code: 500, Message: "Database error"}, nil
}
response := StreaksResponse{
Streaks: convertStreaks(streaks),
}
return GetStreaks200JSONResponse(response), nil
}
// GET /home/graph
func (s *Server) GetGraphData(ctx context.Context, request GetGraphDataRequestObject) (GetGraphDataResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetGraphData401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
graphData, err := s.db.Queries.GetDailyReadStats(ctx, auth.UserName)
if err != nil {
log.Error("GetDailyReadStats DB Error:", err)
return GetGraphData500JSONResponse{Code: 500, Message: "Database error"}, nil
}
response := GraphDataResponse{
GraphData: convertGraphData(graphData),
}
return GetGraphData200JSONResponse(response), nil
}
// GET /home/statistics
func (s *Server) GetUserStatistics(ctx context.Context, request GetUserStatisticsRequestObject) (GetUserStatisticsResponseObject, error) {
_, ok := s.getSessionFromContext(ctx)
if !ok {
return GetUserStatistics401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
userStats, err := s.db.Queries.GetUserStatistics(ctx)
if err != nil {
log.Error("GetUserStatistics DB Error:", err)
return GetUserStatistics500JSONResponse{Code: 500, Message: "Database error"}, nil
}
response := arrangeUserStatistics(userStats)
return GetUserStatistics200JSONResponse(response), nil
}
func convertStreaks(streaks []database.UserStreak) []UserStreak {
result := make([]UserStreak, len(streaks))
for i, streak := range streaks {
result[i] = UserStreak{
Window: streak.Window,
MaxStreak: streak.MaxStreak,
MaxStreakStartDate: streak.MaxStreakStartDate,
MaxStreakEndDate: streak.MaxStreakEndDate,
CurrentStreak: streak.CurrentStreak,
CurrentStreakStartDate: streak.CurrentStreakStartDate,
CurrentStreakEndDate: streak.CurrentStreakEndDate,
}
}
return result
}
func convertGraphData(graphData []database.GetDailyReadStatsRow) []GraphDataPoint {
result := make([]GraphDataPoint, len(graphData))
for i, data := range graphData {
result[i] = GraphDataPoint{
Date: data.Date,
MinutesRead: data.MinutesRead,
}
}
return result
}
func arrangeUserStatistics(userStatistics []database.GetUserStatisticsRow) UserStatisticsResponse {
// Sort by WPM for each period
sortByWPM := func(stats []database.GetUserStatisticsRow, getter func(database.GetUserStatisticsRow) float64) []LeaderboardEntry {
sorted := append([]database.GetUserStatisticsRow(nil), stats...)
sort.SliceStable(sorted, func(i, j int) bool {
return getter(sorted[i]) > getter(sorted[j])
})
result := make([]LeaderboardEntry, len(sorted))
for i, item := range sorted {
result[i] = LeaderboardEntry{UserId: item.UserID, Value: getter(item)}
}
return result
}
// Sort by duration (seconds) for each period
sortByDuration := func(stats []database.GetUserStatisticsRow, getter func(database.GetUserStatisticsRow) int64) []LeaderboardEntry {
sorted := append([]database.GetUserStatisticsRow(nil), stats...)
sort.SliceStable(sorted, func(i, j int) bool {
return getter(sorted[i]) > getter(sorted[j])
})
result := make([]LeaderboardEntry, len(sorted))
for i, item := range sorted {
result[i] = LeaderboardEntry{UserId: item.UserID, Value: float64(getter(item))}
}
return result
}
// Sort by words for each period
sortByWords := func(stats []database.GetUserStatisticsRow, getter func(database.GetUserStatisticsRow) int64) []LeaderboardEntry {
sorted := append([]database.GetUserStatisticsRow(nil), stats...)
sort.SliceStable(sorted, func(i, j int) bool {
return getter(sorted[i]) > getter(sorted[j])
})
result := make([]LeaderboardEntry, len(sorted))
for i, item := range sorted {
result[i] = LeaderboardEntry{UserId: item.UserID, Value: float64(getter(item))}
}
return result
}
return UserStatisticsResponse{
Wpm: LeaderboardData{
All: sortByWPM(userStatistics, func(s database.GetUserStatisticsRow) float64 { return s.TotalWpm }),
Year: sortByWPM(userStatistics, func(s database.GetUserStatisticsRow) float64 { return s.YearlyWpm }),
Month: sortByWPM(userStatistics, func(s database.GetUserStatisticsRow) float64 { return s.MonthlyWpm }),
Week: sortByWPM(userStatistics, func(s database.GetUserStatisticsRow) float64 { return s.WeeklyWpm }),
},
Duration: LeaderboardData{
All: sortByDuration(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.TotalSeconds }),
Year: sortByDuration(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.YearlySeconds }),
Month: sortByDuration(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.MonthlySeconds }),
Week: sortByDuration(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.WeeklySeconds }),
},
Words: LeaderboardData{
All: sortByWords(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.TotalWordsRead }),
Year: sortByWords(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.YearlyWordsRead }),
Month: sortByWords(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.MonthlyWordsRead }),
Week: sortByWords(userStatistics, func(s database.GetUserStatisticsRow) int64 { return s.WeeklyWordsRead }),
},
}
}
// GetSVGGraphData generates SVG bezier path for graph visualization
func GetSVGGraphData(inputData []GraphDataPoint, svgWidth int, svgHeight int) graph.SVGGraphData {
// Convert to int64 slice expected by graph package
intData := make([]int64, len(inputData))
for i, data := range inputData {
intData[i] = int64(data.MinutesRead)
}
return graph.GetSVGGraphData(intData, svgWidth, svgHeight)
}

6
api/v1/oapi-codegen.yaml Normal file
View File

@@ -0,0 +1,6 @@
package: v1
generate:
std-http-server: true
strict-server: true
models: true
output: api.gen.go

1995
api/v1/openapi.yaml Normal file

File diff suppressed because it is too large Load Diff

170
api/v1/progress.go Normal file
View File

@@ -0,0 +1,170 @@
package v1
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
)
// GET /progress
func (s *Server) GetProgressList(ctx context.Context, request GetProgressListRequestObject) (GetProgressListResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetProgressList401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
page := int64(1)
if request.Params.Page != nil {
page = *request.Params.Page
}
limit := int64(15)
if request.Params.Limit != nil {
limit = *request.Params.Limit
}
filter := database.GetProgressParams{
UserID: auth.UserName,
Offset: (page - 1) * limit,
Limit: limit,
}
if request.Params.Document != nil && *request.Params.Document != "" {
filter.DocFilter = true
filter.DocumentID = *request.Params.Document
}
progress, err := s.db.Queries.GetProgress(ctx, filter)
if err != nil {
log.Error("GetProgress DB Error:", err)
return GetProgressList500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Get Total Count
total, err := s.db.Queries.GetProgressCount(ctx, database.GetProgressCountParams{
UserID: auth.UserName,
DocFilter: filter.DocFilter,
DocumentID: filter.DocumentID,
})
if err != nil {
log.Error("GetProgressCount DB Error:", err)
return GetProgressList500JSONResponse{Code: 500, Message: "Database error"}, nil
}
// Calculate Pagination
var nextPage *int64
var previousPage *int64
if page*limit < total {
nextPage = ptrOf(page + 1)
}
if page > 1 {
previousPage = ptrOf(page - 1)
}
apiProgress := make([]Progress, len(progress))
for i, row := range progress {
apiProgress[i] = Progress{
Title: row.Title,
Author: row.Author,
DeviceName: &row.DeviceName,
Percentage: &row.Percentage,
DocumentId: &row.DocumentID,
UserId: &row.UserID,
CreatedAt: parseTimePtr(row.CreatedAt),
}
}
response := ProgressListResponse{
Progress: &apiProgress,
Page: &page,
Limit: &limit,
NextPage: nextPage,
PreviousPage: previousPage,
Total: &total,
}
return GetProgressList200JSONResponse(response), nil
}
// GET /progress/{id}
func (s *Server) GetProgress(ctx context.Context, request GetProgressRequestObject) (GetProgressResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetProgress401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
row, err := s.db.Queries.GetDocumentProgress(ctx, database.GetDocumentProgressParams{
UserID: auth.UserName,
DocumentID: request.Id,
})
if err != nil {
log.Error("GetDocumentProgress DB Error:", err)
return GetProgress404JSONResponse{Code: 404, Message: "Progress not found"}, nil
}
apiProgress := Progress{
DeviceName: &row.DeviceName,
DeviceId: &row.DeviceID,
Percentage: &row.Percentage,
Progress: &row.Progress,
DocumentId: &row.DocumentID,
UserId: &row.UserID,
CreatedAt: parseTimePtr(row.CreatedAt),
}
response := ProgressResponse{
Progress: &apiProgress,
}
return GetProgress200JSONResponse(response), nil
}
// PUT /progress
func (s *Server) UpdateProgress(ctx context.Context, request UpdateProgressRequestObject) (UpdateProgressResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return UpdateProgress401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return UpdateProgress400JSONResponse{Code: 400, Message: "Request body is required"}, nil
}
if _, err := s.db.Queries.UpsertDevice(ctx, database.UpsertDeviceParams{
ID: request.Body.DeviceId,
UserID: auth.UserName,
DeviceName: request.Body.DeviceName,
LastSynced: time.Now().UTC().Format(time.RFC3339),
}); err != nil {
log.Error("UpsertDevice DB Error:", err)
return UpdateProgress500JSONResponse{Code: 500, Message: "Database error"}, nil
}
if _, err := s.db.Queries.UpsertDocument(ctx, database.UpsertDocumentParams{
ID: request.Body.DocumentId,
}); err != nil {
log.Error("UpsertDocument DB Error:", err)
return UpdateProgress500JSONResponse{Code: 500, Message: "Database error"}, nil
}
progress, err := s.db.Queries.UpdateProgress(ctx, database.UpdateProgressParams{
Percentage: request.Body.Percentage,
DocumentID: request.Body.DocumentId,
DeviceID: request.Body.DeviceId,
UserID: auth.UserName,
Progress: request.Body.Progress,
})
if err != nil {
log.Error("UpdateProgress DB Error:", err)
return UpdateProgress400JSONResponse{Code: 400, Message: "Invalid request"}, nil
}
response := UpdateProgressResponse{
DocumentId: progress.DocumentID,
Timestamp: parseTime(progress.CreatedAt),
}
return UpdateProgress200JSONResponse(response), nil
}

59
api/v1/search.go Normal file
View File

@@ -0,0 +1,59 @@
package v1
import (
"context"
"reichard.io/antholume/search"
log "github.com/sirupsen/logrus"
)
// GET /search
func (s *Server) GetSearch(ctx context.Context, request GetSearchRequestObject) (GetSearchResponseObject, error) {
if request.Params.Query == "" {
return GetSearch400JSONResponse{Code: 400, Message: "Invalid query"}, nil
}
query := request.Params.Query
source := string(request.Params.Source)
// Validate source
if source != "LibGen" && source != "Annas Archive" {
return GetSearch400JSONResponse{Code: 400, Message: "Invalid source"}, nil
}
searchResults, err := search.SearchBook(query, search.Source(source))
if err != nil {
log.Error("Search Error:", err)
return GetSearch500JSONResponse{Code: 500, Message: "Search error"}, nil
}
apiResults := make([]SearchItem, len(searchResults))
for i, item := range searchResults {
apiResults[i] = SearchItem{
Id: ptrOf(item.ID),
Title: ptrOf(item.Title),
Author: ptrOf(item.Author),
Language: ptrOf(item.Language),
Series: ptrOf(item.Series),
FileType: ptrOf(item.FileType),
FileSize: ptrOf(item.FileSize),
UploadDate: ptrOf(item.UploadDate),
}
}
response := SearchResponse{
Results: apiResults,
Source: source,
Query: query,
}
return GetSearch200JSONResponse(response), nil
}
// POST /search
func (s *Server) PostSearch(ctx context.Context, request PostSearchRequestObject) (PostSearchResponseObject, error) {
// This endpoint is used by the SSR template to queue a download
// For the API, we just return success - the actual download happens via /documents POST
return PostSearch200Response{}, nil
}

140
api/v1/server.go Normal file
View File

@@ -0,0 +1,140 @@
package v1
import (
"context"
"encoding/json"
"io/fs"
"net/http"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
)
var _ StrictServerInterface = (*Server)(nil)
type Server struct {
mux *http.ServeMux
db *database.DBManager
cfg *config.Config
assets fs.FS
}
// NewServer creates a new native HTTP server
func NewServer(db *database.DBManager, cfg *config.Config, assets fs.FS) *Server {
s := &Server{
mux: http.NewServeMux(),
db: db,
cfg: cfg,
assets: assets,
}
if cfg.DisableAuth {
log.Warn("DISABLE_AUTH is set — all API requests will bypass authentication")
}
// Create strict handler with authentication middleware
strictHandler := NewStrictHandler(s, []StrictMiddlewareFunc{s.authMiddleware})
s.mux = HandlerFromMuxWithBaseURL(strictHandler, s.mux, "/api/v1").(*http.ServeMux)
return s
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.mux.ServeHTTP(w, r)
}
// authMiddleware adds authentication context to requests
func (s *Server) authMiddleware(handler StrictHandlerFunc, operationID string) StrictHandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (any, error) {
// Store request and response in context for all handlers
ctx = context.WithValue(ctx, "request", r)
ctx = context.WithValue(ctx, "response", w)
// Skip auth for public auth and info endpoints - cover and file require auth via cookies
if operationID == "Login" || operationID == "Register" || operationID == "GetInfo" {
return handler(ctx, w, r, request)
}
// Dev Auth Bypass - Inject an admin session when DISABLE_AUTH is set.
// This avoids repeated logins during local development. Uses the
// first user in the database so that DB queries using the user ID
// return real data.
if s.cfg.DisableAuth {
devAuth, ok := s.resolveDevAuth(ctx)
if !ok {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500)
json.NewEncoder(w).Encode(ErrorResponse{Code: 500, Message: "DISABLE_AUTH: no users in database; register one first"})
return nil, nil
}
ctx = context.WithValue(ctx, "auth", devAuth)
return handler(ctx, w, r, request)
}
auth, ok := s.getSession(r)
if !ok {
// Write 401 response directly
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(401)
json.NewEncoder(w).Encode(ErrorResponse{Code: 401, Message: "Unauthorized"})
return nil, nil
}
// Check admin status for admin-only endpoints
adminEndpoints := []string{
"GetAdmin",
"PostAdminAction",
"GetUsers",
"UpdateUser",
"GetImportDirectory",
"PostImport",
"GetImportResults",
"GetLogs",
}
for _, adminEndpoint := range adminEndpoints {
if operationID == adminEndpoint && !auth.IsAdmin {
// Write 403 response directly
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(403)
json.NewEncoder(w).Encode(ErrorResponse{Code: 403, Message: "Admin privileges required"})
return nil, nil
}
}
// Store auth in context for handlers to access
ctx = context.WithValue(ctx, "auth", auth)
return handler(ctx, w, r, request)
}
}
// resolveDevAuth determines the dev user identity when DISABLE_AUTH is set.
// If DISABLE_AUTH_USER is specified, that user is looked up; otherwise the
// first user in the database is used.
func (s *Server) resolveDevAuth(ctx context.Context) (authData, bool) {
if s.cfg.DisableAuthUser != "" {
user, err := s.db.Queries.GetUser(ctx, s.cfg.DisableAuthUser)
if err != nil {
log.Errorf("DISABLE_AUTH_USER=%q not found in database: %v", s.cfg.DisableAuthUser, err)
return authData{}, false
}
return authData{UserName: user.ID, IsAdmin: user.Admin}, true
}
users, err := s.db.Queries.GetUsers(ctx)
if err != nil || len(users) == 0 {
return authData{}, false
}
return authData{UserName: users[0].ID, IsAdmin: users[0].Admin}, true
}
// GetInfo returns server information
func (s *Server) GetInfo(ctx context.Context, request GetInfoRequestObject) (GetInfoResponseObject, error) {
return GetInfo200JSONResponse{
Version: s.cfg.Version,
SearchEnabled: s.cfg.SearchEnabled,
RegistrationEnabled: s.cfg.RegistrationEnabled,
}, nil
}

58
api/v1/server_test.go Normal file
View File

@@ -0,0 +1,58 @@
package v1
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
"reichard.io/antholume/config"
"reichard.io/antholume/database"
)
type ServerTestSuite struct {
suite.Suite
db *database.DBManager
cfg *config.Config
srv *Server
}
func TestServer(t *testing.T) {
suite.Run(t, new(ServerTestSuite))
}
func (suite *ServerTestSuite) SetupTest() {
suite.cfg = &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: "/tmp",
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
suite.db = database.NewMgr(suite.cfg)
suite.srv = NewServer(suite.db, suite.cfg, nil)
}
func (suite *ServerTestSuite) TestNewServer() {
suite.NotNil(suite.srv)
suite.NotNil(suite.srv.mux)
suite.NotNil(suite.srv.db)
suite.NotNil(suite.srv.cfg)
}
func (suite *ServerTestSuite) TestServerServeHTTP() {
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}

157
api/v1/settings.go Normal file
View File

@@ -0,0 +1,157 @@
package v1
import (
"context"
"crypto/md5"
"fmt"
"reichard.io/antholume/database"
argon2id "github.com/alexedwards/argon2id"
)
// GET /settings
func (s *Server) GetSettings(ctx context.Context, request GetSettingsRequestObject) (GetSettingsResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return GetSettings401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
user, err := s.db.Queries.GetUser(ctx, auth.UserName)
if err != nil {
return GetSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
devices, err := s.db.Queries.GetDevices(ctx, auth.UserName)
if err != nil {
return GetSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
apiDevices := make([]Device, len(devices))
for i, device := range devices {
apiDevices[i] = Device{
Id: &device.ID,
DeviceName: &device.DeviceName,
CreatedAt: parseTimePtr(device.CreatedAt),
LastSynced: parseTimePtr(device.LastSynced),
}
}
response := SettingsResponse{
User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin},
Timezone: user.Timezone,
Devices: &apiDevices,
}
return GetSettings200JSONResponse(response), nil
}
// authorizeCredentials verifies if credentials are valid
func (s *Server) authorizeCredentials(ctx context.Context, username string, password string) bool {
user, err := s.db.Queries.GetUser(ctx, username)
if err != nil {
return false
}
// Try argon2 hash comparison
if match, err := argon2id.ComparePasswordAndHash(password, *user.Pass); err == nil && match {
return true
}
return false
}
// PUT /settings
func (s *Server) UpdateSettings(ctx context.Context, request UpdateSettingsRequestObject) (UpdateSettingsResponseObject, error) {
auth, ok := s.getSessionFromContext(ctx)
if !ok {
return UpdateSettings401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
if request.Body == nil {
return UpdateSettings400JSONResponse{Code: 400, Message: "Request body is required"}, nil
}
user, err := s.db.Queries.GetUser(ctx, auth.UserName)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
updateParams := database.UpdateUserParams{
UserID: auth.UserName,
Admin: auth.IsAdmin,
}
// Update password if provided
if request.Body.NewPassword != nil {
if request.Body.Password == nil {
return UpdateSettings400JSONResponse{Code: 400, Message: "Current password is required to set new password"}, nil
}
// Verify current password - first try bcrypt (new format), then argon2, then MD5 (legacy format)
currentPasswordMatched := false
// Try argon2 (current format)
if !currentPasswordMatched {
currentPassword := fmt.Sprintf("%x", md5.Sum([]byte(*request.Body.Password)))
if match, err := argon2id.ComparePasswordAndHash(currentPassword, *user.Pass); err == nil && match {
currentPasswordMatched = true
}
}
if !currentPasswordMatched {
return UpdateSettings400JSONResponse{Code: 400, Message: "Invalid current password"}, nil
}
// Hash new password with argon2
newPassword := fmt.Sprintf("%x", md5.Sum([]byte(*request.Body.NewPassword)))
hashedPassword, err := argon2id.CreateHash(newPassword, argon2id.DefaultParams)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: "Failed to hash password"}, nil
}
updateParams.Password = &hashedPassword
}
// Update timezone if provided
if request.Body.Timezone != nil {
updateParams.Timezone = request.Body.Timezone
}
// If nothing to update, return error
if request.Body.NewPassword == nil && request.Body.Timezone == nil {
return UpdateSettings400JSONResponse{Code: 400, Message: "At least one field must be provided"}, nil
}
// Update user
_, err = s.db.Queries.UpdateUser(ctx, updateParams)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
// Get updated settings to return
user, err = s.db.Queries.GetUser(ctx, auth.UserName)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
devices, err := s.db.Queries.GetDevices(ctx, auth.UserName)
if err != nil {
return UpdateSettings500JSONResponse{Code: 500, Message: err.Error()}, nil
}
apiDevices := make([]Device, len(devices))
for i, device := range devices {
apiDevices[i] = Device{
Id: &device.ID,
DeviceName: &device.DeviceName,
CreatedAt: parseTimePtr(device.CreatedAt),
LastSynced: parseTimePtr(device.LastSynced),
}
}
response := SettingsResponse{
User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin},
Timezone: user.Timezone,
Devices: &apiDevices,
}
return UpdateSettings200JSONResponse(response), nil
}

84
api/v1/utils.go Normal file
View File

@@ -0,0 +1,84 @@
package v1
import (
"encoding/json"
"net/http"
"net/url"
"strconv"
"time"
)
// writeJSON writes a JSON response (deprecated - used by tests only)
func writeJSON(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil {
writeJSONError(w, http.StatusInternalServerError, "Failed to encode response")
}
}
// writeJSONError writes a JSON error response (deprecated - used by tests only)
func writeJSONError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, ErrorResponse{
Code: status,
Message: message,
})
}
// QueryParams represents parsed query parameters (deprecated - used by tests only)
type QueryParams struct {
Page int64
Limit int64
Search *string
}
// parseQueryParams parses URL query parameters (deprecated - used by tests only)
func parseQueryParams(query url.Values, defaultLimit int64) QueryParams {
page, _ := strconv.ParseInt(query.Get("page"), 10, 64)
if page == 0 {
page = 1
}
limit, _ := strconv.ParseInt(query.Get("limit"), 10, 64)
if limit == 0 {
limit = defaultLimit
}
search := query.Get("search")
var searchPtr *string
if search != "" {
searchPtr = ptrOf("%" + search + "%")
}
return QueryParams{
Page: page,
Limit: limit,
Search: searchPtr,
}
}
// ptrOf returns a pointer to the given value
func ptrOf[T any](v T) *T {
return &v
}
// parseTime parses a string to time.Time
func parseTime(s string) time.Time {
t, _ := time.Parse(time.RFC3339, s)
if t.IsZero() {
t, _ = time.Parse("2006-01-02T15:04:05", s)
}
return t
}
// parseTimePtr parses an interface{} (from SQL) to *time.Time
func parseTimePtr(v interface{}) *time.Time {
if v == nil {
return nil
}
if s, ok := v.(string); ok {
t := parseTime(s)
if t.IsZero() {
return nil
}
return &t
}
return nil
}

76
api/v1/utils_test.go Normal file
View File

@@ -0,0 +1,76 @@
package v1
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
)
type UtilsTestSuite struct {
suite.Suite
}
func TestUtils(t *testing.T) {
suite.Run(t, new(UtilsTestSuite))
}
func (suite *UtilsTestSuite) TestWriteJSON() {
w := httptest.NewRecorder()
data := map[string]string{"test": "value"}
writeJSON(w, http.StatusOK, data)
suite.Equal("application/json", w.Header().Get("Content-Type"))
suite.Equal(http.StatusOK, w.Code)
var resp map[string]string
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal("value", resp["test"])
}
func (suite *UtilsTestSuite) TestWriteJSONError() {
w := httptest.NewRecorder()
writeJSONError(w, http.StatusBadRequest, "test error")
suite.Equal(http.StatusBadRequest, w.Code)
var resp ErrorResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal(http.StatusBadRequest, resp.Code)
suite.Equal("test error", resp.Message)
}
func (suite *UtilsTestSuite) TestParseQueryParams() {
query := make(map[string][]string)
query["page"] = []string{"2"}
query["limit"] = []string{"15"}
query["search"] = []string{"test"}
params := parseQueryParams(query, 9)
suite.Equal(int64(2), params.Page)
suite.Equal(int64(15), params.Limit)
suite.NotNil(params.Search)
}
func (suite *UtilsTestSuite) TestParseQueryParamsDefaults() {
query := make(map[string][]string)
params := parseQueryParams(query, 9)
suite.Equal(int64(1), params.Page)
suite.Equal(int64(9), params.Limit)
suite.Nil(params.Search)
}
func (suite *UtilsTestSuite) TestPtrOf() {
value := "test"
ptr := ptrOf(value)
suite.NotNil(ptr)
suite.Equal("test", *ptr)
}