- Restructure floating input: dominant textarea with compact bottom toolbar (model badge, thinking toggle, attach, send/stop). - Model badge sizes to the current selection (not widest option) via a layered transparent select, with truncate-on-overflow fallback. - Auto-expand the conversation sidebar on desktop and slide chat content right when open instead of overlaying. - Add per-request thinking toggle (brain icon, default on, persisted in localStorage) sending chat_template_kwargs.enable_thinking. - Always disable thinking for title summarization. - Generate chat titles before the main response to keep the SSE stream from staying open past visible completion and to avoid busting the KV cache between turns.
712 lines
22 KiB
Go
712 lines
22 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/sirupsen/logrus"
|
|
"reichard.io/aethera/internal/client"
|
|
"reichard.io/aethera/internal/store"
|
|
"reichard.io/aethera/pkg/slices"
|
|
)
|
|
|
|
type API struct {
|
|
logger *logrus.Entry
|
|
store store.Store
|
|
client *client.Client
|
|
dataDir string
|
|
llmEndpoint string
|
|
llmKey string
|
|
generationManager *generationManager
|
|
}
|
|
|
|
func New(s store.Store, dataDir string, logger *logrus.Logger, llmEndpoint, llmKey string) *API {
|
|
return &API{
|
|
store: s,
|
|
dataDir: dataDir,
|
|
logger: logger.WithField("service", "api"),
|
|
llmEndpoint: llmEndpoint,
|
|
llmKey: llmKey,
|
|
generationManager: newGenerationManager(),
|
|
}
|
|
}
|
|
|
|
func normalizeSettings(settings *store.Settings) {
|
|
// Default Text Generation Timeout
|
|
if settings.TextGenerationTimeoutMinutes == 0 {
|
|
settings.TextGenerationTimeoutMinutes = 5
|
|
}
|
|
|
|
// Validate Text Generation Timeout
|
|
switch settings.TextGenerationTimeoutMinutes {
|
|
case 1, 5, 10, 15, 30:
|
|
return
|
|
default:
|
|
settings.TextGenerationTimeoutMinutes = 5
|
|
}
|
|
}
|
|
|
|
func (a *API) textGenerationTimeout() time.Duration {
|
|
// Load Settings
|
|
settings, err := a.store.GetSettings()
|
|
if err != nil {
|
|
a.logger.WithError(err).Error("failed to retrieve settings for text generation timeout")
|
|
return 5 * time.Minute
|
|
}
|
|
|
|
// Normalize Timeout
|
|
normalizeSettings(settings)
|
|
return time.Duration(settings.TextGenerationTimeoutMinutes) * time.Minute
|
|
}
|
|
|
|
func (a *API) GetSettings(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "GetSettingsHandler")
|
|
|
|
settings, err := a.store.GetSettings()
|
|
if err != nil {
|
|
log.WithError(err).Error("failed to retrieve settings")
|
|
http.Error(w, "Failed to retrieve application settings", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Normalize Settings
|
|
normalizeSettings(settings)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(settings); err != nil {
|
|
log.WithError(err).Error("failed to encode application settings response")
|
|
http.Error(w, "Failed to encode application settings response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *API) PostSettings(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "PostSettingsHandler")
|
|
|
|
var newSettings store.Settings
|
|
if err := json.NewDecoder(r.Body).Decode(&newSettings); err != nil {
|
|
log.WithError(err).Error("invalid JSON in settings update request")
|
|
http.Error(w, "Invalid request body format for settings", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Normalize Settings
|
|
normalizeSettings(&newSettings)
|
|
|
|
if err := a.store.SaveSettings(&newSettings); err != nil {
|
|
log.WithError(err).Error("failed to save settings")
|
|
http.Error(w, "Failed to save application settings", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func (a *API) GetModels(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "GetModelsHandler")
|
|
|
|
client, err := a.getClient()
|
|
if err != nil {
|
|
log.WithError(err).Error("failed to initialize API client")
|
|
http.Error(w, "Failed to initialize API client", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
models, err := client.GetModels(r.Context())
|
|
if err != nil {
|
|
log.WithError(err).Error("failed to retrieve available models")
|
|
http.Error(w, "Failed to retrieve available models from API", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(models); err != nil {
|
|
log.WithError(err).Error("failed to encode available models response")
|
|
http.Error(w, "Failed to encode available models response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *API) GetImages(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "GetImagesHandler")
|
|
|
|
files, err := os.ReadDir(path.Join(a.dataDir, "generated/images"))
|
|
if err != nil {
|
|
log.WithError(err).Error("failed to read images directory")
|
|
http.Error(w, "Failed to read images directory", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
imageList := make([]ImageRecord, 0)
|
|
for _, file := range files {
|
|
if !file.IsDir() && strings.HasSuffix(strings.ToLower(file.Name()), ".png") {
|
|
info, err := file.Info()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
imageList = append(imageList, ImageRecord{
|
|
Name: file.Name(),
|
|
Path: "/generated/images/" + file.Name(),
|
|
Size: info.Size(),
|
|
Date: info.ModTime().Format(time.RFC3339),
|
|
})
|
|
}
|
|
}
|
|
sort.Slice(imageList, func(i, j int) bool {
|
|
return imageList[i].Date > imageList[j].Date
|
|
})
|
|
|
|
if err := json.NewEncoder(w).Encode(imageList); err != nil {
|
|
log.WithError(err).Error("failed to encode image list metadata response")
|
|
http.Error(w, "Failed to encode image list metadata response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *API) PostImage(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "PostImageHandler")
|
|
|
|
client, err := a.getClient()
|
|
if err != nil {
|
|
log.WithError(err).Error("failed to initialize API client")
|
|
http.Error(w, "Failed to initialize API client", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var genReq GenerateImageRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&genReq); err != nil {
|
|
log.WithError(err).Error("invalid JSON in image generation request")
|
|
http.Error(w, "Invalid request body format for image generation", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err := genReq.Validate(); err != nil {
|
|
log.WithError(err).Error("invalid request")
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Edit vs Generate Request
|
|
var images []openai.Image
|
|
var reqErr error
|
|
if genReq.isEdit() {
|
|
editParams, err := genReq.getEditParams()
|
|
if err != nil {
|
|
log.WithError(err).Error("invalid image edit parameters")
|
|
http.Error(w, "Invalid image edit parameters", http.StatusBadRequest)
|
|
return
|
|
|
|
}
|
|
images, reqErr = client.EditImage(r.Context(), *editParams)
|
|
} else {
|
|
genParams, err := genReq.getGenerateParams()
|
|
if err != nil {
|
|
log.WithError(err).Error("invalid image generation parameters")
|
|
http.Error(w, "Invalid image generation parameters", http.StatusBadRequest)
|
|
return
|
|
|
|
}
|
|
images, reqErr = client.GenerateImages(r.Context(), *genParams)
|
|
}
|
|
|
|
// Check Error
|
|
if reqErr != nil {
|
|
log.WithError(reqErr).Error("failed to generate images")
|
|
http.Error(w, "Failed to generate images via API", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Normalize Responses
|
|
imageRecords := make([]ImageRecord, 0)
|
|
for i, img := range images {
|
|
if img.B64JSON == "" {
|
|
log.Warnf("empty image data at index %d, skipping", i)
|
|
continue
|
|
}
|
|
|
|
// Decode Image
|
|
imgBytes, err := base64.StdEncoding.DecodeString(img.B64JSON)
|
|
if err != nil {
|
|
log.WithError(err).WithField("index", i).Error("failed to decode image")
|
|
continue
|
|
}
|
|
|
|
// Save Image
|
|
filename := fmt.Sprintf("image_%d_%d.png", time.Now().Unix(), i)
|
|
filePath := path.Join(a.dataDir, "generated/images", filename)
|
|
if err := os.WriteFile(filePath, imgBytes, 0644); err != nil {
|
|
log.WithError(err).WithField("file", filePath).Error("failed to save generated image")
|
|
continue
|
|
}
|
|
|
|
// Record Image
|
|
imageRecords = append(imageRecords, ImageRecord{
|
|
Name: filename,
|
|
Path: fmt.Sprintf("/generated/images/%s", filename),
|
|
Date: time.Now().Format(time.RFC3339),
|
|
Size: int64(len(imgBytes)),
|
|
})
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(imageRecords); err != nil {
|
|
log.WithError(err).Error("failed to encode generated images response")
|
|
http.Error(w, "Failed to encode generated images response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *API) DeleteImage(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "DeleteImageHandler")
|
|
|
|
filename := r.PathValue("filename")
|
|
if filename == "" {
|
|
log.Error("missing filename parameter")
|
|
http.Error(w, "Filename parameter is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Delete Image
|
|
imgDir := path.Join(a.dataDir, "generated/images")
|
|
safePath := path.Join(imgDir, filepath.Base(filename))
|
|
if err := os.Remove(safePath); err != nil {
|
|
log.WithError(err).WithField("file", safePath).Error("failed to delete image file")
|
|
http.Error(w, "Failed to delete image file", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (a *API) GetChats(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "GetChatsHandler")
|
|
|
|
chats, err := a.store.ListChats()
|
|
if err != nil {
|
|
log.WithError(err).Error("failed to list chats")
|
|
http.Error(w, "Failed to retrieve chats", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
sort.Slice(chats, func(i, j int) bool {
|
|
iLast, iFound := slices.Last(chats[i].Messages)
|
|
if !iFound {
|
|
return false
|
|
}
|
|
|
|
jLast, jFound := slices.Last(chats[j].Messages)
|
|
if !jFound {
|
|
return true
|
|
}
|
|
|
|
return iLast.CreatedAt.After(jLast.CreatedAt)
|
|
})
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(ChatListResponse{Chats: slices.Map(chats, toChatNoMessages)}); err != nil {
|
|
log.WithError(err).Error("failed to encode chats list response")
|
|
http.Error(w, "Failed to encode chats list response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *API) PostChat(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "PostChatHandler")
|
|
|
|
// Decode Request
|
|
var genReq GenerateTextRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&genReq); err != nil {
|
|
log.WithError(err).Error("invalid JSON in text generation request")
|
|
http.Error(w, "Invalid request body format for new chat", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err := genReq.Validate(); err != nil {
|
|
log.WithError(err).Error("invalid request")
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Create Chat
|
|
var chat store.Chat
|
|
if err := a.store.SaveChat(&chat); err != nil {
|
|
log.WithError(err).Error("failed to create new chat")
|
|
http.Error(w, "Failed to create new chat", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Start Message
|
|
chunk, err := a.startMessageGeneration(chat.ID, genReq.Model, genReq.Prompt, genReq.Images, genReq.EnableThinking())
|
|
if err != nil {
|
|
log.WithError(err).WithField("chat_id", chat.ID).Error("failed to start message generation")
|
|
http.Error(w, "Failed to start message generation", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(chunk); err != nil {
|
|
log.WithError(err).Error("failed to encode message generation response")
|
|
http.Error(w, "Failed to encode message generation response", http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func (a *API) DeleteChat(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "DeleteChatHandler")
|
|
|
|
chatIDStr := r.PathValue("chatId")
|
|
if chatIDStr == "" {
|
|
log.Error("missing chat ID parameter")
|
|
http.Error(w, "Chat ID is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
chatID, err := uuid.Parse(chatIDStr)
|
|
if err != nil {
|
|
log.WithError(err).Error("invalid chat ID format")
|
|
http.Error(w, "Invalid chat ID format", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Delete Chat
|
|
if err := a.store.DeleteChat(chatID); err != nil {
|
|
log.WithError(err).WithField("chat_id", chatID).Error("failed to delete chat")
|
|
if errors.Is(err, store.ErrChatNotFound) {
|
|
http.Error(w, "Chat not found", http.StatusNotFound)
|
|
} else {
|
|
http.Error(w, "Failed to delete chat", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (a *API) GetChat(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "GetChatHandler")
|
|
|
|
chatID := r.PathValue("chatId")
|
|
if chatID == "" {
|
|
log.Error("missing chat ID parameter")
|
|
http.Error(w, "Chat ID is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
parsedChatID, err := uuid.Parse(chatID)
|
|
if err != nil {
|
|
log.WithError(err).Error("invalid chat ID format")
|
|
http.Error(w, "Invalid chat ID format", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
chat, err := a.store.GetChat(parsedChatID)
|
|
if err != nil {
|
|
log.WithError(err).WithField("chat_id", parsedChatID).Error("failed to get chat")
|
|
http.Error(w, "Failed to get chat", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(toChat(chat)); err != nil {
|
|
log.WithError(err).Error("failed to encode chat messages response")
|
|
http.Error(w, "Failed to encode chat messages response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *API) GetChatStream(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "GetChatStreamHandler")
|
|
|
|
// Parse Chat ID
|
|
rawChatID := r.PathValue("chatId")
|
|
if rawChatID == "" {
|
|
log.Error("missing chat ID parameter")
|
|
http.Error(w, "Chat ID is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
chatID, err := uuid.Parse(rawChatID)
|
|
if err != nil {
|
|
log.WithError(err).Error("invalid chat ID format")
|
|
http.Error(w, "Invalid chat ID format", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Subscribe Before Snapshot
|
|
updates, unsubscribe, active := a.generationManager.subscribe(chatID)
|
|
defer unsubscribe()
|
|
|
|
// Get Chat Snapshot
|
|
chat, err := a.store.GetChat(chatID)
|
|
if err != nil {
|
|
log.WithError(err).WithField("chat_id", chatID).Error("failed to get chat")
|
|
http.Error(w, "Failed to get chat", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Set Headers
|
|
w.Header().Set("Content-Type", "application/x-ndjson")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
flushWriter := newFlushWriter(w)
|
|
|
|
// Send Snapshot
|
|
if err := json.NewEncoder(flushWriter).Encode(&MessageChunk{Chat: toChat(chat)}); err != nil {
|
|
log.WithError(err).WithField("chat_id", chatID).Warn("failed to send stream snapshot")
|
|
return
|
|
}
|
|
if !active {
|
|
return
|
|
}
|
|
|
|
// Forward Updates
|
|
for {
|
|
select {
|
|
case <-r.Context().Done():
|
|
return
|
|
case chunk, ok := <-updates:
|
|
if !ok {
|
|
return
|
|
}
|
|
if err := json.NewEncoder(flushWriter).Encode(chunk); err != nil {
|
|
log.WithError(err).WithField("chat_id", chatID).Warn("client stream disconnected")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *API) StopChatGeneration(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "StopChatGenerationHandler")
|
|
|
|
// Parse Chat ID
|
|
rawChatID := r.PathValue("chatId")
|
|
if rawChatID == "" {
|
|
log.Error("missing chat ID parameter")
|
|
http.Error(w, "Chat ID is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
chatID, err := uuid.Parse(rawChatID)
|
|
if err != nil {
|
|
log.WithError(err).Error("invalid chat ID format")
|
|
http.Error(w, "Invalid chat ID format", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Stop Generation
|
|
if !a.generationManager.stop(chatID) {
|
|
http.Error(w, "Chat generation is not active", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) {
|
|
log := a.logger.WithField("handler", "PostChatMessageHandler")
|
|
|
|
rawChatID := r.PathValue("chatId")
|
|
if rawChatID == "" {
|
|
log.Error("missing chat ID parameter")
|
|
http.Error(w, "Chat ID is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
chatID, err := uuid.Parse(rawChatID)
|
|
if err != nil {
|
|
log.WithError(err).Error("invalid chat ID format")
|
|
http.Error(w, "Invalid chat ID format", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var genReq GenerateTextRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&genReq); err != nil {
|
|
log.WithError(err).Error("invalid JSON in text generation request")
|
|
http.Error(w, "Invalid request body format for text generation", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err := genReq.Validate(); err != nil {
|
|
log.WithError(err).Error("invalid request")
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Start Message
|
|
chunk, err := a.startMessageGeneration(chatID, genReq.Model, genReq.Prompt, genReq.Images, genReq.EnableThinking())
|
|
if err != nil {
|
|
log.WithError(err).WithField("chat_id", chatID).Error("failed to start message generation")
|
|
if errors.Is(err, errGenerationActive) {
|
|
http.Error(w, "Chat generation already active", http.StatusConflict)
|
|
} else {
|
|
http.Error(w, "Failed to start message generation", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(chunk); err != nil {
|
|
log.WithError(err).Error("failed to encode message generation response")
|
|
http.Error(w, "Failed to encode message generation response", http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func (a *API) getClient() (*client.Client, error) {
|
|
if a.client != nil {
|
|
return a.client, nil
|
|
}
|
|
|
|
// Parse LLM Endpoint from Config
|
|
baseURL, err := url.Parse(a.llmEndpoint)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid API endpoint URL: %w", err)
|
|
}
|
|
|
|
a.client = client.NewClient(baseURL, a.llmKey)
|
|
return a.client, nil
|
|
}
|
|
|
|
func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage string, images []string, enableThinking bool) (*MessageChunk, error) {
|
|
apiClient, err := a.getClient()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get client: %w", err)
|
|
}
|
|
|
|
var chat *store.Chat
|
|
var userMsg *store.Message
|
|
var assistantMsg *store.Message
|
|
var initialChunk *MessageChunk
|
|
|
|
// Start Generation - The manager reserves the chat before messages are
|
|
// persisted, preventing concurrent completions from creating duplicate rows.
|
|
if err := a.generationManager.start(chatID, func(_ *generation) error {
|
|
// Create User Message
|
|
userMsg = &store.Message{ChatID: chatID, Role: "user", Content: userMessage, Images: images}
|
|
if err := a.store.SaveChatMessage(userMsg); err != nil {
|
|
return fmt.Errorf("failed to add user message to chat: %w", err)
|
|
}
|
|
|
|
// Get Chat History - Fetch before creating the in-progress assistant message so the
|
|
// LLM request does not include an empty assistant response prefill.
|
|
chat, err = a.store.GetChat(chatID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get chat: %w", err)
|
|
}
|
|
|
|
// Add Assistant Response
|
|
assistantMsg = &store.Message{ChatID: chatID, Role: "assistant", Status: store.MessageStatusStreaming}
|
|
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
|
return fmt.Errorf("failed to add assistant message to chat: %w", err)
|
|
}
|
|
|
|
// Create Initial Chunk
|
|
initialChunk = &MessageChunk{
|
|
Chat: toChatNoMessages(chat),
|
|
UserMessage: userMsg,
|
|
AssistantMessage: assistantMsg,
|
|
}
|
|
return nil
|
|
}, func(gen *generation) {
|
|
a.runMessageGeneration(apiClient, chat, assistantMsg, chatModel, enableThinking, gen)
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return initialChunk, nil
|
|
}
|
|
|
|
func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, assistantMsg *store.Message, chatModel string, enableThinking bool, gen *generation) {
|
|
// Create Generation Context
|
|
ctx, cancel := context.WithTimeout(gen.ctx, a.textGenerationTimeout())
|
|
defer cancel()
|
|
|
|
// Generate Title First - Doing this before the main response avoids busting the KV
|
|
// cache between the user prompt and the assistant reply, and keeps the stream from
|
|
// staying open past the visible response completing.
|
|
if chat.Title == "" && len(chat.Messages) > 0 {
|
|
title, err := apiClient.CreateTitle(ctx, chat.Messages[0].Content, chatModel)
|
|
if err != nil {
|
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to create chat title")
|
|
} else {
|
|
chat.Title = title
|
|
if err := a.store.SaveChat(chat); err != nil {
|
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to update chat")
|
|
} else {
|
|
gen.broadcast(&MessageChunk{Chat: toChatNoMessages(chat)})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Send Message
|
|
if _, err := apiClient.SendMessage(ctx, chat.Messages, chatModel, enableThinking, func(m *client.MessageChunk) error {
|
|
messageChanged := false
|
|
|
|
if m.Stats != nil {
|
|
messageChanged = true
|
|
assistantMsg.Stats = m.Stats
|
|
}
|
|
if m.Message != nil {
|
|
messageChanged = true
|
|
assistantMsg.Content += *m.Message
|
|
}
|
|
if m.Thinking != nil {
|
|
messageChanged = true
|
|
assistantMsg.Thinking += *m.Thinking
|
|
}
|
|
|
|
// Save And Broadcast Progress
|
|
if messageChanged {
|
|
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
|
return fmt.Errorf("failed to save assistant progress: %w", err)
|
|
}
|
|
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
// Handle Stopped Generation
|
|
if errors.Is(gen.ctx.Err(), context.Canceled) {
|
|
assistantMsg.Status = store.MessageStatusStopped
|
|
if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil {
|
|
a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message")
|
|
}
|
|
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
|
return
|
|
}
|
|
|
|
// Handle Error Generation
|
|
assistantMsg.Status = store.MessageStatusError
|
|
if saveErr := a.store.SaveChatMessage(assistantMsg); saveErr != nil {
|
|
a.logger.WithError(saveErr).WithField("chat_id", chat.ID).Error("failed to save errored assistant message")
|
|
}
|
|
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to generate text stream")
|
|
return
|
|
}
|
|
|
|
// Handle Stopped Generation
|
|
if errors.Is(gen.ctx.Err(), context.Canceled) {
|
|
assistantMsg.Status = store.MessageStatusStopped
|
|
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to save stopped assistant message")
|
|
return
|
|
}
|
|
gen.broadcast(&MessageChunk{AssistantMessage: assistantMsg})
|
|
return
|
|
}
|
|
|
|
// Complete Assistant Message
|
|
assistantMsg.Status = store.MessageStatusComplete
|
|
if err := a.store.SaveChatMessage(assistantMsg); err != nil {
|
|
a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to save assistant message")
|
|
return
|
|
}
|
|
gen.broadcast(&MessageChunk{Chat: toChatNoMessages(chat), AssistantMessage: assistantMsg})
|
|
}
|