From f359471a271389fde4a2c806632aa375e8899e1d Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Sat, 2 May 2026 16:26:10 -0400 Subject: [PATCH] feat(chat): stop active llm responses --- .gitignore | 1 + backend/AGENTS.md | 1 + backend/internal/api/generation.go | 20 ++++++++++ backend/internal/api/handlers.go | 54 ++++++++++++++++++++++++-- backend/internal/server/server.go | 1 + backend/internal/store/types.go | 3 +- flake.nix | 2 + frontend/public/pages/chats.html | 28 ++++++++++++- frontend/src/client.ts | 11 ++++++ frontend/src/components/chatManager.ts | 13 +++++++ frontend/src/types/index.ts | 2 +- 11 files changed, 130 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 18faab5..3e166e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ data .opencode +.env diff --git a/backend/AGENTS.md b/backend/AGENTS.md index 59aa8ff..326a437 100644 --- a/backend/AGENTS.md +++ b/backend/AGENTS.md @@ -66,6 +66,7 @@ web/ # Embedded static assets (embed.go) | POST | `/api/chats/{chatId}` | PostChatMessage | | DELETE | `/api/chats/{chatId}` | DeleteChat | | GET | `/api/chats/{chatId}/stream` | GetChatStream | +| POST | `/api/chats/{chatId}/stop` | StopChatGeneration | ## Testing diff --git a/backend/internal/api/generation.go b/backend/internal/api/generation.go index 1804946..ef9ec6f 100644 --- a/backend/internal/api/generation.go +++ b/backend/internal/api/generation.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "slices" "sync" @@ -19,6 +20,8 @@ type generationManager struct { type generation struct { mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc subscribers map[chan *MessageChunk]struct{} done chan struct{} closed bool @@ -36,7 +39,10 @@ func (m *generationManager) start(chatID uuid.UUID, prepare func(*generation) er } // Reserve Generation + ctx, cancel := context.WithCancel(context.Background()) gen := &generation{ + ctx: ctx, + cancel: cancel, subscribers: make(map[chan *MessageChunk]struct{}), done: make(chan struct{}), } @@ -79,6 +85,19 @@ func (m *generationManager) subscribe(chatID uuid.UUID) (<-chan *MessageChunk, f return ch, func() { gen.unsubscribe(ch) }, true } +func (m *generationManager) stop(chatID uuid.UUID) bool { + m.mu.RLock() + gen, found := m.generations[chatID] + m.mu.RUnlock() + if !found { + return false + } + + // Cancel Generation + gen.cancel() + return true +} + func (g *generation) subscribe() chan *MessageChunk { ch := make(chan *MessageChunk, 64) @@ -126,6 +145,7 @@ func (g *generation) close() { return } g.closed = true + g.cancel() close(g.done) for subscriber := range g.subscribers { close(subscriber) diff --git a/backend/internal/api/handlers.go b/backend/internal/api/handlers.go index 23bd0f6..f423a5a 100644 --- a/backend/internal/api/handlers.go +++ b/backend/internal/api/handlers.go @@ -449,6 +449,32 @@ func (a *API) GetChatStream(w http.ResponseWriter, r *http.Request) { } } +func (a *API) StopChatGeneration(w http.ResponseWriter, r *http.Request) { + log := a.logger.WithField("handler", "StopChatGenerationHandler") + + // Parse Chat ID + rawChatID := r.PathValue("chatId") + if rawChatID == "" { + log.Error("missing chat ID parameter") + http.Error(w, "Chat ID is required", http.StatusBadRequest) + return + } + chatID, err := uuid.Parse(rawChatID) + if err != nil { + log.WithError(err).Error("invalid chat ID format") + http.Error(w, "Invalid chat ID format", http.StatusBadRequest) + return + } + + // Stop Generation + if !a.generationManager.stop(chatID) { + http.Error(w, "Chat generation is not active", http.StatusNotFound) + return + } + + w.WriteHeader(http.StatusNoContent) +} + func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) { log := a.logger.WithField("handler", "PostChatMessageHandler") @@ -563,7 +589,7 @@ func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage st func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, assistantMsg *store.Message, chatModel string, gen *generation) { // Create Generation Context - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + ctx, cancel := context.WithTimeout(gen.ctx, time.Minute*5) defer cancel() // Send Message @@ -593,15 +619,37 @@ func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, a return nil }); err != nil { - assistantMsg.Status = store.MessageStatusFailed + // Handle Stopped Generation + if errors.Is(gen.ctx.Err(), context.Canceled) { + assistantMsg.Status = store.MessageStatusStopped + if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil { + a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message") + } + gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg}) + return + } + + // Handle Error Generation + assistantMsg.Status = store.MessageStatusError if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil { - a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save failed assistant message") + a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save errored assistant message") } gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg}) a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to generate text stream") return } + // Handle Stopped Generation + if errors.Is(gen.ctx.Err(), context.Canceled) { + assistantMsg.Status = store.MessageStatusStopped + if err := a.store.SaveChatMessage(assistantMsg); err != nil { + a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message") + return + } + gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg}) + return + } + // Summarize & Update Chat Title if chat.Title == "" { var err error diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go index 36d93f9..7af4d8e 100644 --- a/backend/internal/server/server.go +++ b/backend/internal/server/server.go @@ -49,6 +49,7 @@ func StartServer(settingsStore store.Store, dataDir, staticDir, listenAddress st mux.HandleFunc("POST /api/chats", api.PostChat) mux.HandleFunc("GET /api/chats/{chatId}", api.GetChat) mux.HandleFunc("GET /api/chats/{chatId}/stream", api.GetChatStream) + mux.HandleFunc("POST /api/chats/{chatId}/stop", api.StopChatGeneration) mux.HandleFunc("POST /api/chats/{chatId}", api.PostChatMessage) mux.HandleFunc("DELETE /api/chats/{chatId}", api.DeleteChat) diff --git a/backend/internal/store/types.go b/backend/internal/store/types.go index 0f7956f..a7b0eac 100644 --- a/backend/internal/store/types.go +++ b/backend/internal/store/types.go @@ -33,7 +33,8 @@ type MessageStatus string const ( MessageStatusStreaming MessageStatus = "streaming" MessageStatusComplete MessageStatus = "complete" - MessageStatusFailed MessageStatus = "failed" + MessageStatusStopped MessageStatus = "stopped" + MessageStatusError MessageStatus = "error" ) type Message struct { diff --git a/flake.nix b/flake.nix index 8f66f23..e848594 100644 --- a/flake.nix +++ b/flake.nix @@ -38,6 +38,8 @@ shellHook = '' export LD_LIBRARY_PATH=${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH + + . .env ''; }; } diff --git a/frontend/public/pages/chats.html b/frontend/public/pages/chats.html index f2d94fa..cf54289 100644 --- a/frontend/public/pages/chats.html +++ b/frontend/public/pages/chats.html @@ -68,12 +68,18 @@ x-html="renderMarkdown(message.content)" > - +
+
@@ -279,6 +285,26 @@ + + diff --git a/frontend/src/client.ts b/frontend/src/client.ts index 5049d2c..4e7b436 100644 --- a/frontend/src/client.ts +++ b/frontend/src/client.ts @@ -124,6 +124,17 @@ export async function streamChatUpdates( return streamMessage(response, onChunk); } +export async function stopChatGeneration(chatId: string): Promise { + const response = await fetch(`/api/chats/${chatId}/stop`, { + method: 'POST', + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(readError(errorData) || `HTTP ${response.status}`); + } +} + export async function getChatMessages(chatId: string): Promise { const response = await fetch(`/api/chats/${chatId}`); const data = await response.json().catch(() => ({})); diff --git a/frontend/src/components/chatManager.ts b/frontend/src/components/chatManager.ts index ccfb7f4..fff2f73 100644 --- a/frontend/src/components/chatManager.ts +++ b/frontend/src/components/chatManager.ts @@ -7,6 +7,7 @@ import { getModels, sendMessage, streamChatUpdates, + stopChatGeneration, getChatMessages, listChats, deleteChat, @@ -87,6 +88,18 @@ Alpine.data('chatManager', () => ({ } }, + async stopResponse() { + if (!this.activeStreamChatID) return; + + // Stop Active Generation + try { + await stopChatGeneration(this.activeStreamChatID); + } catch (err) { + console.error('Error stopping response:', err); + this.error = parseError(err); + } + }, + async sendMessage() { const message = this.inputMessage.trim(); if ((!message && this.selectedImages.length === 0) || this.loading) return; diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 7db8f6a..a32c01b 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -7,7 +7,7 @@ export interface Chat { messages: Message[]; } -export type MessageStatus = 'streaming' | 'complete' | 'failed'; +export type MessageStatus = 'streaming' | 'complete' | 'stopped' | 'error' | 'failed'; export interface Message { id: string;