feat(chat): stop active llm responses
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
|||||||
data
|
data
|
||||||
.opencode
|
.opencode
|
||||||
|
.env
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ web/ # Embedded static assets (embed.go)
|
|||||||
| POST | `/api/chats/{chatId}` | PostChatMessage |
|
| POST | `/api/chats/{chatId}` | PostChatMessage |
|
||||||
| DELETE | `/api/chats/{chatId}` | DeleteChat |
|
| DELETE | `/api/chats/{chatId}` | DeleteChat |
|
||||||
| GET | `/api/chats/{chatId}/stream` | GetChatStream |
|
| GET | `/api/chats/{chatId}/stream` | GetChatStream |
|
||||||
|
| POST | `/api/chats/{chatId}/stop` | StopChatGeneration |
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -19,6 +20,8 @@ type generationManager struct {
|
|||||||
|
|
||||||
type generation struct {
|
type generation struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
subscribers map[chan *MessageChunk]struct{}
|
subscribers map[chan *MessageChunk]struct{}
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
closed bool
|
closed bool
|
||||||
@@ -36,7 +39,10 @@ func (m *generationManager) start(chatID uuid.UUID, prepare func(*generation) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reserve Generation
|
// Reserve Generation
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
gen := &generation{
|
gen := &generation{
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
subscribers: make(map[chan *MessageChunk]struct{}),
|
subscribers: make(map[chan *MessageChunk]struct{}),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
@@ -79,6 +85,19 @@ func (m *generationManager) subscribe(chatID uuid.UUID) (<-chan *MessageChunk, f
|
|||||||
return ch, func() { gen.unsubscribe(ch) }, true
|
return ch, func() { gen.unsubscribe(ch) }, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *generationManager) stop(chatID uuid.UUID) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
gen, found := m.generations[chatID]
|
||||||
|
m.mu.RUnlock()
|
||||||
|
if !found {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel Generation
|
||||||
|
gen.cancel()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (g *generation) subscribe() chan *MessageChunk {
|
func (g *generation) subscribe() chan *MessageChunk {
|
||||||
ch := make(chan *MessageChunk, 64)
|
ch := make(chan *MessageChunk, 64)
|
||||||
|
|
||||||
@@ -126,6 +145,7 @@ func (g *generation) close() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
g.closed = true
|
g.closed = true
|
||||||
|
g.cancel()
|
||||||
close(g.done)
|
close(g.done)
|
||||||
for subscriber := range g.subscribers {
|
for subscriber := range g.subscribers {
|
||||||
close(subscriber)
|
close(subscriber)
|
||||||
|
|||||||
@@ -449,6 +449,32 @@ func (a *API) GetChatStream(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *API) StopChatGeneration(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log := a.logger.WithField("handler", "StopChatGenerationHandler")
|
||||||
|
|
||||||
|
// Parse Chat ID
|
||||||
|
rawChatID := r.PathValue("chatId")
|
||||||
|
if rawChatID == "" {
|
||||||
|
log.Error("missing chat ID parameter")
|
||||||
|
http.Error(w, "Chat ID is required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
chatID, err := uuid.Parse(rawChatID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("invalid chat ID format")
|
||||||
|
http.Error(w, "Invalid chat ID format", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop Generation
|
||||||
|
if !a.generationManager.stop(chatID) {
|
||||||
|
http.Error(w, "Chat generation is not active", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) {
|
func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) {
|
||||||
log := a.logger.WithField("handler", "PostChatMessageHandler")
|
log := a.logger.WithField("handler", "PostChatMessageHandler")
|
||||||
|
|
||||||
@@ -563,7 +589,7 @@ func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage st
|
|||||||
|
|
||||||
func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, assistantMsg *store.Message, chatModel string, gen *generation) {
|
func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, assistantMsg *store.Message, chatModel string, gen *generation) {
|
||||||
// Create Generation Context
|
// Create Generation Context
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
|
ctx, cancel := context.WithTimeout(gen.ctx, time.Minute*5)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Send Message
|
// Send Message
|
||||||
@@ -593,15 +619,37 @@ func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, a
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
assistantMsg.Status = store.MessageStatusFailed
|
// Handle Stopped Generation
|
||||||
|
if errors.Is(gen.ctx.Err(), context.Canceled) {
|
||||||
|
assistantMsg.Status = store.MessageStatusStopped
|
||||||
|
if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil {
|
||||||
|
a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message")
|
||||||
|
}
|
||||||
|
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Error Generation
|
||||||
|
assistantMsg.Status = store.MessageStatusError
|
||||||
if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil {
|
if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil {
|
||||||
a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save failed assistant message")
|
a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save errored assistant message")
|
||||||
}
|
}
|
||||||
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
||||||
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to generate text stream")
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to generate text stream")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle Stopped Generation
|
||||||
|
if errors.Is(gen.ctx.Err(), context.Canceled) {
|
||||||
|
assistantMsg.Status = store.MessageStatusStopped
|
||||||
|
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
||||||
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Summarize & Update Chat Title
|
// Summarize & Update Chat Title
|
||||||
if chat.Title == "" {
|
if chat.Title == "" {
|
||||||
var err error
|
var err error
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ func StartServer(settingsStore store.Store, dataDir, staticDir, listenAddress st
|
|||||||
mux.HandleFunc("POST /api/chats", api.PostChat)
|
mux.HandleFunc("POST /api/chats", api.PostChat)
|
||||||
mux.HandleFunc("GET /api/chats/{chatId}", api.GetChat)
|
mux.HandleFunc("GET /api/chats/{chatId}", api.GetChat)
|
||||||
mux.HandleFunc("GET /api/chats/{chatId}/stream", api.GetChatStream)
|
mux.HandleFunc("GET /api/chats/{chatId}/stream", api.GetChatStream)
|
||||||
|
mux.HandleFunc("POST /api/chats/{chatId}/stop", api.StopChatGeneration)
|
||||||
mux.HandleFunc("POST /api/chats/{chatId}", api.PostChatMessage)
|
mux.HandleFunc("POST /api/chats/{chatId}", api.PostChatMessage)
|
||||||
mux.HandleFunc("DELETE /api/chats/{chatId}", api.DeleteChat)
|
mux.HandleFunc("DELETE /api/chats/{chatId}", api.DeleteChat)
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ type MessageStatus string
|
|||||||
const (
|
const (
|
||||||
MessageStatusStreaming MessageStatus = "streaming"
|
MessageStatusStreaming MessageStatus = "streaming"
|
||||||
MessageStatusComplete MessageStatus = "complete"
|
MessageStatusComplete MessageStatus = "complete"
|
||||||
MessageStatusFailed MessageStatus = "failed"
|
MessageStatusStopped MessageStatus = "stopped"
|
||||||
|
MessageStatusError MessageStatus = "error"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
|
|||||||
@@ -38,6 +38,8 @@
|
|||||||
|
|
||||||
shellHook = ''
|
shellHook = ''
|
||||||
export LD_LIBRARY_PATH=${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH
|
export LD_LIBRARY_PATH=${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
. .env
|
||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,12 +68,18 @@
|
|||||||
x-html="renderMarkdown(message.content)"
|
x-html="renderMarkdown(message.content)"
|
||||||
></div>
|
></div>
|
||||||
|
|
||||||
<!-- Timestamp -->
|
<!-- Message Metadata -->
|
||||||
<div class="flex items-center justify-between gap-2 mt-2">
|
<div class="flex items-center justify-between gap-2 mt-2">
|
||||||
<div
|
<div
|
||||||
class="text-[10px] opacity-60"
|
class="text-[10px] opacity-60"
|
||||||
x-text="new Date(message.created_at).toLocaleTimeString()"
|
x-text="new Date(message.created_at).toLocaleTimeString()"
|
||||||
></div>
|
></div>
|
||||||
|
<div
|
||||||
|
x-show="message.role === 'assistant' && ['stopped', 'error', 'failed'].includes(message.status)"
|
||||||
|
:class="message.status === 'stopped' ? 'bg-primary-300/50 text-primary-700' : 'bg-tertiary-100 text-tertiary-700'"
|
||||||
|
class="px-2 py-0.5 rounded-full text-[10px] font-medium"
|
||||||
|
x-text="message.status === 'stopped' ? 'Stopped' : 'Error'"
|
||||||
|
></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -279,6 +285,26 @@
|
|||||||
</svg>
|
</svg>
|
||||||
</template>
|
</template>
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
|
<button
|
||||||
|
x-show="loading"
|
||||||
|
type="button"
|
||||||
|
@click="stopResponse()"
|
||||||
|
:disabled="!activeStreamChatID"
|
||||||
|
:class="!activeStreamChatID ? 'opacity-50 cursor-not-allowed' : 'hover:shadow-md hover:scale-105'"
|
||||||
|
class="self-stretch w-[44px] bg-tertiary-600 text-white rounded-xl transition-all flex items-center justify-center flex-shrink-0"
|
||||||
|
title="Stop response"
|
||||||
|
aria-label="Stop response"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="h-4 w-4"
|
||||||
|
fill="currentColor"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
aria-hidden="true"
|
||||||
|
>
|
||||||
|
<path d="M6 6h12v12H6z" />
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
<!-- Error Message -->
|
<!-- Error Message -->
|
||||||
|
|||||||
@@ -124,6 +124,17 @@ export async function streamChatUpdates(
|
|||||||
return streamMessage(response, onChunk);
|
return streamMessage(response, onChunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function stopChatGeneration(chatId: string): Promise<void> {
|
||||||
|
const response = await fetch(`/api/chats/${chatId}/stop`, {
|
||||||
|
method: 'POST',
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.json().catch(() => ({}));
|
||||||
|
throw new Error(readError(errorData) || `HTTP ${response.status}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export async function getChatMessages(chatId: string): Promise<Chat> {
|
export async function getChatMessages(chatId: string): Promise<Chat> {
|
||||||
const response = await fetch(`/api/chats/${chatId}`);
|
const response = await fetch(`/api/chats/${chatId}`);
|
||||||
const data = await response.json().catch(() => ({}));
|
const data = await response.json().catch(() => ({}));
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import {
|
|||||||
getModels,
|
getModels,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
streamChatUpdates,
|
streamChatUpdates,
|
||||||
|
stopChatGeneration,
|
||||||
getChatMessages,
|
getChatMessages,
|
||||||
listChats,
|
listChats,
|
||||||
deleteChat,
|
deleteChat,
|
||||||
@@ -87,6 +88,18 @@ Alpine.data('chatManager', () => ({
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
async stopResponse() {
|
||||||
|
if (!this.activeStreamChatID) return;
|
||||||
|
|
||||||
|
// Stop Active Generation
|
||||||
|
try {
|
||||||
|
await stopChatGeneration(this.activeStreamChatID);
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error stopping response:', err);
|
||||||
|
this.error = parseError(err);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
async sendMessage() {
|
async sendMessage() {
|
||||||
const message = this.inputMessage.trim();
|
const message = this.inputMessage.trim();
|
||||||
if ((!message && this.selectedImages.length === 0) || this.loading) return;
|
if ((!message && this.selectedImages.length === 0) || this.loading) return;
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ export interface Chat {
|
|||||||
messages: Message[];
|
messages: Message[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export type MessageStatus = 'streaming' | 'complete' | 'failed';
|
export type MessageStatus = 'streaming' | 'complete' | 'stopped' | 'error' | 'failed';
|
||||||
|
|
||||||
export interface Message {
|
export interface Message {
|
||||||
id: string;
|
id: string;
|
||||||
|
|||||||
Reference in New Issue
Block a user