This commit is contained in:
2026-03-22 10:44:24 -04:00
parent 7e96e41ba4
commit 27e651c4f5
25 changed files with 774 additions and 225 deletions

View File

@@ -538,6 +538,9 @@ type UpdateUserFormdataRequestBody UpdateUserFormdataBody
// LoginJSONRequestBody defines body for Login for application/json ContentType.
type LoginJSONRequestBody = LoginRequest
// RegisterJSONRequestBody defines body for Register for application/json ContentType.
type RegisterJSONRequestBody = LoginRequest
// CreateDocumentMultipartRequestBody defines body for CreateDocument for multipart/form-data ContentType.
type CreateDocumentMultipartRequestBody CreateDocumentMultipartBody
@@ -591,6 +594,9 @@ type ServerInterface interface {
// Get current user info
// (GET /auth/me)
GetMe(w http.ResponseWriter, r *http.Request)
// User registration
// (POST /auth/register)
Register(w http.ResponseWriter, r *http.Request)
// List documents
// (GET /documents)
GetDocuments(w http.ResponseWriter, r *http.Request, params GetDocumentsParams)
@@ -961,6 +967,20 @@ func (siw *ServerInterfaceWrapper) GetMe(w http.ResponseWriter, r *http.Request)
handler.ServeHTTP(w, r)
}
// Register operation middleware
func (siw *ServerInterfaceWrapper) Register(w http.ResponseWriter, r *http.Request) {
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.Register(w, r)
}))
for _, middleware := range siw.HandlerMiddlewares {
handler = middleware(handler)
}
handler.ServeHTTP(w, r)
}
// GetDocuments operation middleware
func (siw *ServerInterfaceWrapper) GetDocuments(w http.ResponseWriter, r *http.Request) {
@@ -1606,6 +1626,7 @@ func HandlerWithOptions(si ServerInterface, options StdHTTPServerOptions) http.H
m.HandleFunc("POST "+options.BaseURL+"/auth/login", wrapper.Login)
m.HandleFunc("POST "+options.BaseURL+"/auth/logout", wrapper.Logout)
m.HandleFunc("GET "+options.BaseURL+"/auth/me", wrapper.GetMe)
m.HandleFunc("POST "+options.BaseURL+"/auth/register", wrapper.Register)
m.HandleFunc("GET "+options.BaseURL+"/documents", wrapper.GetDocuments)
m.HandleFunc("POST "+options.BaseURL+"/documents", wrapper.CreateDocument)
m.HandleFunc("GET "+options.BaseURL+"/documents/{id}", wrapper.GetDocument)
@@ -2072,6 +2093,50 @@ func (response GetMe401JSONResponse) VisitGetMeResponse(w http.ResponseWriter) e
return json.NewEncoder(w).Encode(response)
}
type RegisterRequestObject struct {
Body *RegisterJSONRequestBody
}
type RegisterResponseObject interface {
VisitRegisterResponse(w http.ResponseWriter) error
}
type Register201JSONResponse LoginResponse
func (response Register201JSONResponse) VisitRegisterResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(201)
return json.NewEncoder(w).Encode(response)
}
type Register400JSONResponse ErrorResponse
func (response Register400JSONResponse) VisitRegisterResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(400)
return json.NewEncoder(w).Encode(response)
}
type Register403JSONResponse ErrorResponse
func (response Register403JSONResponse) VisitRegisterResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(403)
return json.NewEncoder(w).Encode(response)
}
type Register500JSONResponse ErrorResponse
func (response Register500JSONResponse) VisitRegisterResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500)
return json.NewEncoder(w).Encode(response)
}
type GetDocumentsRequestObject struct {
Params GetDocumentsParams
}
@@ -2864,6 +2929,9 @@ type StrictServerInterface interface {
// Get current user info
// (GET /auth/me)
GetMe(ctx context.Context, request GetMeRequestObject) (GetMeResponseObject, error)
// User registration
// (POST /auth/register)
Register(ctx context.Context, request RegisterRequestObject) (RegisterResponseObject, error)
// List documents
// (GET /documents)
GetDocuments(ctx context.Context, request GetDocumentsRequestObject) (GetDocumentsResponseObject, error)
@@ -3279,6 +3347,37 @@ func (sh *strictHandler) GetMe(w http.ResponseWriter, r *http.Request) {
}
}
// Register operation middleware
func (sh *strictHandler) Register(w http.ResponseWriter, r *http.Request) {
var request RegisterRequestObject
var body RegisterJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
sh.options.RequestErrorHandlerFunc(w, r, fmt.Errorf("can't decode JSON body: %w", err))
return
}
request.Body = &body
handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) {
return sh.ssi.Register(ctx, request.(RegisterRequestObject))
}
for _, middleware := range sh.middlewares {
handler = middleware(handler, "Register")
}
response, err := handler(r.Context(), w, r, request)
if err != nil {
sh.options.ResponseErrorHandlerFunc(w, r, err)
} else if validResponse, ok := response.(RegisterResponseObject); ok {
if err := validResponse.VisitRegisterResponse(w); err != nil {
sh.options.ResponseErrorHandlerFunc(w, r, err)
}
} else if response != nil {
sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response))
}
}
// GetDocuments operation middleware
func (sh *strictHandler) GetDocuments(w http.ResponseWriter, r *http.Request, params GetDocumentsParams) {
var request GetDocumentsRequestObject

View File

@@ -36,44 +36,8 @@ func (s *Server) Login(ctx context.Context, request LoginRequestObject) (LoginRe
return Login401JSONResponse{Code: 401, Message: "Invalid credentials"}, nil
}
// Get request and response from context (set by middleware)
r := s.getRequestFromContext(ctx)
w := s.getResponseWriterFromContext(ctx)
if r == nil || w == nil {
return Login500JSONResponse{Code: 500, Message: "Internal context error"}, nil
}
// Create session with cookie options for Vite proxy compatibility
store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey))
if s.cfg.CookieEncKey != "" {
if len(s.cfg.CookieEncKey) == 16 || len(s.cfg.CookieEncKey) == 32 {
store = sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey), []byte(s.cfg.CookieEncKey))
}
}
session, err := store.Get(r, "token")
if err != nil {
return Login401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
// Configure cookie options to work with Vite proxy
// For localhost development, we need SameSite to allow cookies across ports
session.Options.SameSite = http.SameSiteLaxMode
session.Options.HttpOnly = true
if !s.cfg.CookieSecure {
session.Options.Secure = false // Allow HTTP for localhost development
} else {
session.Options.Secure = true
}
session.Values["authorizedUser"] = user.ID
session.Values["isAdmin"] = user.Admin
session.Values["expiresAt"] = time.Now().Unix() + (60 * 60 * 24 * 7)
session.Values["authHash"] = *user.AuthHash
if err := session.Save(r, w); err != nil {
return Login500JSONResponse{Code: 500, Message: "Failed to create session"}, nil
if err := s.saveUserSession(ctx, user.ID, user.Admin, *user.AuthHash); err != nil {
return Login500JSONResponse{Code: 500, Message: err.Error()}, nil
}
return Login200JSONResponse{
@@ -82,6 +46,46 @@ func (s *Server) Login(ctx context.Context, request LoginRequestObject) (LoginRe
}, nil
}
// POST /auth/register
func (s *Server) Register(ctx context.Context, request RegisterRequestObject) (RegisterResponseObject, error) {
if !s.cfg.RegistrationEnabled {
return Register403JSONResponse{Code: 403, Message: "Registration is disabled"}, nil
}
if request.Body == nil {
return Register400JSONResponse{Code: 400, Message: "Invalid request body"}, nil
}
req := *request.Body
if req.Username == "" || req.Password == "" {
return Register400JSONResponse{Code: 400, Message: "Invalid user or password"}, nil
}
currentUsers, err := s.db.Queries.GetUsers(ctx)
if err != nil {
return Register500JSONResponse{Code: 500, Message: "Failed to create user"}, nil
}
isAdmin := len(currentUsers) == 0
if err := s.createUser(ctx, req.Username, &req.Password, &isAdmin); err != nil {
return Register400JSONResponse{Code: 400, Message: err.Error()}, nil
}
user, err := s.db.Queries.GetUser(ctx, req.Username)
if err != nil {
return Register500JSONResponse{Code: 500, Message: "Failed to load created user"}, nil
}
if err := s.saveUserSession(ctx, user.ID, user.Admin, *user.AuthHash); err != nil {
return Register500JSONResponse{Code: 500, Message: err.Error()}, nil
}
return Register201JSONResponse{
Username: user.ID,
IsAdmin: user.Admin,
}, nil
}
// POST /auth/logout
func (s *Server) Logout(ctx context.Context, request LogoutRequestObject) (LogoutResponseObject, error) {
_, ok := s.getSessionFromContext(ctx)
@@ -96,28 +100,11 @@ func (s *Server) Logout(ctx context.Context, request LogoutRequestObject) (Logou
return Logout401JSONResponse{Code: 401, Message: "Internal context error"}, nil
}
// Create session store
store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey))
if s.cfg.CookieEncKey != "" {
if len(s.cfg.CookieEncKey) == 16 || len(s.cfg.CookieEncKey) == 32 {
store = sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey), []byte(s.cfg.CookieEncKey))
}
}
session, err := store.Get(r, "token")
session, err := s.getCookieSession(r)
if err != nil {
return Logout401JSONResponse{Code: 401, Message: "Unauthorized"}, nil
}
// Configure cookie options (same as login)
session.Options.SameSite = http.SameSiteLaxMode
session.Options.HttpOnly = true
if !s.cfg.CookieSecure {
session.Options.Secure = false
} else {
session.Options.Secure = true
}
session.Values = make(map[any]any)
if err := session.Save(r, w); err != nil {
@@ -140,6 +127,50 @@ func (s *Server) GetMe(ctx context.Context, request GetMeRequestObject) (GetMeRe
}, nil
}
func (s *Server) saveUserSession(ctx context.Context, username string, isAdmin bool, authHash string) error {
r := s.getRequestFromContext(ctx)
w := s.getResponseWriterFromContext(ctx)
if r == nil || w == nil {
return fmt.Errorf("internal context error")
}
session, err := s.getCookieSession(r)
if err != nil {
return fmt.Errorf("unauthorized")
}
session.Values["authorizedUser"] = username
session.Values["isAdmin"] = isAdmin
session.Values["expiresAt"] = time.Now().Unix() + (60 * 60 * 24 * 7)
session.Values["authHash"] = authHash
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to create session")
}
return nil
}
func (s *Server) getCookieSession(r *http.Request) (*sessions.Session, error) {
store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey))
if s.cfg.CookieEncKey != "" {
if len(s.cfg.CookieEncKey) == 16 || len(s.cfg.CookieEncKey) == 32 {
store = sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey), []byte(s.cfg.CookieEncKey))
}
}
session, err := store.Get(r, "token")
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
}
session.Options.SameSite = http.SameSiteLaxMode
session.Options.HttpOnly = true
session.Options.Secure = s.cfg.CookieSecure
return session, nil
}
// getSessionFromContext extracts authData from context
func (s *Server) getSessionFromContext(ctx context.Context) (authData, bool) {
auth, ok := ctx.Value("auth").(authData)

View File

@@ -25,16 +25,16 @@ type AuthTestSuite struct {
func (suite *AuthTestSuite) setupConfig() *config.Config {
return &config.Config{
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: "/tmp",
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
ListenPort: "8080",
DBType: "memory",
DBName: "test",
ConfigPath: "/tmp",
CookieAuthKey: "test-auth-key-32-bytes-long-enough",
CookieEncKey: "0123456789abcdef",
CookieSecure: false,
CookieHTTPOnly: true,
Version: "test",
DemoMode: false,
RegistrationEnabled: true,
}
}
@@ -126,6 +126,51 @@ func (suite *AuthTestSuite) TestAPILoginInvalidCredentials() {
suite.Equal(http.StatusUnauthorized, w.Code)
}
func (suite *AuthTestSuite) TestAPIRegister() {
reqBody := LoginRequest{
Username: "newuser",
Password: "newpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusCreated, w.Code)
var resp LoginResponse
suite.Require().NoError(json.Unmarshal(w.Body.Bytes(), &resp))
suite.Equal("newuser", resp.Username)
suite.True(resp.IsAdmin, "first registered user should mirror legacy admin bootstrap behavior")
cookies := w.Result().Cookies()
suite.Require().NotEmpty(cookies, "register should set a session cookie")
user, err := suite.db.Queries.GetUser(suite.T().Context(), "newuser")
suite.Require().NoError(err)
suite.True(user.Admin)
}
func (suite *AuthTestSuite) TestAPIRegisterDisabled() {
suite.cfg.RegistrationEnabled = false
suite.srv = NewServer(suite.db, suite.cfg, nil)
reqBody := LoginRequest{
Username: "newuser",
Password: "newpass",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
w := httptest.NewRecorder()
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusForbidden, w.Code)
}
func (suite *AuthTestSuite) TestAPILogout() {
suite.createTestUser("testuser", "testpass")
cookie := suite.login("testuser", "testpass")
@@ -163,4 +208,4 @@ func (suite *AuthTestSuite) TestAPIGetMeUnauthenticated() {
suite.srv.ServeHTTP(w, req)
suite.Equal(http.StatusUnauthorized, w.Code)
}
}

View File

@@ -11,9 +11,9 @@ import (
"strings"
"time"
log "github.com/sirupsen/logrus"
"reichard.io/antholume/database"
"reichard.io/antholume/metadata"
log "github.com/sirupsen/logrus"
)
// GET /documents
@@ -81,7 +81,7 @@ func (s *Server) GetDocuments(ctx context.Context, request GetDocumentsRequestOb
LastRead: parseInterfaceTime(row.LastRead),
CreatedAt: time.Now(), // Will be overwritten if we had a proper created_at from DB
UpdatedAt: time.Now(), // Will be overwritten if we had a proper updated_at from DB
Deleted: false, // Default, should be overridden if available
Deleted: false, // Default, should be overridden if available
}
if row.Words != nil {
wordCounts = append(wordCounts, WordCount{
@@ -217,10 +217,10 @@ func (s *Server) EditDocument(ctx context.Context, request EditDocumentRequestOb
Isbn13: request.Body.Isbn13,
Coverfile: coverFileName,
// Preserve existing values for non-editable fields
Md5: currentDoc.Md5,
Basepath: currentDoc.Basepath,
Filepath: currentDoc.Filepath,
Words: currentDoc.Words,
Md5: currentDoc.Md5,
Basepath: currentDoc.Basepath,
Filepath: currentDoc.Filepath,
Words: currentDoc.Words,
})
if err != nil {
log.Error("UpsertDocument DB Error:", err)
@@ -306,7 +306,7 @@ func deriveBaseFileName(metadataInfo *metadata.MetadataInfo) string {
}
// parseInterfaceTime converts an interface{} to time.Time for SQLC queries
func parseInterfaceTime(t interface{}) *time.Time {
func parseInterfaceTime(t any) *time.Time {
if t == nil {
return nil
}
@@ -380,7 +380,7 @@ func (s *Server) GetDocumentCover(ctx context.Context, request GetDocumentCoverR
} else {
// Derive Path
coverPath := filepath.Join(s.cfg.DataPath, "covers", *document.Coverfile)
// Validate File Exists
fileInfo, err := os.Stat(coverPath)
if os.IsNotExist(err) {
@@ -713,7 +713,7 @@ func (s *Server) CreateDocument(ctx context.Context, request CreateDocumentReque
}
file := fileField[0]
// Validate file extension
if !strings.HasSuffix(strings.ToLower(file.Filename), ".epub") {
return CreateDocument400JSONResponse{Code: 400, Message: "Only EPUB files are allowed"}, nil
@@ -771,17 +771,17 @@ func (s *Server) CreateDocument(ctx context.Context, request CreateDocumentReque
// Document already exists
existingDoc, _ := s.db.Queries.GetDocument(ctx, *metadataInfo.PartialMD5)
apiDoc := Document{
Id: existingDoc.ID,
Title: *existingDoc.Title,
Author: *existingDoc.Author,
Id: existingDoc.ID,
Title: *existingDoc.Title,
Author: *existingDoc.Author,
Description: existingDoc.Description,
Isbn10: existingDoc.Isbn10,
Isbn13: existingDoc.Isbn13,
Words: existingDoc.Words,
Isbn10: existingDoc.Isbn10,
Isbn13: existingDoc.Isbn13,
Words: existingDoc.Words,
Filepath: existingDoc.Filepath,
CreatedAt: parseTime(existingDoc.CreatedAt),
CreatedAt: parseTime(existingDoc.CreatedAt),
UpdatedAt: parseTime(existingDoc.UpdatedAt),
Deleted: existingDoc.Deleted,
Deleted: existingDoc.Deleted,
}
response := DocumentResponse{
Document: apiDoc,
@@ -818,17 +818,17 @@ func (s *Server) CreateDocument(ctx context.Context, request CreateDocumentReque
}
apiDoc := Document{
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Id: doc.ID,
Title: *doc.Title,
Author: *doc.Author,
Description: doc.Description,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Isbn10: doc.Isbn10,
Isbn13: doc.Isbn13,
Words: doc.Words,
Filepath: doc.Filepath,
CreatedAt: parseTime(doc.CreatedAt),
CreatedAt: parseTime(doc.CreatedAt),
UpdatedAt: parseTime(doc.UpdatedAt),
Deleted: doc.Deleted,
Deleted: doc.Deleted,
}
response := DocumentResponse{

View File

@@ -1182,6 +1182,44 @@ paths:
schema:
$ref: '#/components/schemas/ErrorResponse'
/auth/register:
post:
summary: User registration
operationId: register
tags:
- Auth
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/LoginRequest'
responses:
201:
description: Successful registration
content:
application/json:
schema:
$ref: '#/components/schemas/LoginResponse'
400:
description: Bad request
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
403:
description: Registration disabled
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
500:
description: Internal server error
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
/auth/logout:
post:
summary: User logout

View File

@@ -46,8 +46,8 @@ func (s *Server) authMiddleware(handler StrictHandlerFunc, operationID string) S
ctx = context.WithValue(ctx, "request", r)
ctx = context.WithValue(ctx, "response", w)
// Skip auth for login and info endpoints - cover and file require auth via cookies
if operationID == "Login" || operationID == "GetInfo" {
// Skip auth for public auth and info endpoints - cover and file require auth via cookies
if operationID == "Login" || operationID == "Register" || operationID == "GetInfo" {
return handler(ctx, w, r, request)
}
@@ -92,10 +92,8 @@ func (s *Server) authMiddleware(handler StrictHandlerFunc, operationID string) S
// GetInfo returns server information
func (s *Server) GetInfo(ctx context.Context, request GetInfoRequestObject) (GetInfoResponseObject, error) {
return GetInfo200JSONResponse{
Version: s.cfg.Version,
SearchEnabled: s.cfg.SearchEnabled,
RegistrationEnabled: s.cfg.RegistrationEnabled,
Version: s.cfg.Version,
SearchEnabled: s.cfg.SearchEnabled,
RegistrationEnabled: s.cfg.RegistrationEnabled,
}, nil
}