diff --git a/.gitignore b/.gitignore index 18faab5..3e166e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ data .opencode +.env diff --git a/backend/AGENTS.md b/backend/AGENTS.md index 59aa8ff..326a437 100644 --- a/backend/AGENTS.md +++ b/backend/AGENTS.md @@ -66,6 +66,7 @@ web/ # Embedded static assets (embed.go) | POST | `/api/chats/{chatId}` | PostChatMessage | | DELETE | `/api/chats/{chatId}` | DeleteChat | | GET | `/api/chats/{chatId}/stream` | GetChatStream | +| POST | `/api/chats/{chatId}/stop` | StopChatGeneration | ## Testing diff --git a/backend/internal/api/generation.go b/backend/internal/api/generation.go index 1804946..ef9ec6f 100644 --- a/backend/internal/api/generation.go +++ b/backend/internal/api/generation.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "slices" "sync" @@ -19,6 +20,8 @@ type generationManager struct { type generation struct { mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc subscribers map[chan *MessageChunk]struct{} done chan struct{} closed bool @@ -36,7 +39,10 @@ func (m *generationManager) start(chatID uuid.UUID, prepare func(*generation) er } // Reserve Generation + ctx, cancel := context.WithCancel(context.Background()) gen := &generation{ + ctx: ctx, + cancel: cancel, subscribers: make(map[chan *MessageChunk]struct{}), done: make(chan struct{}), } @@ -79,6 +85,19 @@ func (m *generationManager) subscribe(chatID uuid.UUID) (<-chan *MessageChunk, f return ch, func() { gen.unsubscribe(ch) }, true } +func (m *generationManager) stop(chatID uuid.UUID) bool { + m.mu.RLock() + gen, found := m.generations[chatID] + m.mu.RUnlock() + if !found { + return false + } + + // Cancel Generation + gen.cancel() + return true +} + func (g *generation) subscribe() chan *MessageChunk { ch := make(chan *MessageChunk, 64) @@ -126,6 +145,7 @@ func (g *generation) close() { return } g.closed = true + g.cancel() close(g.done) for subscriber := range g.subscribers { close(subscriber) diff --git a/backend/internal/api/handlers.go b/backend/internal/api/handlers.go index 23bd0f6..f423a5a 100644 --- a/backend/internal/api/handlers.go +++ b/backend/internal/api/handlers.go @@ -449,6 +449,32 @@ func (a *API) GetChatStream(w http.ResponseWriter, r *http.Request) { } } +func (a *API) StopChatGeneration(w http.ResponseWriter, r *http.Request) { + log := a.logger.WithField("handler", "StopChatGenerationHandler") + + // Parse Chat ID + rawChatID := r.PathValue("chatId") + if rawChatID == "" { + log.Error("missing chat ID parameter") + http.Error(w, "Chat ID is required", http.StatusBadRequest) + return + } + chatID, err := uuid.Parse(rawChatID) + if err != nil { + log.WithError(err).Error("invalid chat ID format") + http.Error(w, "Invalid chat ID format", http.StatusBadRequest) + return + } + + // Stop Generation + if !a.generationManager.stop(chatID) { + http.Error(w, "Chat generation is not active", http.StatusNotFound) + return + } + + w.WriteHeader(http.StatusNoContent) +} + func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) { log := a.logger.WithField("handler", "PostChatMessageHandler") @@ -563,7 +589,7 @@ func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage st func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, assistantMsg *store.Message, chatModel string, gen *generation) { // Create Generation Context - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + ctx, cancel := context.WithTimeout(gen.ctx, time.Minute*5) defer cancel() // Send Message @@ -593,15 +619,37 @@ func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, a return nil }); err != nil { - assistantMsg.Status = store.MessageStatusFailed + // Handle Stopped Generation + if errors.Is(gen.ctx.Err(), context.Canceled) { + assistantMsg.Status = store.MessageStatusStopped + if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil { + a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message") + } + gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg}) + return + } + + // Handle Error Generation + assistantMsg.Status = store.MessageStatusError if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil { - a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save failed assistant message") + a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save errored assistant message") } gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg}) a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to generate text stream") return } + // Handle Stopped Generation + if errors.Is(gen.ctx.Err(), context.Canceled) { + assistantMsg.Status = store.MessageStatusStopped + if err := a.store.SaveChatMessage(assistantMsg); err != nil { + a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message") + return + } + gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg}) + return + } + // Summarize & Update Chat Title if chat.Title == "" { var err error diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go index 36d93f9..7af4d8e 100644 --- a/backend/internal/server/server.go +++ b/backend/internal/server/server.go @@ -49,6 +49,7 @@ func StartServer(settingsStore store.Store, dataDir, staticDir, listenAddress st mux.HandleFunc("POST /api/chats", api.PostChat) mux.HandleFunc("GET /api/chats/{chatId}", api.GetChat) mux.HandleFunc("GET /api/chats/{chatId}/stream", api.GetChatStream) + mux.HandleFunc("POST /api/chats/{chatId}/stop", api.StopChatGeneration) mux.HandleFunc("POST /api/chats/{chatId}", api.PostChatMessage) mux.HandleFunc("DELETE /api/chats/{chatId}", api.DeleteChat) diff --git a/backend/internal/store/types.go b/backend/internal/store/types.go index 0f7956f..a7b0eac 100644 --- a/backend/internal/store/types.go +++ b/backend/internal/store/types.go @@ -33,7 +33,8 @@ type MessageStatus string const ( MessageStatusStreaming MessageStatus = "streaming" MessageStatusComplete MessageStatus = "complete" - MessageStatusFailed MessageStatus = "failed" + MessageStatusStopped MessageStatus = "stopped" + MessageStatusError MessageStatus = "error" ) type Message struct { diff --git a/flake.nix b/flake.nix index 8f66f23..e848594 100644 --- a/flake.nix +++ b/flake.nix @@ -38,6 +38,8 @@ shellHook = '' export LD_LIBRARY_PATH=${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH + + . .env ''; }; } diff --git a/frontend/public/pages/chats.html b/frontend/public/pages/chats.html index f2d94fa..cf54289 100644 --- a/frontend/public/pages/chats.html +++ b/frontend/public/pages/chats.html @@ -68,12 +68,18 @@ x-html="renderMarkdown(message.content)" > - +