feat(chat): stop active llm responses
This commit is contained in:
@@ -66,6 +66,7 @@ web/ # Embedded static assets (embed.go)
|
||||
| POST | `/api/chats/{chatId}` | PostChatMessage |
|
||||
| DELETE | `/api/chats/{chatId}` | DeleteChat |
|
||||
| GET | `/api/chats/{chatId}/stream` | GetChatStream |
|
||||
| POST | `/api/chats/{chatId}/stop` | StopChatGeneration |
|
||||
|
||||
## Testing
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"sync"
|
||||
@@ -19,6 +20,8 @@ type generationManager struct {
|
||||
|
||||
type generation struct {
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
subscribers map[chan *MessageChunk]struct{}
|
||||
done chan struct{}
|
||||
closed bool
|
||||
@@ -36,7 +39,10 @@ func (m *generationManager) start(chatID uuid.UUID, prepare func(*generation) er
|
||||
}
|
||||
|
||||
// Reserve Generation
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
gen := &generation{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
subscribers: make(map[chan *MessageChunk]struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
@@ -79,6 +85,19 @@ func (m *generationManager) subscribe(chatID uuid.UUID) (<-chan *MessageChunk, f
|
||||
return ch, func() { gen.unsubscribe(ch) }, true
|
||||
}
|
||||
|
||||
func (m *generationManager) stop(chatID uuid.UUID) bool {
|
||||
m.mu.RLock()
|
||||
gen, found := m.generations[chatID]
|
||||
m.mu.RUnlock()
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
// Cancel Generation
|
||||
gen.cancel()
|
||||
return true
|
||||
}
|
||||
|
||||
func (g *generation) subscribe() chan *MessageChunk {
|
||||
ch := make(chan *MessageChunk, 64)
|
||||
|
||||
@@ -126,6 +145,7 @@ func (g *generation) close() {
|
||||
return
|
||||
}
|
||||
g.closed = true
|
||||
g.cancel()
|
||||
close(g.done)
|
||||
for subscriber := range g.subscribers {
|
||||
close(subscriber)
|
||||
|
||||
@@ -449,6 +449,32 @@ func (a *API) GetChatStream(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *API) StopChatGeneration(w http.ResponseWriter, r *http.Request) {
|
||||
log := a.logger.WithField("handler", "StopChatGenerationHandler")
|
||||
|
||||
// Parse Chat ID
|
||||
rawChatID := r.PathValue("chatId")
|
||||
if rawChatID == "" {
|
||||
log.Error("missing chat ID parameter")
|
||||
http.Error(w, "Chat ID is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
chatID, err := uuid.Parse(rawChatID)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("invalid chat ID format")
|
||||
http.Error(w, "Invalid chat ID format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Stop Generation
|
||||
if !a.generationManager.stop(chatID) {
|
||||
http.Error(w, "Chat generation is not active", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) {
|
||||
log := a.logger.WithField("handler", "PostChatMessageHandler")
|
||||
|
||||
@@ -563,7 +589,7 @@ func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage st
|
||||
|
||||
func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, assistantMsg *store.Message, chatModel string, gen *generation) {
|
||||
// Create Generation Context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
|
||||
ctx, cancel := context.WithTimeout(gen.ctx, time.Minute*5)
|
||||
defer cancel()
|
||||
|
||||
// Send Message
|
||||
@@ -593,15 +619,37 @@ func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, a
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
assistantMsg.Status = store.MessageStatusFailed
|
||||
// Handle Stopped Generation
|
||||
if errors.Is(gen.ctx.Err(), context.Canceled) {
|
||||
assistantMsg.Status = store.MessageStatusStopped
|
||||
if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil {
|
||||
a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message")
|
||||
}
|
||||
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Error Generation
|
||||
assistantMsg.Status = store.MessageStatusError
|
||||
if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil {
|
||||
a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save failed assistant message")
|
||||
a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save errored assistant message")
|
||||
}
|
||||
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
||||
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to generate text stream")
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Stopped Generation
|
||||
if errors.Is(gen.ctx.Err(), context.Canceled) {
|
||||
assistantMsg.Status = store.MessageStatusStopped
|
||||
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
||||
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message")
|
||||
return
|
||||
}
|
||||
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
||||
return
|
||||
}
|
||||
|
||||
// Summarize & Update Chat Title
|
||||
if chat.Title == "" {
|
||||
var err error
|
||||
|
||||
@@ -49,6 +49,7 @@ func StartServer(settingsStore store.Store, dataDir, staticDir, listenAddress st
|
||||
mux.HandleFunc("POST /api/chats", api.PostChat)
|
||||
mux.HandleFunc("GET /api/chats/{chatId}", api.GetChat)
|
||||
mux.HandleFunc("GET /api/chats/{chatId}/stream", api.GetChatStream)
|
||||
mux.HandleFunc("POST /api/chats/{chatId}/stop", api.StopChatGeneration)
|
||||
mux.HandleFunc("POST /api/chats/{chatId}", api.PostChatMessage)
|
||||
mux.HandleFunc("DELETE /api/chats/{chatId}", api.DeleteChat)
|
||||
|
||||
|
||||
@@ -33,7 +33,8 @@ type MessageStatus string
|
||||
const (
|
||||
MessageStatusStreaming MessageStatus = "streaming"
|
||||
MessageStatusComplete MessageStatus = "complete"
|
||||
MessageStatusFailed MessageStatus = "failed"
|
||||
MessageStatusStopped MessageStatus = "stopped"
|
||||
MessageStatusError MessageStatus = "error"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
|
||||
Reference in New Issue
Block a user