- Restructure floating input: dominant textarea with compact bottom toolbar (model badge, thinking toggle, attach, send/stop). - Model badge sizes to the current selection (not widest option) via a layered transparent select, with truncate-on-overflow fallback. - Auto-expand the conversation sidebar on desktop and slide chat content right when open instead of overlaying. - Add per-request thinking toggle (brain icon, default on, persisted in localStorage) sending chat_template_kwargs.enable_thinking. - Always disable thinking for title summarization. - Generate chat titles before the main response to keep the SSE stream from staying open past visible completion and to avoid busting the KV cache between turns.
327 lines
8.6 KiB
Go
327 lines
8.6 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/url"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/google/jsonschema-go/jsonschema"
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/openai/openai-go/v3/option"
|
|
"github.com/openai/openai-go/v3/packages/respjson"
|
|
"github.com/openai/openai-go/v3/shared"
|
|
"reichard.io/aethera/internal/store"
|
|
"reichard.io/aethera/internal/types"
|
|
"reichard.io/aethera/pkg/ptr"
|
|
"reichard.io/aethera/pkg/slices"
|
|
)
|
|
|
|
type StreamCallback func(*MessageChunk) error
|
|
|
|
type Client struct {
|
|
oaiClient *openai.Client
|
|
}
|
|
|
|
func (c *Client) GetModels(ctx context.Context) ([]Model, error) {
|
|
// Get Models
|
|
currPage, err := c.oaiClient.Models.List(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allData := currPage.Data
|
|
|
|
// Pagination
|
|
for {
|
|
currPage, err = currPage.GetNextPage()
|
|
if err != nil {
|
|
return nil, err
|
|
} else if currPage == nil {
|
|
break
|
|
}
|
|
allData = append(allData, currPage.Data...)
|
|
}
|
|
|
|
// Convert
|
|
return slices.Map(allData, fromOpenAIModel), nil
|
|
}
|
|
|
|
func (c *Client) GenerateImages(ctx context.Context, body openai.ImageGenerateParams) ([]openai.Image, error) {
|
|
// Generate Images
|
|
resp, err := c.oaiClient.Images.Generate(ctx, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp.Data, nil
|
|
}
|
|
|
|
func (c *Client) EditImage(ctx context.Context, body openai.ImageEditParams) ([]openai.Image, error) {
|
|
// Edit Image
|
|
resp, err := c.oaiClient.Images.Edit(ctx, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp.Data, nil
|
|
}
|
|
|
|
func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, model string, enableThinking bool, cb StreamCallback) (string, error) {
|
|
// Ensure Callback
|
|
if cb == nil {
|
|
cb = func(mc *MessageChunk) error { return nil }
|
|
}
|
|
|
|
// Map Messages
|
|
messages := slices.Map(chatMessages, func(m *store.Message) openai.ChatCompletionMessageParamUnion {
|
|
if m.Role == "user" {
|
|
return buildUserMessage(m)
|
|
}
|
|
return openai.AssistantMessage(m.Content)
|
|
})
|
|
|
|
// Create Request
|
|
chatReq := openai.ChatCompletionNewParams{
|
|
Model: model,
|
|
Messages: messages,
|
|
StreamOptions: openai.ChatCompletionStreamOptionsParam{
|
|
IncludeUsage: openai.Bool(true),
|
|
},
|
|
}
|
|
chatReq.SetExtraFields(map[string]any{
|
|
"timings_per_token": true, // Llama.cpp
|
|
"chat_template_kwargs": map[string]any{"enable_thinking": enableThinking},
|
|
})
|
|
|
|
// Perform Request & Allocate Stats
|
|
msgStats := types.MessageStats{StartTime: time.Now()}
|
|
stream := c.oaiClient.Chat.Completions.NewStreaming(ctx, chatReq)
|
|
|
|
// Iterate Stream
|
|
var respContent string
|
|
for stream.Next() {
|
|
// Check Context
|
|
if ctx.Err() != nil {
|
|
return respContent, ctx.Err()
|
|
}
|
|
|
|
// Load Chunk
|
|
chunk := stream.Current()
|
|
msgChunk := &MessageChunk{Stats: &msgStats}
|
|
|
|
// Populate Timings
|
|
sendUpdate := populateLlamaCPPTimings(&msgStats, chunk.JSON.ExtraFields)
|
|
sendUpdate = populateUsageTimings(&msgStats, chunk.Usage) || sendUpdate
|
|
|
|
if len(chunk.Choices) > 0 {
|
|
delta := chunk.Choices[0].Delta
|
|
|
|
// Check Thinking - Support both "reasoning_content" (DeepSeek)
|
|
// and "reasoning" (vLLM) field names.
|
|
for _, thinkingKey := range []string{"reasoning_content", "reasoning"} {
|
|
if thinkingField, found := delta.JSON.ExtraFields[thinkingKey]; found {
|
|
var thinkingContent string
|
|
if err := json.Unmarshal([]byte(thinkingField.Raw()), &thinkingContent); err != nil {
|
|
return respContent, fmt.Errorf("thinking unmarshal error: %w", err)
|
|
} else if thinkingContent != "" {
|
|
msgStats.RecordFirstToken()
|
|
sendUpdate = true
|
|
msgChunk.Thinking = ptr.Of(thinkingContent)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
// Check Content
|
|
if delta.Content != "" {
|
|
msgStats.RecordFirstToken()
|
|
sendUpdate = true
|
|
msgChunk.Message = ptr.Of(delta.Content)
|
|
respContent += delta.Content
|
|
}
|
|
}
|
|
|
|
// Send Timings
|
|
if sendUpdate {
|
|
msgStats.CalculateDerived()
|
|
if err := cb(msgChunk); err != nil {
|
|
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check Error
|
|
if err := stream.Err(); err != nil {
|
|
return respContent, fmt.Errorf("stream error: %w", err)
|
|
}
|
|
|
|
// Send Final Chunk
|
|
msgStats.RecordLastToken()
|
|
msgStats.CalculateDerived()
|
|
if err := cb(&MessageChunk{Stats: &msgStats}); err != nil {
|
|
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
|
}
|
|
|
|
return respContent, nil
|
|
}
|
|
|
|
func (c *Client) CreateTitle(ctx context.Context, userMessage, model string) (string, error) {
|
|
prompt := "You are an agent responsible for creating titles for chats based on the initial message. " +
|
|
"Your titles should be succinct and short. Respond with JUST the chat title. Initial Message: \n\n" + userMessage
|
|
|
|
// Generate Text Stream
|
|
output, err := c.SendMessage(ctx, []*store.Message{{
|
|
Role: "user",
|
|
Content: prompt,
|
|
}}, model, false, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sent message: %w", err)
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
func (c *Client) StructuredOutput(ctx context.Context, target any, prompt, model string) error {
|
|
// Validate Target Pointer
|
|
v := reflect.ValueOf(target)
|
|
if v.Kind() != reflect.Pointer {
|
|
return fmt.Errorf("target must be a pointer, got %T", target)
|
|
}
|
|
if v.IsNil() {
|
|
return fmt.Errorf("target pointer is nil")
|
|
}
|
|
|
|
// Validate Target Struct
|
|
elem := v.Elem()
|
|
if elem.Kind() != reflect.Struct {
|
|
return fmt.Errorf("target must be a pointer to struct, got pointer to %s", elem.Kind())
|
|
}
|
|
|
|
// Build Schema
|
|
schema, err := buildJSONSchema(elem.Type())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to build schema: %w", err)
|
|
}
|
|
|
|
// Perform Request
|
|
resp, err := c.oaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
|
|
Model: model,
|
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
|
openai.UserMessage(prompt),
|
|
},
|
|
ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{
|
|
OfJSONSchema: schema,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("API call failed: %w", err)
|
|
}
|
|
|
|
// Parse Response
|
|
content := resp.Choices[0].Message.Content
|
|
if err := json.Unmarshal([]byte(content), target); err != nil {
|
|
return fmt.Errorf("failed to unmarshal response: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func buildJSONSchema(rType reflect.Type) (*shared.ResponseFormatJSONSchemaParam, error) {
|
|
schema, err := jsonschema.ForType(rType, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &shared.ResponseFormatJSONSchemaParam{
|
|
JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
|
|
Name: rType.Name(),
|
|
Schema: map[string]any{
|
|
"type": schema.Type,
|
|
"properties": schema.Properties,
|
|
"required": schema.Required,
|
|
"additionalProperties": false,
|
|
},
|
|
Strict: openai.Bool(true),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func populateLlamaCPPTimings(msgStats *types.MessageStats, extraFields map[string]respjson.Field) bool {
|
|
rawTimings, found := extraFields["timings"]
|
|
if !found {
|
|
return false
|
|
}
|
|
|
|
var llamaTimings llamaCPPTimings
|
|
if err := json.Unmarshal([]byte(rawTimings.Raw()), &llamaTimings); err != nil {
|
|
return false
|
|
}
|
|
|
|
if llamaTimings.PromptN != 0 {
|
|
msgStats.PromptTokens = ptr.Of(int32(llamaTimings.PromptN))
|
|
}
|
|
if llamaTimings.PredictedN != 0 {
|
|
msgStats.GeneratedTokens = ptr.Of(int32(llamaTimings.PredictedN))
|
|
}
|
|
|
|
msgStats.PromptPerSec = ptr.Of(float32(llamaTimings.PromptPerSecond))
|
|
msgStats.GeneratedPerSec = ptr.Of(float32(llamaTimings.PredictedPerSecond))
|
|
|
|
return true
|
|
}
|
|
|
|
func populateUsageTimings(msgStats *types.MessageStats, usage openai.CompletionUsage) (didChange bool) {
|
|
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 {
|
|
return
|
|
}
|
|
|
|
if msgStats.PromptTokens == nil {
|
|
didChange = true
|
|
msgStats.PromptTokens = ptr.Of(int32(usage.PromptTokens))
|
|
}
|
|
|
|
if msgStats.GeneratedTokens == nil {
|
|
didChange = true
|
|
reasoningTokens := usage.CompletionTokensDetails.ReasoningTokens
|
|
msgStats.GeneratedTokens = ptr.Of(int32(usage.CompletionTokens + reasoningTokens))
|
|
}
|
|
|
|
return didChange
|
|
}
|
|
|
|
func NewClient(baseURL *url.URL, apiKey string) *Client {
|
|
opts := []option.RequestOption{option.WithBaseURL(baseURL.String())}
|
|
if apiKey != "" {
|
|
opts = append(opts, option.WithAPIKey(apiKey))
|
|
}
|
|
oaiClient := openai.NewClient(opts...)
|
|
return &Client{oaiClient: &oaiClient}
|
|
}
|
|
|
|
func buildUserMessage(m *store.Message) openai.ChatCompletionMessageParamUnion {
|
|
// Simple Text Message
|
|
if len(m.Images) == 0 {
|
|
return openai.UserMessage(m.Content)
|
|
}
|
|
|
|
// Build Multimodal Content Parts
|
|
parts := make([]openai.ChatCompletionContentPartUnionParam, 0, len(m.Images)+1)
|
|
|
|
// Add Image Parts
|
|
for _, imgURL := range m.Images {
|
|
parts = append(parts, openai.ImageContentPart(
|
|
openai.ChatCompletionContentPartImageImageURLParam{
|
|
URL: imgURL,
|
|
},
|
|
))
|
|
}
|
|
|
|
// Add Text Part
|
|
if m.Content != "" {
|
|
parts = append(parts, openai.TextContentPart(m.Content))
|
|
}
|
|
|
|
// Build User Message with Content Parts
|
|
return openai.UserMessage(parts)
|
|
}
|