diff --git a/backend/internal/api/handlers.go b/backend/internal/api/handlers.go index de8cf50..ddbd9bf 100644 --- a/backend/internal/api/handlers.go +++ b/backend/internal/api/handlers.go @@ -345,7 +345,7 @@ func (a *API) PostChat(w http.ResponseWriter, r *http.Request) { } // Start Message - chunk, err := a.startMessageGeneration(chat.ID, genReq.Model, genReq.Prompt, genReq.Images) + chunk, err := a.startMessageGeneration(chat.ID, genReq.Model, genReq.Prompt, genReq.Images, genReq.EnableThinking()) if err != nil { log.WithError(err).WithField("chat_id", chat.ID).Error("failed to start message generation") http.Error(w, "Failed to start message generation", http.StatusInternalServerError) @@ -539,7 +539,7 @@ func (a *API) PostChatMessage(w http.ResponseWriter, r *http.Request) { } // Start Message - chunk, err := a.startMessageGeneration(chatID, genReq.Model, genReq.Prompt, genReq.Images) + chunk, err := a.startMessageGeneration(chatID, genReq.Model, genReq.Prompt, genReq.Images, genReq.EnableThinking()) if err != nil { log.WithError(err).WithField("chat_id", chatID).Error("failed to start message generation") if errors.Is(err, errGenerationActive) { @@ -572,7 +572,7 @@ func (a *API) getClient() (*client.Client, error) { return a.client, nil } -func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage string, images []string) (*MessageChunk, error) { +func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage string, images []string, enableThinking bool) (*MessageChunk, error) { apiClient, err := a.getClient() if err != nil { return nil, fmt.Errorf("failed to get client: %w", err) @@ -613,7 +613,7 @@ func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage st } return nil }, func(gen *generation) { - a.runMessageGeneration(apiClient, chat, assistantMsg, chatModel, gen) + a.runMessageGeneration(apiClient, chat, assistantMsg, chatModel, enableThinking, gen) }); err != nil { return nil, err } @@ -621,13 +621,30 @@ func (a *API) startMessageGeneration(chatID uuid.UUID, chatModel, userMessage st return initialChunk, nil } -func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, assistantMsg *store.Message, chatModel string, gen *generation) { +func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, assistantMsg *store.Message, chatModel string, enableThinking bool, gen *generation) { // Create Generation Context ctx, cancel := context.WithTimeout(gen.ctx, a.textGenerationTimeout()) defer cancel() + // Generate Title First - Doing this before the main response avoids busting the KV + // cache between the user prompt and the assistant reply, and keeps the stream from + // staying open past the visible response completing. + if chat.Title == "" && len(chat.Messages) > 0 { + title, err := apiClient.CreateTitle(ctx, chat.Messages[0].Content, chatModel) + if err != nil { + a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to create chat title") + } else { + chat.Title = title + if err := a.store.SaveChat(chat); err != nil { + a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to update chat") + } else { + gen.broadcast(&MessageChunk{Chat: toChatNoMessages(chat)}) + } + } + } + // Send Message - if _, err := apiClient.SendMessage(ctx, chat.Messages, chatModel, func(m *client.MessageChunk) error { + if _, err := apiClient.SendMessage(ctx, chat.Messages, chatModel, enableThinking, func(m *client.MessageChunk) error { messageChanged := false if m.Stats != nil { @@ -684,17 +701,6 @@ func (a *API) runMessageGeneration(apiClient *client.Client, chat *store.Chat, a return } - // Summarize & Update Chat Title - if chat.Title == "" { - var err error - chat.Title, err = apiClient.CreateTitle(ctx, chat.Messages[0].Content, chatModel) - if err != nil { - a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to create chat title") - } else if err := a.store.SaveChat(chat); err != nil { - a.logger.WithError(err).WithField("chat_id", chat.ID).Error("failed to update chat") - } - } - // Complete Assistant Message assistantMsg.Status = store.MessageStatusComplete if err := a.store.SaveChatMessage(assistantMsg); err != nil { diff --git a/backend/internal/api/types.go b/backend/internal/api/types.go index e481e30..c6403db 100644 --- a/backend/internal/api/types.go +++ b/backend/internal/api/types.go @@ -69,9 +69,17 @@ type ImageRecord struct { } type GenerateTextRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Images []string `json:"images,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Images []string `json:"images,omitempty"` + Thinking *bool `json:"thinking,omitempty"` +} + +func (r *GenerateTextRequest) EnableThinking() bool { + if r.Thinking == nil { + return true + } + return *r.Thinking } func (r *GenerateTextRequest) Validate() error { diff --git a/backend/internal/client/client.go b/backend/internal/client/client.go index d4a65eb..4639496 100644 --- a/backend/internal/client/client.go +++ b/backend/internal/client/client.go @@ -66,7 +66,7 @@ func (c *Client) EditImage(ctx context.Context, body openai.ImageEditParams) ([] return resp.Data, nil } -func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, model string, cb StreamCallback) (string, error) { +func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, model string, enableThinking bool, cb StreamCallback) (string, error) { // Ensure Callback if cb == nil { cb = func(mc *MessageChunk) error { return nil } @@ -89,7 +89,8 @@ func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, }, } chatReq.SetExtraFields(map[string]any{ - "timings_per_token": true, // Llama.cpp + "timings_per_token": true, // Llama.cpp + "chat_template_kwargs": map[string]any{"enable_thinking": enableThinking}, }) // Perform Request & Allocate Stats @@ -172,7 +173,7 @@ func (c *Client) CreateTitle(ctx context.Context, userMessage, model string) (st output, err := c.SendMessage(ctx, []*store.Message{{ Role: "user", Content: prompt, - }}, model, nil) + }}, model, false, nil) if err != nil { return "", fmt.Errorf("failed to sent message: %w", err) } diff --git a/backend/internal/client/client_test.go b/backend/internal/client/client_test.go index c666ade..9a5398a 100644 --- a/backend/internal/client/client_test.go +++ b/backend/internal/client/client_test.go @@ -32,7 +32,7 @@ func TestSendMessage(t *testing.T) { _, err = client.SendMessage(ctx, []*store.Message{{ Role: "user", Content: "What is 2+2? Think step by step.", - }}, model, func(mc *MessageChunk) error { + }}, model, true, func(mc *MessageChunk) error { if mc.Thinking != nil { _, err := thinkingBuf.Write([]byte(*mc.Thinking)) return err @@ -118,7 +118,7 @@ func TestSendMessageWithImage(t *testing.T) { Role: "user", Content: "Describe this image in detail.", Images: []string{dataURL}, - }}, model, func(mc *MessageChunk) error { + }}, model, true, func(mc *MessageChunk) error { if mc.Message != nil { outputBuf.WriteString(*mc.Message) } diff --git a/frontend/public/pages/chats.html b/frontend/public/pages/chats.html index cf54289..00a9875 100644 --- a/frontend/public/pages/chats.html +++ b/frontend/public/pages/chats.html @@ -1,7 +1,8 @@