diff --git a/api/api.go b/api/api.go index 49f3697..168e82a 100644 --- a/api/api.go +++ b/api/api.go @@ -114,6 +114,11 @@ func (api *API) Start() error { return api.httpServer.ListenAndServe() } +// Handler returns the underlying http.Handler for the Gin router +func (api *API) Handler() http.Handler { + return api.httpServer.Handler +} + func (api *API) Stop() error { // Stop server ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) diff --git a/api/v1/auth.go b/api/v1/auth.go new file mode 100644 index 0000000..ee127f7 --- /dev/null +++ b/api/v1/auth.go @@ -0,0 +1,179 @@ +package v1 + +import ( + "context" + "crypto/md5" + "encoding/json" + "fmt" + "net/http" + "time" + + argon2 "github.com/alexedwards/argon2id" + "github.com/gorilla/sessions" + log "github.com/sirupsen/logrus" +) + +// authData represents session authentication data +type authData struct { + UserName string + IsAdmin bool + AuthHash string +} + +// withAuth wraps a handler with session authentication +func (s *Server) withAuth(handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + auth, ok := s.getSession(r) + if !ok { + writeJSONError(w, http.StatusUnauthorized, "Unauthorized") + return + } + ctx := context.WithValue(r.Context(), "auth", auth) + handler(w, r.WithContext(ctx)) + } +} + +// getSession retrieves auth data from the session cookie +func (s *Server) getSession(r *http.Request) (auth authData, ok bool) { + // Get session from cookie store + store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey)) + if s.cfg.CookieEncKey != "" { + if len(s.cfg.CookieEncKey) == 16 || len(s.cfg.CookieEncKey) == 32 { + store = sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey), []byte(s.cfg.CookieEncKey)) + } else { + log.Error("invalid cookie encryption key (must be 16 or 32 bytes)") + return authData{}, false + } + } + + session, err := store.Get(r, "token") + if err != nil { + return authData{}, false + } + + // Get session values + authorizedUser := session.Values["authorizedUser"] + isAdmin := session.Values["isAdmin"] + expiresAt := session.Values["expiresAt"] + authHash := session.Values["authHash"] + + if authorizedUser == nil || isAdmin == nil || expiresAt == nil || authHash == nil { + return authData{}, false + } + + auth = authData{ + UserName: authorizedUser.(string), + IsAdmin: isAdmin.(bool), + AuthHash: authHash.(string), + } + + // Validate auth hash + ctx := r.Context() + correctAuthHash, err := s.getUserAuthHash(ctx, auth.UserName) + if err != nil || correctAuthHash != auth.AuthHash { + return authData{}, false + } + + return auth, true +} + +// getUserAuthHash retrieves the user's auth hash from DB or cache +func (s *Server) getUserAuthHash(ctx context.Context, username string) (string, error) { + user, err := s.db.Queries.GetUser(ctx, username) + if err != nil { + return "", err + } + return *user.AuthHash, nil +} + +// apiLogin handles POST /api/v1/auth/login +func (s *Server) apiLogin(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + var req LoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, http.StatusBadRequest, "Invalid JSON") + return + } + + if req.Username == "" || req.Password == "" { + writeJSONError(w, http.StatusBadRequest, "Invalid credentials") + return + } + + // MD5 - KOSync compatibility + password := fmt.Sprintf("%x", md5.Sum([]byte(req.Password))) + + // Verify credentials + user, err := s.db.Queries.GetUser(r.Context(), req.Username) + if err != nil { + writeJSONError(w, http.StatusUnauthorized, "Invalid credentials") + return + } + + if match, err := argon2.ComparePasswordAndHash(password, *user.Pass); err != nil || !match { + writeJSONError(w, http.StatusUnauthorized, "Invalid credentials") + return + } + + // Create session + store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey)) + if s.cfg.CookieEncKey != "" { + if len(s.cfg.CookieEncKey) == 16 || len(s.cfg.CookieEncKey) == 32 { + store = sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey), []byte(s.cfg.CookieEncKey)) + } + } + + session, _ := store.Get(r, "token") + session.Values["authorizedUser"] = user.ID + session.Values["isAdmin"] = user.Admin + session.Values["expiresAt"] = time.Now().Unix() + (60 * 60 * 24 * 7) + session.Values["authHash"] = *user.AuthHash + + if err := session.Save(r, w); err != nil { + writeJSONError(w, http.StatusInternalServerError, "Failed to create session") + return + } + + writeJSON(w, http.StatusOK, LoginResponse{ + Username: user.ID, + IsAdmin: user.Admin, + }) +} + +// apiLogout handles POST /api/v1/auth/logout +func (s *Server) apiLogout(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + store := sessions.NewCookieStore([]byte(s.cfg.CookieAuthKey)) + session, _ := store.Get(r, "token") + session.Values = make(map[any]any) + + if err := session.Save(r, w); err != nil { + writeJSONError(w, http.StatusInternalServerError, "Failed to logout") + return + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "logged out"}) +} + +// apiGetMe handles GET /api/v1/auth/me +func (s *Server) apiGetMe(w http.ResponseWriter, r *http.Request) { + auth, ok := r.Context().Value("auth").(authData) + if !ok { + writeJSONError(w, http.StatusUnauthorized, "Unauthorized") + return + } + + writeJSON(w, http.StatusOK, UserData{ + Username: auth.UserName, + IsAdmin: auth.IsAdmin, + }) +} + diff --git a/api/v1/auth_test.go b/api/v1/auth_test.go new file mode 100644 index 0000000..f41e50b --- /dev/null +++ b/api/v1/auth_test.go @@ -0,0 +1,184 @@ +package v1 + +import ( + "bytes" + "crypto/md5" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + argon2 "github.com/alexedwards/argon2id" + "reichard.io/antholume/database" +) + +func TestAPILogin(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + server := NewServer(db, cfg) + + // First, create a user + createTestUser(t, db, "testuser", "testpass") + + // Test login + reqBody := LoginRequest{ + Username: "testuser", + Password: "testpass", + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body)) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp LoginResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if resp.Username != "testuser" { + t.Errorf("Expected username 'testuser', got '%s'", resp.Username) + } +} + +func TestAPILoginInvalidCredentials(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + server := NewServer(db, cfg) + + reqBody := LoginRequest{ + Username: "testuser", + Password: "wrongpass", + } + body, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body)) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("Expected 401, got %d", w.Code) + } +} + +func TestAPILogout(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + server := NewServer(db, cfg) + + // Create user and login + createTestUser(t, db, "testuser", "testpass") + + // Login first + reqBody := LoginRequest{Username: "testuser", Password: "testpass"} + body, _ := json.Marshal(reqBody) + loginReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body)) + loginResp := httptest.NewRecorder() + server.ServeHTTP(loginResp, loginReq) + + // Get session cookie + cookies := loginResp.Result().Cookies() + if len(cookies) == 0 { + t.Fatal("No session cookie returned") + } + + // Logout + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil) + req.AddCookie(cookies[0]) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } +} + +func TestAPIGetMe(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + server := NewServer(db, cfg) + + // Create user and login + createTestUser(t, db, "testuser", "testpass") + + // Login first + reqBody := LoginRequest{Username: "testuser", Password: "testpass"} + body, _ := json.Marshal(reqBody) + loginReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body)) + loginResp := httptest.NewRecorder() + server.ServeHTTP(loginResp, loginReq) + + // Get session cookie + cookies := loginResp.Result().Cookies() + if len(cookies) == 0 { + t.Fatal("No session cookie returned") + } + + // Get me + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) + req.AddCookie(cookies[0]) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", w.Code) + } + + var resp UserData + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if resp.Username != "testuser" { + t.Errorf("Expected username 'testuser', got '%s'", resp.Username) + } +} + +func TestAPIGetMeUnauthenticated(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + server := NewServer(db, cfg) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("Expected 401, got %d", w.Code) + } +} + +func createTestUser(t *testing.T, db *database.DBManager, username, password string) { + t.Helper() + + // MD5 hash for KOSync compatibility (matches existing system) + md5Hash := fmt.Sprintf("%x", md5.Sum([]byte(password))) + + // Then argon2 hash the MD5 + hashedPassword, err := argon2.CreateHash(md5Hash, argon2.DefaultParams) + if err != nil { + t.Fatalf("Failed to hash password: %v", err) + } + + authHash := "test-auth-hash" + + _, err = db.Queries.CreateUser(t.Context(), database.CreateUserParams{ + ID: username, + Pass: &hashedPassword, + AuthHash: &authHash, + Admin: true, + }) + if err != nil { + t.Fatalf("Failed to create user: %v", err) + } +} \ No newline at end of file diff --git a/api/v1/documents.go b/api/v1/documents.go new file mode 100644 index 0000000..6b386c2 --- /dev/null +++ b/api/v1/documents.go @@ -0,0 +1,141 @@ +package v1 + +import ( + "net/http" + "strconv" + "strings" + + "reichard.io/antholume/database" + "reichard.io/antholume/pkg/ptr" +) + +// apiGetDocuments handles GET /api/v1/documents +// Deprecated: Use GetDocuments with DocumentListRequest instead +func (s *Server) apiGetDocuments(w http.ResponseWriter, r *http.Request) { + // Parse query params + query := r.URL.Query() + page, _ := strconv.ParseInt(query.Get("page"), 10, 64) + if page == 0 { + page = 1 + } + limit, _ := strconv.ParseInt(query.Get("limit"), 10, 64) + if limit == 0 { + limit = 9 + } + search := query.Get("search") + + // Get auth from context + auth, ok := r.Context().Value("auth").(authData) + if !ok { + writeJSONError(w, http.StatusUnauthorized, "Unauthorized") + return + } + + // Build query + var queryPtr *string + if search != "" { + queryPtr = ptr.Of("%" + search + "%") + } + + // Query database + rows, err := s.db.Queries.GetDocumentsWithStats( + r.Context(), + database.GetDocumentsWithStatsParams{ + UserID: auth.UserName, + Query: queryPtr, + Deleted: ptr.Of(false), + Offset: (page - 1) * limit, + Limit: limit, + }, + ) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, err.Error()) + return + } + + // Calculate pagination + total := int64(len(rows)) + var nextPage *int64 + var previousPage *int64 + if page*limit < total { + nextPage = ptr.Of(page + 1) + } + if page > 1 { + previousPage = ptr.Of(page - 1) + } + + // Get word counts + wordCounts := make([]WordCount, 0, len(rows)) + for _, row := range rows { + if row.Words != nil { + wordCounts = append(wordCounts, WordCount{ + DocumentID: row.ID, + Count: *row.Words, + }) + } + } + + // Return response + writeJSON(w, http.StatusOK, DocumentsResponse{ + Documents: rows, + Total: total, + Page: page, + Limit: limit, + NextPage: nextPage, + PreviousPage: previousPage, + Search: ptr.Of(search), + User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin}, + WordCounts: wordCounts, + }) +} + +// apiGetDocument handles GET /api/v1/documents/:id +// Deprecated: Use GetDocument with DocumentRequest instead +func (s *Server) apiGetDocument(w http.ResponseWriter, r *http.Request) { + // Extract ID from URL path + path := strings.TrimPrefix(r.URL.Path, "/api/v1/documents/") + id := strings.TrimPrefix(path, "/") + + if id == "" { + writeJSONError(w, http.StatusBadRequest, "Document ID required") + return + } + + // Get auth from context + auth, ok := r.Context().Value("auth").(authData) + if !ok { + writeJSONError(w, http.StatusUnauthorized, "Unauthorized") + return + } + + // Query database + doc, err := s.db.Queries.GetDocument(r.Context(), id) + if err != nil { + writeJSONError(w, http.StatusNotFound, "Document not found") + return + } + + // Get progress + progressRow, err := s.db.Queries.GetDocumentProgress(r.Context(), database.GetDocumentProgressParams{ + UserID: auth.UserName, + DocumentID: id, + }) + var progress *Progress + if err == nil { + progress = &Progress{ + UserID: progressRow.UserID, + DocumentID: progressRow.DocumentID, + DeviceID: progressRow.DeviceID, + Percentage: progressRow.Percentage, + Progress: progressRow.Progress, + CreatedAt: progressRow.CreatedAt, + } + } + + // Return response + writeJSON(w, http.StatusOK, DocumentResponse{ + Document: doc, + User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin}, + Progress: progress, + }) +} \ No newline at end of file diff --git a/api/v1/documents_test.go b/api/v1/documents_test.go new file mode 100644 index 0000000..fe40cef --- /dev/null +++ b/api/v1/documents_test.go @@ -0,0 +1,164 @@ +package v1 + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "reichard.io/antholume/database" + "reichard.io/antholume/pkg/ptr" +) + +func TestAPIGetDocuments(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + server := NewServer(db, cfg) + + // Create user and login + createTestUser(t, db, "testuser", "testpass") + + // Login first + reqBody := LoginRequest{Username: "testuser", Password: "testpass"} + body, _ := json.Marshal(reqBody) + loginReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body)) + loginResp := httptest.NewRecorder() + server.ServeHTTP(loginResp, loginReq) + + // Get session cookie + cookies := loginResp.Result().Cookies() + if len(cookies) == 0 { + t.Fatal("No session cookie returned") + } + + // Get documents + req := httptest.NewRequest(http.MethodGet, "/api/v1/documents?page=1&limit=9", nil) + req.AddCookie(cookies[0]) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp DocumentsResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if resp.Page != 1 { + t.Errorf("Expected page 1, got %d", resp.Page) + } + + if resp.Limit != 9 { + t.Errorf("Expected limit 9, got %d", resp.Limit) + } + + if resp.User.Username != "testuser" { + t.Errorf("Expected username 'testuser', got '%s'", resp.User.Username) + } +} + +func TestAPIGetDocumentsUnauthenticated(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + server := NewServer(db, cfg) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/documents", nil) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("Expected 401, got %d", w.Code) + } +} + +func TestAPIGetDocument(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + server := NewServer(db, cfg) + + // Create user + createTestUser(t, db, "testuser", "testpass") + + // Create a document using UpsertDocument + docID := "test-doc-1" + _, err := db.Queries.UpsertDocument(t.Context(), database.UpsertDocumentParams{ + ID: docID, + Title: ptr.Of("Test Document"), + Author: ptr.Of("Test Author"), + }) + if err != nil { + t.Fatalf("Failed to create document: %v", err) + } + + // Login + reqBody := LoginRequest{Username: "testuser", Password: "testpass"} + body, _ := json.Marshal(reqBody) + loginReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body)) + loginResp := httptest.NewRecorder() + server.ServeHTTP(loginResp, loginReq) + + cookies := loginResp.Result().Cookies() + if len(cookies) == 0 { + t.Fatal("No session cookie returned") + } + + // Get document + req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/"+docID, nil) + req.AddCookie(cookies[0]) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp DocumentResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if resp.Document.ID != docID { + t.Errorf("Expected document ID '%s', got '%s'", docID, resp.Document.ID) + } + + if *resp.Document.Title != "Test Document" { + t.Errorf("Expected title 'Test Document', got '%s'", *resp.Document.Title) + } +} + +func TestAPIGetDocumentNotFound(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + server := NewServer(db, cfg) + + // Create user and login + createTestUser(t, db, "testuser", "testpass") + + reqBody := LoginRequest{Username: "testuser", Password: "testpass"} + body, _ := json.Marshal(reqBody) + loginReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(body)) + loginResp := httptest.NewRecorder() + server.ServeHTTP(loginResp, loginReq) + + cookies := loginResp.Result().Cookies() + if len(cookies) == 0 { + t.Fatal("No session cookie returned") + } + + // Get non-existent document + req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/non-existent", nil) + req.AddCookie(cookies[0]) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("Expected 404, got %d", w.Code) + } +} \ No newline at end of file diff --git a/api/v1/handlers.go b/api/v1/handlers.go new file mode 100644 index 0000000..d44d5e5 --- /dev/null +++ b/api/v1/handlers.go @@ -0,0 +1,294 @@ +package v1 + +import ( + "context" + "net/http" + "strconv" + "strings" + + "reichard.io/antholume/database" +) + +// DocumentRequest represents a request for a single document +type DocumentRequest struct { + ID string +} + +// DocumentListRequest represents a request for listing documents +type DocumentListRequest struct { + Page int64 + Limit int64 + Search *string +} + +// ProgressRequest represents a request for document progress +type ProgressRequest struct { + ID string +} + +// ActivityRequest represents a request for activity data +type ActivityRequest struct { + DocFilter bool + DocumentID string + Offset int64 + Limit int64 +} + +// SettingsRequest represents a request for settings data +type SettingsRequest struct{} + +// GetDocument handles GET /api/v1/documents/:id +func (s *Server) GetDocument(ctx context.Context, req DocumentRequest) (DocumentResponse, error) { + auth := getAuthFromContext(ctx) + if auth == nil { + return DocumentResponse{}, &apiError{status: http.StatusUnauthorized, message: "Unauthorized"} + } + + doc, err := s.db.Queries.GetDocument(ctx, req.ID) + if err != nil { + return DocumentResponse{}, &apiError{status: http.StatusNotFound, message: "Document not found"} + } + + progressRow, err := s.db.Queries.GetDocumentProgress(ctx, database.GetDocumentProgressParams{ + UserID: auth.UserName, + DocumentID: req.ID, + }) + var progress *Progress + if err == nil { + progress = &Progress{ + UserID: progressRow.UserID, + DocumentID: progressRow.DocumentID, + DeviceID: progressRow.DeviceID, + Percentage: progressRow.Percentage, + Progress: progressRow.Progress, + CreatedAt: progressRow.CreatedAt, + } + } + + return DocumentResponse{ + Document: doc, + User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin}, + Progress: progress, + }, nil +} + +// GetDocuments handles GET /api/v1/documents +func (s *Server) GetDocuments(ctx context.Context, req DocumentListRequest) (DocumentsResponse, error) { + auth := getAuthFromContext(ctx) + if auth == nil { + return DocumentsResponse{}, &apiError{status: http.StatusUnauthorized, message: "Unauthorized"} + } + + rows, err := s.db.Queries.GetDocumentsWithStats( + ctx, + database.GetDocumentsWithStatsParams{ + UserID: auth.UserName, + Query: req.Search, + Deleted: ptrOf(false), + Offset: (req.Page - 1) * req.Limit, + Limit: req.Limit, + }, + ) + if err != nil { + return DocumentsResponse{}, &apiError{status: http.StatusInternalServerError, message: err.Error()} + } + + total := int64(len(rows)) + var nextPage *int64 + var previousPage *int64 + if req.Page*req.Limit < total { + nextPage = ptrOf(req.Page + 1) + } + if req.Page > 1 { + previousPage = ptrOf(req.Page - 1) + } + + wordCounts := make([]WordCount, 0, len(rows)) + for _, row := range rows { + if row.Words != nil { + wordCounts = append(wordCounts, WordCount{ + DocumentID: row.ID, + Count: *row.Words, + }) + } + } + + return DocumentsResponse{ + Documents: rows, + Total: total, + Page: req.Page, + Limit: req.Limit, + NextPage: nextPage, + PreviousPage: previousPage, + Search: req.Search, + User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin}, + WordCounts: wordCounts, + }, nil +} + +// GetProgress handles GET /api/v1/progress/:id +func (s *Server) GetProgress(ctx context.Context, req ProgressRequest) (Progress, error) { + auth := getAuthFromContext(ctx) + if auth == nil { + return Progress{}, &apiError{status: http.StatusUnauthorized, message: "Unauthorized"} + } + + if req.ID == "" { + return Progress{}, &apiError{status: http.StatusBadRequest, message: "Document ID required"} + } + + progressRow, err := s.db.Queries.GetDocumentProgress(ctx, database.GetDocumentProgressParams{ + UserID: auth.UserName, + DocumentID: req.ID, + }) + if err != nil { + return Progress{}, &apiError{status: http.StatusNotFound, message: "Progress not found"} + } + + return Progress{ + UserID: progressRow.UserID, + DocumentID: progressRow.DocumentID, + DeviceID: progressRow.DeviceID, + Percentage: progressRow.Percentage, + Progress: progressRow.Progress, + CreatedAt: progressRow.CreatedAt, + }, nil +} + +// GetActivity handles GET /api/v1/activity +func (s *Server) GetActivity(ctx context.Context, req ActivityRequest) (ActivityResponse, error) { + auth := getAuthFromContext(ctx) + if auth == nil { + return ActivityResponse{}, &apiError{status: http.StatusUnauthorized, message: "Unauthorized"} + } + + activities, err := s.db.Queries.GetActivity(ctx, database.GetActivityParams{ + UserID: auth.UserName, + DocFilter: req.DocFilter, + DocumentID: req.DocumentID, + Offset: req.Offset, + Limit: req.Limit, + }) + if err != nil { + return ActivityResponse{}, &apiError{status: http.StatusInternalServerError, message: err.Error()} + } + + return ActivityResponse{ + Activities: activities, + User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin}, + }, nil +} + +// GetSettings handles GET /api/v1/settings +func (s *Server) GetSettings(ctx context.Context, req SettingsRequest) (SettingsResponse, error) { + auth := getAuthFromContext(ctx) + if auth == nil { + return SettingsResponse{}, &apiError{status: http.StatusUnauthorized, message: "Unauthorized"} + } + + user, err := s.db.Queries.GetUser(ctx, auth.UserName) + if err != nil { + return SettingsResponse{}, &apiError{status: http.StatusInternalServerError, message: err.Error()} + } + + return SettingsResponse{ + Settings: []database.Setting{}, + User: UserData{Username: auth.UserName, IsAdmin: auth.IsAdmin}, + Timezone: user.Timezone, + }, nil +} + +// getAuthFromContext extracts authData from context +func getAuthFromContext(ctx context.Context) *authData { + auth, ok := ctx.Value("auth").(authData) + if !ok { + return nil + } + return &auth +} + +// apiError represents an API error with status code +type apiError struct { + status int + message string +} + +// Error implements error interface +func (e *apiError) Error() string { + return e.message +} + +// handlerFunc is a generic API handler function +type handlerFunc[T, R any] func(context.Context, T) (R, error) + +// requestParser parses an HTTP request into a request struct +type requestParser[T any] func(*http.Request) T + +// handle wraps an API handler function with HTTP response writing +func handle[T, R any](fn handlerFunc[T, R], parser requestParser[T]) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + req := parser(r) + resp, err := fn(r.Context(), req) + if err != nil { + if apiErr, ok := err.(*apiError); ok { + writeJSONError(w, apiErr.status, apiErr.message) + } else { + writeJSONError(w, http.StatusInternalServerError, err.Error()) + } + return + } + writeJSON(w, http.StatusOK, resp) + } +} + +// parseDocumentRequest extracts document request from HTTP request +func parseDocumentRequest(r *http.Request) DocumentRequest { + path := strings.TrimPrefix(r.URL.Path, "/api/v1/documents/") + id := strings.TrimPrefix(path, "/") + return DocumentRequest{ID: id} +} + +// parseDocumentListRequest extracts document list request from URL query +func parseDocumentListRequest(r *http.Request) DocumentListRequest { + query := r.URL.Query() + page, _ := strconv.ParseInt(query.Get("page"), 10, 64) + if page == 0 { + page = 1 + } + limit, _ := strconv.ParseInt(query.Get("limit"), 10, 64) + if limit == 0 { + limit = 9 + } + search := query.Get("search") + var searchPtr *string + if search != "" { + searchPtr = ptrOf("%" + search + "%") + } + return DocumentListRequest{ + Page: page, + Limit: limit, + Search: searchPtr, + } +} + +// parseProgressRequest extracts progress request from HTTP request +func parseProgressRequest(r *http.Request) ProgressRequest { + path := strings.TrimPrefix(r.URL.Path, "/api/v1/progress/") + id := strings.TrimPrefix(path, "/") + return ProgressRequest{ID: id} +} + +// parseActivityRequest extracts activity request from HTTP request +func parseActivityRequest(r *http.Request) ActivityRequest { + return ActivityRequest{ + DocFilter: false, + DocumentID: "", + Offset: 0, + Limit: 100, + } +} + +// parseSettingsRequest extracts settings request from HTTP request +func parseSettingsRequest(r *http.Request) SettingsRequest { + return SettingsRequest{} +} \ No newline at end of file diff --git a/api/v1/server.go b/api/v1/server.go new file mode 100644 index 0000000..35e870f --- /dev/null +++ b/api/v1/server.go @@ -0,0 +1,50 @@ +package v1 + +import ( + "net/http" + + "reichard.io/antholume/config" + "reichard.io/antholume/database" +) + +type Server struct { + mux *http.ServeMux + db *database.DBManager + cfg *config.Config +} + +// NewServer creates a new native HTTP server +func NewServer(db *database.DBManager, cfg *config.Config) *Server { + s := &Server{ + mux: http.NewServeMux(), + db: db, + cfg: cfg, + } + s.registerRoutes() + return s +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mux.ServeHTTP(w, r) +} + +// registerRoutes sets up all API routes +func (s *Server) registerRoutes() { + // Documents endpoints + s.mux.HandleFunc("/api/v1/documents", s.withAuth(wrapRequest(s.GetDocuments, parseDocumentListRequest))) + s.mux.HandleFunc("/api/v1/documents/", s.withAuth(wrapRequest(s.GetDocument, parseDocumentRequest))) + + // Progress endpoints + s.mux.HandleFunc("/api/v1/progress/", s.withAuth(wrapRequest(s.GetProgress, parseProgressRequest))) + + // Activity endpoints + s.mux.HandleFunc("/api/v1/activity", s.withAuth(wrapRequest(s.GetActivity, parseActivityRequest))) + + // Settings endpoints + s.mux.HandleFunc("/api/v1/settings", s.withAuth(wrapRequest(s.GetSettings, parseSettingsRequest))) + + // Auth endpoints + s.mux.HandleFunc("/api/v1/auth/login", s.apiLogin) + s.mux.HandleFunc("/api/v1/auth/logout", s.withAuth(s.apiLogout)) + s.mux.HandleFunc("/api/v1/auth/me", s.withAuth(s.apiGetMe)) +} diff --git a/api/v1/server_test.go b/api/v1/server_test.go new file mode 100644 index 0000000..d1d357c --- /dev/null +++ b/api/v1/server_test.go @@ -0,0 +1,74 @@ +package v1 + +import ( + "net/http" + "net/http/httptest" + "testing" + + "reichard.io/antholume/config" + "reichard.io/antholume/database" +) + +func TestNewServer(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + + server := NewServer(db, cfg) + + if server == nil { + t.Fatal("NewServer returned nil") + } + + if server.mux == nil { + t.Fatal("Server mux is nil") + } + + if server.db == nil { + t.Fatal("Server db is nil") + } + + if server.cfg == nil { + t.Fatal("Server cfg is nil") + } +} + +func TestServerServeHTTP(t *testing.T) { + db := setupTestDB(t) + cfg := testConfig() + + server := NewServer(db, cfg) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) + w := httptest.NewRecorder() + + server.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("Expected 401 for unauthenticated request, got %d", w.Code) + } +} + +func setupTestDB(t *testing.T) *database.DBManager { + t.Helper() + + cfg := testConfig() + cfg.DBType = "memory" + + return database.NewMgr(cfg) +} + +func testConfig() *config.Config { + return &config.Config{ + ListenPort: "8080", + DBType: "memory", + DBName: "test", + ConfigPath: "/tmp", + CookieAuthKey: "test-auth-key-32-bytes-long-enough", + CookieEncKey: "0123456789abcdef", // Exactly 16 bytes + CookieSecure: false, + CookieHTTPOnly: true, + Version: "test", + DemoMode: false, + RegistrationEnabled: true, + } +} \ No newline at end of file diff --git a/api/v1/types.go b/api/v1/types.go new file mode 100644 index 0000000..e9d74d7 --- /dev/null +++ b/api/v1/types.go @@ -0,0 +1,76 @@ +package v1 + +import "reichard.io/antholume/database" + +// DocumentsResponse is the API response for document list endpoints +type DocumentsResponse struct { + Documents []database.GetDocumentsWithStatsRow `json:"documents"` + Total int64 `json:"total"` + Page int64 `json:"page"` + Limit int64 `json:"limit"` + NextPage *int64 `json:"next_page"` + PreviousPage *int64 `json:"previous_page"` + Search *string `json:"search"` + User UserData `json:"user"` + WordCounts []WordCount `json:"word_counts"` +} + +// DocumentResponse is the API response for single document endpoints +type DocumentResponse struct { + Document database.Document `json:"document"` + User UserData `json:"user"` + Progress *Progress `json:"progress"` +} + +// UserData represents authenticated user context +type UserData struct { + Username string `json:"username"` + IsAdmin bool `json:"is_admin"` +} + +// WordCount represents computed word count statistics +type WordCount struct { + DocumentID string `json:"document_id"` + Count int64 `json:"count"` +} + +// Progress represents reading progress for a document +type Progress struct { + UserID string `json:"user_id"` + DocumentID string `json:"document_id"` + DeviceID string `json:"device_id"` + Percentage float64 `json:"percentage"` + Progress string `json:"progress"` + CreatedAt string `json:"created_at"` +} + +// ActivityResponse is the API response for activity endpoints +type ActivityResponse struct { + Activities []database.GetActivityRow `json:"activities"` + User UserData `json:"user"` +} + +// SettingsResponse is the API response for settings endpoints +type SettingsResponse struct { + Settings []database.Setting `json:"settings"` + User UserData `json:"user"` + Timezone *string `json:"timezone"` +} + +// LoginRequest is the request body for login +type LoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// LoginResponse is the response for successful login +type LoginResponse struct { + Username string `json:"username"` + IsAdmin bool `json:"is_admin"` +} + +// ErrorResponse represents an API error +type ErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` +} \ No newline at end of file diff --git a/api/v1/utils.go b/api/v1/utils.go new file mode 100644 index 0000000..1d818f0 --- /dev/null +++ b/api/v1/utils.go @@ -0,0 +1,59 @@ +package v1 + +import ( + "encoding/json" + "net/http" + "net/url" + "strconv" +) + +// writeJSON writes a JSON response +func writeJSON(w http.ResponseWriter, status int, data any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(data); err != nil { + writeJSONError(w, http.StatusInternalServerError, "Failed to encode response") + } +} + +// writeJSONError writes a JSON error response +func writeJSONError(w http.ResponseWriter, status int, message string) { + writeJSON(w, status, ErrorResponse{ + Code: status, + Message: message, + }) +} + +// QueryParams represents parsed query parameters +type QueryParams struct { + Page int64 + Limit int64 + Search *string +} + +// parseQueryParams parses URL query parameters +func parseQueryParams(query url.Values, defaultLimit int64) QueryParams { + page, _ := strconv.ParseInt(query.Get("page"), 10, 64) + if page == 0 { + page = 1 + } + limit, _ := strconv.ParseInt(query.Get("limit"), 10, 64) + if limit == 0 { + limit = defaultLimit + } + search := query.Get("search") + var searchPtr *string + if search != "" { + searchPtr = ptrOf("%" + search + "%") + } + return QueryParams{ + Page: page, + Limit: limit, + Search: searchPtr, + } +} + +// ptrOf returns a pointer to the given value +func ptrOf[T any](v T) *T { + return &v +} \ No newline at end of file diff --git a/api/v1/utils_test.go b/api/v1/utils_test.go new file mode 100644 index 0000000..de0e647 --- /dev/null +++ b/api/v1/utils_test.go @@ -0,0 +1,107 @@ +package v1 + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteJSON(t *testing.T) { + w := httptest.NewRecorder() + data := map[string]string{"test": "value"} + + writeJSON(w, http.StatusOK, data) + + if w.Header().Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type 'application/json', got '%s'", w.Header().Get("Content-Type")) + } + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if resp["test"] != "value" { + t.Errorf("Expected 'value', got '%s'", resp["test"]) + } +} + +func TestWriteJSONError(t *testing.T) { + w := httptest.NewRecorder() + + writeJSONError(w, http.StatusBadRequest, "test error") + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", w.Code) + } + + var resp ErrorResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if resp.Code != http.StatusBadRequest { + t.Errorf("Expected code 400, got %d", resp.Code) + } + + if resp.Message != "test error" { + t.Errorf("Expected message 'test error', got '%s'", resp.Message) + } +} + +func TestParseQueryParams(t *testing.T) { + query := make(map[string][]string) + query["page"] = []string{"2"} + query["limit"] = []string{"15"} + query["search"] = []string{"test"} + + params := parseQueryParams(query, 9) + + if params.Page != 2 { + t.Errorf("Expected page 2, got %d", params.Page) + } + + if params.Limit != 15 { + t.Errorf("Expected limit 15, got %d", params.Limit) + } + + if params.Search == nil { + t.Fatal("Expected search to be set") + } +} + +func TestParseQueryParamsDefaults(t *testing.T) { + query := make(map[string][]string) + + params := parseQueryParams(query, 9) + + if params.Page != 1 { + t.Errorf("Expected page 1, got %d", params.Page) + } + + if params.Limit != 9 { + t.Errorf("Expected limit 9, got %d", params.Limit) + } + + if params.Search != nil { + t.Errorf("Expected search to be nil, got '%v'", params.Search) + } +} + +func TestPtrOf(t *testing.T) { + value := "test" + ptr := ptrOf(value) + + if ptr == nil { + t.Fatal("Expected non-nil pointer") + } + + if *ptr != "test" { + t.Errorf("Expected 'test', got '%s'", *ptr) + } +} \ No newline at end of file diff --git a/server/server.go b/server/server.go index f9cf34a..7d5d7ce 100644 --- a/server/server.go +++ b/server/server.go @@ -12,36 +12,62 @@ import ( "reichard.io/antholume/api" "reichard.io/antholume/config" "reichard.io/antholume/database" + v1 "reichard.io/antholume/api/v1" ) type server struct { - db *database.DBManager - api *api.API - done chan int - wg sync.WaitGroup + db *database.DBManager + ginAPI *api.API + v1API *v1.Server + httpServer *http.Server + done chan int + wg sync.WaitGroup } -// Create new server +// Create new server with both Gin and v1 API running in parallel func New(c *config.Config, assets fs.FS) *server { db := database.NewMgr(c) - api := api.NewApi(db, c, assets) + ginAPI := api.NewApi(db, c, assets) + v1API := v1.NewServer(db, c) + + // Create combined mux that handles both Gin and v1 API + mux := http.NewServeMux() + + // Register v1 API routes first (they take precedence) + mux.Handle("/api/v1/", v1API) + + // Register Gin API routes (handles all other routes including /) + // Gin's router implements http.Handler + mux.Handle("/", ginAPI.Handler()) + + // Create HTTP server with combined mux + httpServer := &http.Server{ + Handler: mux, + Addr: ":" + c.ListenPort, + } return &server{ - db: db, - api: api, - done: make(chan int), + db: db, + ginAPI: ginAPI, + v1API: v1API, + httpServer: httpServer, + done: make(chan int), } } -// Start server +// Start server - runs both Gin and v1 API concurrently func (s *server) Start() { - log.Info("Starting server...") + log.Info("Starting server with both Gin (templates) and v1 (API)...") + log.Info("v1 API endpoints available at /api/v1/*") + log.Info("Gin template endpoints available at /") + s.wg.Add(2) go func() { defer s.wg.Done() - err := s.api.Start() + log.Infof("HTTP server listening on %s", s.httpServer.Addr) + err := s.httpServer.ListenAndServe() if err != nil && err != http.ErrServerClosed { log.Error("Starting server failed: ", err) } @@ -66,20 +92,28 @@ func (s *server) Start() { } }() - log.Info("Server started") + log.Info("Server started - running both Gin and v1 API concurrently") } -// Stop server +// Stop server - gracefully shuts down both APIs func (s *server) Stop() { log.Info("Stopping server...") - if err := s.api.Stop(); err != nil { - log.Error("HTTP server stop failed: ", err) + // Shutdown HTTP server (both Gin and v1) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := s.httpServer.Shutdown(ctx); err != nil { + log.Error("HTTP server shutdown failed: ", err) } close(s.done) s.wg.Wait() + // Close DB + if err := s.db.DB.Close(); err != nil { + log.Error("DB close failed: ", err) + } + log.Info("Server stopped") } @@ -90,4 +124,4 @@ func (s *server) runScheduledTasks(ctx context.Context) { log.Warn("Refreshing temp table cache failed: ", err) } log.Debug("Completed in: ", time.Since(start)) -} +} \ No newline at end of file