feat: stream persistent
This commit is contained in:
194
backend/internal/api/generation.go
Normal file
194
backend/internal/api/generation.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"reichard.io/aethera/internal/store"
|
||||||
|
"reichard.io/aethera/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errGenerationActive = errors.New("generation already active")
|
||||||
|
|
||||||
|
type generationManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
generations map[uuid.UUID]*generation
|
||||||
|
}
|
||||||
|
|
||||||
|
type generation struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
subscribers map[chan *MessageChunk]struct{}
|
||||||
|
done chan struct{}
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGenerationManager() *generationManager {
|
||||||
|
return &generationManager{generations: make(map[uuid.UUID]*generation)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *generationManager) start(chatID uuid.UUID, prepare func(*generation) error, run func(*generation)) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
if _, found := m.generations[chatID]; found {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return errGenerationActive
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reserve Generation
|
||||||
|
gen := &generation{
|
||||||
|
subscribers: make(map[chan *MessageChunk]struct{}),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
m.generations[chatID] = gen
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
// Prepare Generation - This runs while the generation is reserved so a
|
||||||
|
// concurrent request cannot persist duplicate user/assistant messages.
|
||||||
|
if err := prepare(gen); err != nil {
|
||||||
|
gen.close()
|
||||||
|
m.mu.Lock()
|
||||||
|
delete(m.generations, chatID)
|
||||||
|
m.mu.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run Generation
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
gen.close()
|
||||||
|
m.mu.Lock()
|
||||||
|
delete(m.generations, chatID)
|
||||||
|
m.mu.Unlock()
|
||||||
|
}()
|
||||||
|
run(gen)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *generationManager) subscribe(chatID uuid.UUID) (<-chan *MessageChunk, func(), bool) {
|
||||||
|
m.mu.RLock()
|
||||||
|
gen, found := m.generations[chatID]
|
||||||
|
m.mu.RUnlock()
|
||||||
|
if !found {
|
||||||
|
return nil, func() {}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := gen.subscribe()
|
||||||
|
return ch, func() { gen.unsubscribe(ch) }, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *generation) subscribe() chan *MessageChunk {
|
||||||
|
ch := make(chan *MessageChunk, 64)
|
||||||
|
|
||||||
|
// Add Subscriber
|
||||||
|
g.mu.Lock()
|
||||||
|
if g.closed {
|
||||||
|
close(ch)
|
||||||
|
} else {
|
||||||
|
g.subscribers[ch] = struct{}{}
|
||||||
|
}
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *generation) unsubscribe(ch chan *MessageChunk) {
|
||||||
|
// Remove Subscriber
|
||||||
|
g.mu.Lock()
|
||||||
|
if _, found := g.subscribers[ch]; found {
|
||||||
|
delete(g.subscribers, ch)
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
g.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *generation) broadcast(chunk *MessageChunk) {
|
||||||
|
g.mu.RLock()
|
||||||
|
defer g.mu.RUnlock()
|
||||||
|
|
||||||
|
// Broadcast Chunk
|
||||||
|
for subscriber := range g.subscribers {
|
||||||
|
select {
|
||||||
|
case subscriber <- cloneMessageChunk(chunk):
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *generation) close() {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
|
||||||
|
// Close Subscribers
|
||||||
|
if g.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
g.closed = true
|
||||||
|
close(g.done)
|
||||||
|
for subscriber := range g.subscribers {
|
||||||
|
close(subscriber)
|
||||||
|
delete(g.subscribers, subscriber)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneMessageChunk(chunk *MessageChunk) *MessageChunk {
|
||||||
|
if chunk == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone Chunk
|
||||||
|
cloned := &MessageChunk{
|
||||||
|
Chat: chunk.Chat,
|
||||||
|
UserMessage: cloneStoreMessage(chunk.UserMessage),
|
||||||
|
AssistantMessage: cloneStoreMessage(chunk.AssistantMessage),
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneStoreMessage(msg *store.Message) *store.Message {
|
||||||
|
if msg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone Message
|
||||||
|
cloned := *msg
|
||||||
|
if msg.Stats != nil {
|
||||||
|
stats := *msg.Stats
|
||||||
|
cloned.Stats = &stats
|
||||||
|
cloneMessageStatsPointers(msg.Stats, cloned.Stats)
|
||||||
|
}
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneMessageStatsPointers(src, dst *types.MessageStats) {
|
||||||
|
// Clone Pointer Fields
|
||||||
|
if src.EndTime != nil {
|
||||||
|
v := *src.EndTime
|
||||||
|
dst.EndTime = &v
|
||||||
|
}
|
||||||
|
if src.PromptTokens != nil {
|
||||||
|
v := *src.PromptTokens
|
||||||
|
dst.PromptTokens = &v
|
||||||
|
}
|
||||||
|
if src.GeneratedTokens != nil {
|
||||||
|
v := *src.GeneratedTokens
|
||||||
|
dst.GeneratedTokens = &v
|
||||||
|
}
|
||||||
|
if src.PromptPerSec != nil {
|
||||||
|
v := *src.PromptPerSec
|
||||||
|
dst.PromptPerSec = &v
|
||||||
|
}
|
||||||
|
if src.GeneratedPerSec != nil {
|
||||||
|
v := *src.GeneratedPerSec
|
||||||
|
dst.GeneratedPerSec = &v
|
||||||
|
}
|
||||||
|
if src.TimeToFirstToken != nil {
|
||||||
|
v := *src.TimeToFirstToken
|
||||||
|
dst.TimeToFirstToken = &v
|
||||||
|
}
|
||||||
|
if src.TimeToLastToken != nil {
|
||||||
|
v := *src.TimeToLastToken
|
||||||
|
dst.TimeToLastToken = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -28,6 +28,7 @@ type API struct {
|
|||||||
store store.Store
|
store store.Store
|
||||||
client *client.Client
|
client *client.Client
|
||||||
dataDir string
|
dataDir string
|
||||||
|
generationManager *generationManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(s store.Store, dataDir string, logger *logrus.Logger) *API {
|
func New(s store.Store, dataDir string, logger *logrus.Logger) *API {
|
||||||
@@ -35,6 +36,7 @@ func New(s store.Store, dataDir string, logger *logrus.Logger) *API {
|
|||||||
store: s,
|
store: s,
|
||||||
dataDir: dataDir,
|
dataDir: dataDir,
|
||||||
logger: logger.WithField("service", "api"),
|
logger: logger.WithField("service", "api"),
|
||||||
|
generationManager: newGenerationManager(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -322,13 +324,18 @@ func (a *API) PostChat(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send Message
|
// Start Message
|
||||||
responseStarted, err := a.sendMessage(r.Context(), w, chat.ID, genReq.Model, genReq.Prompt)
|
chunk, err := a.startMessageGeneration(chat.ID, genReq.Model, genReq.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).WithField("chat_id", chat.ID).Error("failed to send message")
|
log.WithError(err).WithField("chat_id", chat.ID).Error("failed to start message generation")
|
||||||
if !responseStarted {
|
http.Error(w, "Failed to start message generation", http.StatusInternalServerError)
|
||||||
http.Error(w, "Failed to send message", http.StatusInternalServerError)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(chunk); err != nil {
|
||||||
|
log.WithError(err).Error("failed to encode message generation response")
|
||||||
|
http.Error(w, "Failed to encode message generation response", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -395,6 +402,68 @@ func (a *API) GetChat(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *API) GetChatStream(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log := a.logger.WithField("handler", "GetChatStreamHandler")
|
||||||
|
|
||||||
|
// Parse Chat ID
|
||||||
|
rawChatID := r.PathValue("chatId")
|
||||||
|
if rawChatID == "" {
|
||||||
|
log.Error("missing chat ID parameter")
|
||||||
|
http.Error(w, "Chat ID is required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
chatID, err := uuid.Parse(rawChatID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("invalid chat ID format")
|
||||||
|
http.Error(w, "Invalid chat ID format", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe Before Snapshot
|
||||||
|
updates, unsubscribe, active := a.generationManager.subscribe(chatID)
|
||||||
|
defer unsubscribe()
|
||||||
|
|
||||||
|
// Get Chat Snapshot
|
||||||
|
chat, err := a.store.GetChat(chatID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).WithField("chat_id", chatID).Error("failed to get chat")
|
||||||
|
http.Error(w, "Failed to get chat", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set Headers
|
||||||
|
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
flushWriter := newFlushWriter(w)
|
||||||
|
|
||||||
|
// Send Snapshot
|
||||||
|
if err := json.NewEncoder(flushWriter).Encode(&MessageChunk{Chat: toChat(chat)}); err != nil {
|
||||||
|
log.WithError(err).WithField("chat_id", chatID).Warn("failed to send stream snapshot")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !active {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward Updates
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case chunk, ok := <-updates:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := json.NewEncoder(flushWriter).Encode(chunk); err != nil {
|
||||||
|
log.WithError(err).WithField("chat_id", chatID).Warn("client stream disconnected")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) {
|
func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) {
|
||||||
log := a.logger.WithField("handler", "PostChatMessageHandler")
|
log := a.logger.WithField("handler", "PostChatMessageHandler")
|
||||||
|
|
||||||
@@ -424,13 +493,22 @@ func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send Message
|
// Start Message
|
||||||
responseStarted, err := a.sendMessage(r.Context(), w, chatID, genReq.Model, genReq.Prompt)
|
chunk, err := a.startMessageGeneration(chatID, genReq.Model, genReq.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).WithField("chat_id", chatID).Error("failed to send message")
|
log.WithError(err).WithField("chat_id", chatID).Error("failed to start message generation")
|
||||||
if !responseStarted {
|
if errors.Is(err, errGenerationActive) {
|
||||||
http.Error(w, "Failed to send message", http.StatusInternalServerError)
|
http.Error(w, "Chat generation already active", http.StatusConflict)
|
||||||
|
} else {
|
||||||
|
http.Error(w, "Failed to start message generation", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(chunk); err != nil {
|
||||||
|
log.WithError(err).Error("failed to encode message generation response")
|
||||||
|
http.Error(w, "Failed to encode message generation response", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -456,99 +534,99 @@ func (a *API) getClient() (*client.Client, error) {
|
|||||||
return a.client, nil
|
return a.client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) sendMessage(ctx context.Context, w http.ResponseWriter, chatID uuid.UUID, chatModel, userMessage string) (bool, error) {
|
func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage string) (*MessageChunk, error) {
|
||||||
apiClient, err := a.getClient()
|
apiClient, err := a.getClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to get client: %w", err)
|
return nil, fmt.Errorf("failed to get client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detach Request Context
|
var chat *store.Chat
|
||||||
ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Minute*5)
|
var userMsg *store.Message
|
||||||
defer cancel()
|
var assistantMsg *store.Message
|
||||||
|
var initialChunk *MessageChunk
|
||||||
|
|
||||||
|
// Start Generation - The manager reserves the chat before messages are
|
||||||
|
// persisted, preventing concurrent completions from creating duplicate rows.
|
||||||
|
if err := a.generationManager.start(chatID, func(_ *generation) error {
|
||||||
// Create User Message
|
// Create User Message
|
||||||
userMsg := &store.Message{ChatID: chatID, Role: "user", Content: userMessage}
|
userMsg = &store.Message{ChatID: chatID, Role: "user", Content: userMessage}
|
||||||
if err := a.store.SaveChatMessage(userMsg); err != nil {
|
if err := a.store.SaveChatMessage(userMsg); err != nil {
|
||||||
return false, fmt.Errorf("failed to add user message to chat: %w", err)
|
return fmt.Errorf("failed to add user message to chat: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get Chat History - Fetch before creating the in-progress assistant message so the
|
// Get Chat History - Fetch before creating the in-progress assistant message so the
|
||||||
// LLM request does not include an empty assistant response prefill.
|
// LLM request does not include an empty assistant response prefill.
|
||||||
chat, err := a.store.GetChat(chatID)
|
chat, err = a.store.GetChat(chatID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to get chat: %w", err)
|
return fmt.Errorf("failed to get chat: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Assistant Response - TODO: Ensure InProgress Flag?
|
// Add Assistant Response
|
||||||
assistantMsg := &store.Message{ChatID: chatID, Role: "assistant"}
|
assistantMsg = &store.Message{ChatID: chatID, Role: "assistant", Status: store.MessageStatusStreaming}
|
||||||
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
||||||
return false, fmt.Errorf("failed to add assistant message to chat: %w", err)
|
return fmt.Errorf("failed to add assistant message to chat: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set Headers
|
// Create Initial Chunk
|
||||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
initialChunk = &MessageChunk{
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
w.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
|
|
||||||
// Create Flush Writer
|
|
||||||
flushWriter := newFlushWriter(w)
|
|
||||||
|
|
||||||
// Send Initial Chunk - User Message & Chat
|
|
||||||
if err := json.NewEncoder(flushWriter).Encode(&MessageChunk{
|
|
||||||
Chat: toChatNoMessages(chat),
|
Chat: toChatNoMessages(chat),
|
||||||
UserMessage: userMsg,
|
UserMessage: userMsg,
|
||||||
}); err != nil {
|
AssistantMessage: assistantMsg,
|
||||||
return false, fmt.Errorf("failed to send initial chunk: %w", err)
|
|
||||||
}
|
}
|
||||||
responseStarted := true
|
return nil
|
||||||
streamToClient := true
|
}, func(gen *generation) {
|
||||||
|
a.runMessageGeneration(apiClient, chat, assistantMsg, chatModel, gen)
|
||||||
|
}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return initialChunk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, assistantMsg *store.Message, chatModel string, gen *generation) {
|
||||||
|
// Create Generation Context
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// Send Message
|
// Send Message
|
||||||
if _, err := apiClient.SendMessage(ctx, chat.Messages, chatModel, func(m *client.MessageChunk) error {
|
if _, err := apiClient.SendMessage(ctx, chat.Messages, chatModel, func(m *client.MessageChunk) error {
|
||||||
var apiMsgChunk MessageChunk
|
|
||||||
messageChanged := false
|
messageChanged := false
|
||||||
|
|
||||||
if m.Stats != nil {
|
if m.Stats != nil {
|
||||||
messageChanged = true
|
messageChanged = true
|
||||||
assistantMsg.Stats = m.Stats
|
assistantMsg.Stats = m.Stats
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Message != nil {
|
if m.Message != nil {
|
||||||
messageChanged = true
|
messageChanged = true
|
||||||
assistantMsg.Content += *m.Message
|
assistantMsg.Content += *m.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Thinking != nil {
|
if m.Thinking != nil {
|
||||||
messageChanged = true
|
messageChanged = true
|
||||||
assistantMsg.Thinking += *m.Thinking
|
assistantMsg.Thinking += *m.Thinking
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save Assistant Progress - Persist each streamed update so partial content
|
// Save And Broadcast Progress
|
||||||
// survives client disconnects or upstream stream failures.
|
|
||||||
if messageChanged {
|
if messageChanged {
|
||||||
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
||||||
return fmt.Errorf("failed to save assistant progress: %w", err)
|
return fmt.Errorf("failed to save assistant progress: %w", err)
|
||||||
}
|
}
|
||||||
apiMsgChunk.AssistantMessage = assistantMsg
|
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
||||||
}
|
|
||||||
|
|
||||||
// Send Progress Chunk - If the browser disconnects, keep the detached
|
|
||||||
// generation running and continue saving streamed content to the store.
|
|
||||||
if streamToClient {
|
|
||||||
if err := json.NewEncoder(flushWriter).Encode(apiMsgChunk); err != nil {
|
|
||||||
streamToClient = false
|
|
||||||
a.logger.WithError(err).WithField("chat_id", chat.ID).Warn("client stream disconnected")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return responseStarted, fmt.Errorf("failed to generate text stream: %w", err)
|
assistantMsg.Status = store.MessageStatusFailed
|
||||||
|
if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil {
|
||||||
|
a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save failed assistant message")
|
||||||
|
}
|
||||||
|
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
||||||
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to generate text stream")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Summarize & Update Chat Title
|
// Summarize & Update Chat Title
|
||||||
if chat.Title == "" {
|
if chat.Title == "" {
|
||||||
|
var err error
|
||||||
chat.Title, err = apiClient.CreateTitle(ctx, chat.Messages[0].Content, chatModel)
|
chat.Title, err = apiClient.CreateTitle(ctx, chat.Messages[0].Content, chatModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to create chat title")
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to create chat title")
|
||||||
@@ -557,20 +635,11 @@ func (a *API) sendMessage(ctx context.Context, w http.ResponseWriter, chatID uui
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update Assistant Message
|
// Complete Assistant Message
|
||||||
|
assistantMsg.Status = store.MessageStatusComplete
|
||||||
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
||||||
return responseStarted, fmt.Errorf("failed to save assistant message to chat: %w", err)
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to save assistant message")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
gen.broadcast(&MessageChunk{Chat: toChatNoMessages(chat), AssistantMessage: assistantMsg})
|
||||||
// Send Final Chunk
|
|
||||||
if streamToClient {
|
|
||||||
if err := json.NewEncoder(flushWriter).Encode(&MessageChunk{
|
|
||||||
Chat: toChatNoMessages(chat),
|
|
||||||
AssistantMessage: assistantMsg,
|
|
||||||
}); err != nil {
|
|
||||||
return responseStarted, fmt.Errorf("failed to send final chunk: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return responseStarted, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ func StartServer(settingsStore store.Store, dataDir, staticDir, listenAddress st
|
|||||||
mux.HandleFunc("GET /api/chats", api.GetChats)
|
mux.HandleFunc("GET /api/chats", api.GetChats)
|
||||||
mux.HandleFunc("POST /api/chats", api.PostChat)
|
mux.HandleFunc("POST /api/chats", api.PostChat)
|
||||||
mux.HandleFunc("GET /api/chats/{chatId}", api.GetChat)
|
mux.HandleFunc("GET /api/chats/{chatId}", api.GetChat)
|
||||||
|
mux.HandleFunc("GET /api/chats/{chatId}/stream", api.GetChatStream)
|
||||||
mux.HandleFunc("POST /api/chats/{chatId}", api.PostChatMessage)
|
mux.HandleFunc("POST /api/chats/{chatId}", api.PostChatMessage)
|
||||||
mux.HandleFunc("DELETE /api/chats/{chatId}", api.DeleteChat)
|
mux.HandleFunc("DELETE /api/chats/{chatId}", api.DeleteChat)
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,14 @@ type Chat struct {
|
|||||||
Messages []*Message `json:"messages"`
|
Messages []*Message `json:"messages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MessageStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageStatusStreaming MessageStatus = "streaming"
|
||||||
|
MessageStatusComplete MessageStatus = "complete"
|
||||||
|
MessageStatusFailed MessageStatus = "failed"
|
||||||
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
baseModel
|
baseModel
|
||||||
|
|
||||||
@@ -35,5 +43,6 @@ type Message struct {
|
|||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Thinking string `json:"thinking"`
|
Thinking string `json:"thinking"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
Status MessageStatus `json:"status,omitempty"`
|
||||||
Stats *types.MessageStats `json:"stats,omitempty"`
|
Stats *types.MessageStats `json:"stats,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,13 +90,37 @@ export async function sendMessage(
|
|||||||
requestData: GenerateTextRequest,
|
requestData: GenerateTextRequest,
|
||||||
onChunk: (chunk: MessageChunk) => void,
|
onChunk: (chunk: MessageChunk) => void,
|
||||||
) {
|
) {
|
||||||
|
const initialChunk = await startMessage(chatId, requestData);
|
||||||
|
onChunk(initialChunk);
|
||||||
|
|
||||||
|
if (!initialChunk.chat) return;
|
||||||
|
return streamChatUpdates(initialChunk.chat.id, onChunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function startMessage(
|
||||||
|
chatId: string,
|
||||||
|
requestData: GenerateTextRequest,
|
||||||
|
): Promise<MessageChunk> {
|
||||||
const url = chatId ? `/api/chats/${chatId}` : '/api/chats';
|
const url = chatId ? `/api/chats/${chatId}` : '/api/chats';
|
||||||
const response = await fetch(url, {
|
const response = await fetch(url, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify(requestData),
|
body: JSON.stringify(requestData),
|
||||||
});
|
});
|
||||||
|
const data = await response.json().catch(() => ({}));
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(readError(data) || `HTTP ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function streamChatUpdates(
|
||||||
|
chatId: string,
|
||||||
|
onChunk: (chunk: MessageChunk) => void,
|
||||||
|
) {
|
||||||
|
const response = await fetch(`/api/chats/${chatId}/stream`);
|
||||||
return streamMessage(response, onChunk);
|
return streamMessage(response, onChunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,6 +157,15 @@ export async function deleteChat(chatId: string): Promise<void> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function readError(data: unknown): string {
|
||||||
|
if (typeof data !== 'object' || data === null || !('error' in data)) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
|
const error = data.error;
|
||||||
|
return typeof error === 'string' ? error : '';
|
||||||
|
}
|
||||||
|
|
||||||
async function streamMessage(
|
async function streamMessage(
|
||||||
response: Response,
|
response: Response,
|
||||||
onChunk: (chunk: MessageChunk) => void,
|
onChunk: (chunk: MessageChunk) => void,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import {
|
|||||||
getSettings,
|
getSettings,
|
||||||
getModels,
|
getModels,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
|
streamChatUpdates,
|
||||||
getChatMessages,
|
getChatMessages,
|
||||||
listChats,
|
listChats,
|
||||||
deleteChat,
|
deleteChat,
|
||||||
@@ -41,6 +42,7 @@ Alpine.data('chatManager', () => ({
|
|||||||
selectedChatID: null as string | null,
|
selectedChatID: null as string | null,
|
||||||
chatListOpen: false,
|
chatListOpen: false,
|
||||||
loading: false,
|
loading: false,
|
||||||
|
activeStreamChatID: null as string | null,
|
||||||
|
|
||||||
async init() {
|
async init() {
|
||||||
// Acquire Data
|
// Acquire Data
|
||||||
@@ -109,66 +111,27 @@ Alpine.data('chatManager', () => ({
|
|||||||
this.selectedChatID = IN_PROGRESS_UUID;
|
this.selectedChatID = IN_PROGRESS_UUID;
|
||||||
}
|
}
|
||||||
|
|
||||||
// New User Message
|
// Add Optimistic User Message
|
||||||
let userMessage: Message = {
|
const currentChat: Chat = this.chats.find(
|
||||||
|
(c) => c.id === this.selectedChatID,
|
||||||
|
)!;
|
||||||
|
currentChat.messages.push({
|
||||||
id: IN_PROGRESS_UUID,
|
id: IN_PROGRESS_UUID,
|
||||||
chat_id: this.selectedChatID,
|
chat_id: this.selectedChatID,
|
||||||
role: 'user',
|
role: 'user',
|
||||||
thinking: '',
|
thinking: '',
|
||||||
content: message,
|
content: message,
|
||||||
created_at: new Date().toISOString(),
|
created_at: new Date().toISOString(),
|
||||||
};
|
});
|
||||||
|
|
||||||
// Get Chat
|
|
||||||
let currentChat: Chat = this.chats.find(
|
|
||||||
(c) => c.id === this.selectedChatID,
|
|
||||||
)!;
|
|
||||||
|
|
||||||
// Add User Message
|
|
||||||
currentChat.messages.push(userMessage);
|
|
||||||
currentChat.message_count += 1;
|
currentChat.message_count += 1;
|
||||||
|
|
||||||
// Assistant Message Placeholder
|
|
||||||
let assistantMessage: Message | undefined;
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await sendMessage(
|
await sendMessage(
|
||||||
this.selectedChatID === IN_PROGRESS_UUID ? '' : this.selectedChatID,
|
this.selectedChatID === IN_PROGRESS_UUID ? '' : this.selectedChatID,
|
||||||
{ model: this.selectedModel, prompt: message },
|
{ model: this.selectedModel, prompt: message },
|
||||||
(chunk: MessageChunk) => {
|
(chunk: MessageChunk) => {
|
||||||
// Handle Chat
|
if (chunk.chat) this.activeStreamChatID = chunk.chat.id;
|
||||||
if (chunk.chat) {
|
this.applyMessageChunk(chunk);
|
||||||
Object.assign(currentChat, {
|
|
||||||
...chunk.chat,
|
|
||||||
messages: currentChat.messages,
|
|
||||||
});
|
|
||||||
this.selectedChatID = chunk.chat.id;
|
|
||||||
this.updateHash(chunk.chat.id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle User Message
|
|
||||||
if (chunk.user_message) {
|
|
||||||
Object.assign(userMessage, chunk.user_message);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle Assistant Message
|
|
||||||
if (chunk.assistant_message) {
|
|
||||||
if (!assistantMessage) {
|
|
||||||
assistantMessage = chunk.assistant_message;
|
|
||||||
currentChat.messages.push(assistantMessage);
|
|
||||||
} else {
|
|
||||||
const index = currentChat.messages.findIndex(
|
|
||||||
(m) => m.id === assistantMessage!.id,
|
|
||||||
);
|
|
||||||
if (index !== -1) {
|
|
||||||
currentChat.messages[index] = {
|
|
||||||
...assistantMessage,
|
|
||||||
...chunk.assistant_message,
|
|
||||||
};
|
|
||||||
currentChat.messages = [...currentChat.messages];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
@@ -176,9 +139,56 @@ Alpine.data('chatManager', () => ({
|
|||||||
this.error = parseError(err);
|
this.error = parseError(err);
|
||||||
} finally {
|
} finally {
|
||||||
this.loading = false;
|
this.loading = false;
|
||||||
|
this.activeStreamChatID = null;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
applyMessageChunk(chunk: MessageChunk) {
|
||||||
|
// Handle Chat
|
||||||
|
if (chunk.chat) {
|
||||||
|
let chat = this.chats.find((c) => c.id === chunk.chat!.id);
|
||||||
|
if (!chat) chat = this.chats.find((c) => c.id === IN_PROGRESS_UUID);
|
||||||
|
if (!chat) {
|
||||||
|
chat = { ...chunk.chat, messages: chunk.chat.messages || [] };
|
||||||
|
this.chats.unshift(chat);
|
||||||
|
} else {
|
||||||
|
Object.assign(chat, chunk.chat);
|
||||||
|
chat.messages = chunk.chat.messages?.length
|
||||||
|
? chunk.chat.messages
|
||||||
|
: chat.messages;
|
||||||
|
}
|
||||||
|
this.selectedChatID = chunk.chat.id;
|
||||||
|
this.updateHash(chunk.chat.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
const chatID = chunk.chat?.id || this.selectedChatID;
|
||||||
|
const currentChat = this.chats.find((c) => c.id === chatID);
|
||||||
|
if (!currentChat) return;
|
||||||
|
|
||||||
|
// Handle Messages
|
||||||
|
if (chunk.user_message) this.upsertMessage(currentChat, chunk.user_message);
|
||||||
|
if (chunk.assistant_message)
|
||||||
|
this.upsertMessage(currentChat, chunk.assistant_message);
|
||||||
|
},
|
||||||
|
|
||||||
|
upsertMessage(chat: Chat, message: Message) {
|
||||||
|
// Upsert Message
|
||||||
|
const existingIndex = chat.messages.findIndex(
|
||||||
|
(m) =>
|
||||||
|
m.id === message.id ||
|
||||||
|
(m.id === IN_PROGRESS_UUID && m.role === message.role),
|
||||||
|
);
|
||||||
|
if (existingIndex === -1) {
|
||||||
|
chat.messages.push(message);
|
||||||
|
} else {
|
||||||
|
chat.messages[existingIndex] = {
|
||||||
|
...chat.messages[existingIndex],
|
||||||
|
...message,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
chat.messages = [...chat.messages];
|
||||||
|
},
|
||||||
|
|
||||||
updateHash(chatID: string | null) {
|
updateHash(chatID: string | null) {
|
||||||
const newRoute = CHAT_ROUTE + (chatID ? '/' + chatID : '');
|
const newRoute = CHAT_ROUTE + (chatID ? '/' + chatID : '');
|
||||||
window.history.pushState(null, '', newRoute);
|
window.history.pushState(null, '', newRoute);
|
||||||
@@ -202,13 +212,46 @@ Alpine.data('chatManager', () => ({
|
|||||||
(c) => c.id == this.selectedChatID,
|
(c) => c.id == this.selectedChatID,
|
||||||
);
|
);
|
||||||
|
|
||||||
this.chats[chatIndex].messages = response.messages || [];
|
if (chatIndex === -1) return;
|
||||||
|
this.chats[chatIndex] = {
|
||||||
|
...this.chats[chatIndex],
|
||||||
|
...response,
|
||||||
|
messages: response.messages || [],
|
||||||
|
};
|
||||||
|
await this.reconnectChatStream(response);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Error loading chat messages:', err);
|
console.error('Error loading chat messages:', err);
|
||||||
this.error = 'Failed to load messages';
|
this.error = 'Failed to load messages';
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
async reconnectChatStream(chat: Chat) {
|
||||||
|
const latestMessage = chat.messages[chat.messages.length - 1];
|
||||||
|
if (
|
||||||
|
!latestMessage ||
|
||||||
|
latestMessage.role !== 'assistant' ||
|
||||||
|
latestMessage.status !== 'streaming' ||
|
||||||
|
this.activeStreamChatID === chat.id
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconnect Stream
|
||||||
|
this.loading = true;
|
||||||
|
this.activeStreamChatID = chat.id;
|
||||||
|
try {
|
||||||
|
await streamChatUpdates(chat.id, (chunk: MessageChunk) =>
|
||||||
|
this.applyMessageChunk(chunk),
|
||||||
|
);
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error reconnecting chat stream:', err);
|
||||||
|
this.error = parseError(err);
|
||||||
|
} finally {
|
||||||
|
this.loading = false;
|
||||||
|
this.activeStreamChatID = null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
get models(): Model[] {
|
get models(): Model[] {
|
||||||
if (!this.settings.text_generation_selector) return this._models;
|
if (!this.settings.text_generation_selector) return this._models;
|
||||||
return applyFilter(this._models, this.settings.text_generation_selector);
|
return applyFilter(this._models, this.settings.text_generation_selector);
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ export interface Chat {
|
|||||||
messages: Message[];
|
messages: Message[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type MessageStatus = 'streaming' | 'complete' | 'failed';
|
||||||
|
|
||||||
export interface Message {
|
export interface Message {
|
||||||
id: string;
|
id: string;
|
||||||
chat_id: string;
|
chat_id: string;
|
||||||
@@ -14,6 +16,7 @@ export interface Message {
|
|||||||
role: 'user' | 'assistant';
|
role: 'user' | 'assistant';
|
||||||
thinking: string;
|
thinking: string;
|
||||||
content: string;
|
content: string;
|
||||||
|
status?: MessageStatus;
|
||||||
stats?: MessageStats;
|
stats?: MessageStats;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user