package client import ( "context" "encoding/json" "fmt" "net/url" "reflect" "time" "github.com/google/jsonschema-go/jsonschema" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/packages/respjson" "github.com/openai/openai-go/v3/shared" "reichard.io/aethera/internal/store" "reichard.io/aethera/internal/types" "reichard.io/aethera/pkg/ptr" "reichard.io/aethera/pkg/slices" ) type StreamCallback func(*MessageChunk) error type Client struct { oaiClient *openai.Client } func (c *Client) GetModels(ctx context.Context) ([]Model, error) { // Get Models currPage, err := c.oaiClient.Models.List(ctx) if err != nil { return nil, err } allData := currPage.Data // Pagination for { currPage, err = currPage.GetNextPage() if err != nil { return nil, err } else if currPage == nil { break } allData = append(allData, currPage.Data...) } // Convert return slices.Map(allData, fromOpenAIModel), nil } func (c *Client) GenerateImages(ctx context.Context, body openai.ImageGenerateParams) ([]openai.Image, error) { // Generate Images resp, err := c.oaiClient.Images.Generate(ctx, body) if err != nil { return nil, err } return resp.Data, nil } func (c *Client) EditImage(ctx context.Context, body openai.ImageEditParams) ([]openai.Image, error) { // Edit Image resp, err := c.oaiClient.Images.Edit(ctx, body) if err != nil { return nil, err } return resp.Data, nil } func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, model string, cb StreamCallback) (string, error) { // Ensure Callback if cb == nil { cb = func(mc *MessageChunk) error { return nil } } // Map Messages messages := slices.Map(chatMessages, func(m *store.Message) openai.ChatCompletionMessageParamUnion { if m.Role == "user" { return openai.UserMessage(m.Content) } return openai.AssistantMessage(m.Content) }) // Create Request chatReq := openai.ChatCompletionNewParams{ Model: model, Messages: messages, StreamOptions: openai.ChatCompletionStreamOptionsParam{ IncludeUsage: openai.Bool(true), }, } chatReq.SetExtraFields(map[string]any{ "timings_per_token": true, // Llama.cpp }) // Perform Request & Allocate Stats msgStats := types.MessageStats{StartTime: time.Now()} stream := c.oaiClient.Chat.Completions.NewStreaming(ctx, chatReq) // Iterate Stream var respContent string for stream.Next() { // Check Context if ctx.Err() != nil { return respContent, ctx.Err() } // Load Chunk chunk := stream.Current() msgChunk := &MessageChunk{Stats: &msgStats} // Populate Timings sendUpdate := populateLlamaCPPTimings(&msgStats, chunk.JSON.ExtraFields) sendUpdate = populateUsageTimings(&msgStats, chunk.Usage) || sendUpdate if len(chunk.Choices) > 0 { delta := chunk.Choices[0].Delta // Check Thinking if thinkingField, found := delta.JSON.ExtraFields["reasoning_content"]; found { var thinkingContent string if err := json.Unmarshal([]byte(thinkingField.Raw()), &thinkingContent); err != nil { return respContent, fmt.Errorf("thinking unmarshal error: %w", err) } else if thinkingContent != "" { msgStats.RecordFirstToken() sendUpdate = true msgChunk.Thinking = ptr.Of(thinkingContent) } } // Check Content if delta.Content != "" { msgStats.RecordFirstToken() sendUpdate = true msgChunk.Message = ptr.Of(delta.Content) respContent += delta.Content } } // Send Timings if sendUpdate { msgStats.CalculateDerived() if err := cb(msgChunk); err != nil { return respContent, fmt.Errorf("chunk callback error: %w", err) } } } // Check Error if err := stream.Err(); err != nil { return respContent, fmt.Errorf("stream error: %w", err) } // Send Final Chunk msgStats.RecordLastToken() msgStats.CalculateDerived() if err := cb(&MessageChunk{Stats: &msgStats}); err != nil { return respContent, fmt.Errorf("chunk callback error: %w", err) } return respContent, nil } func (c *Client) CreateTitle(ctx context.Context, userMessage, model string) (string, error) { prompt := "You are an agent responsible for creating titles for chats based on the initial message. " + "Your titles should be succinct and short. Respond with JUST the chat title. Initial Message: \n\n" + userMessage // Generate Text Stream output, err := c.SendMessage(ctx, []*store.Message{{ Role: "user", Content: prompt, }}, model, nil) if err != nil { return "", fmt.Errorf("failed to sent message: %w", err) } return output, nil } func (c *Client) StructuredOutput(ctx context.Context, target any, prompt, model string) error { // Validate Target Pointer v := reflect.ValueOf(target) if v.Kind() != reflect.Pointer { return fmt.Errorf("target must be a pointer, got %T", target) } if v.IsNil() { return fmt.Errorf("target pointer is nil") } // Validate Target Struct elem := v.Elem() if elem.Kind() != reflect.Struct { return fmt.Errorf("target must be a pointer to struct, got pointer to %s", elem.Kind()) } // Build Schema schema, err := buildJSONSchema(elem.Type()) if err != nil { return fmt.Errorf("failed to build schema: %w", err) } // Perform Request resp, err := c.oaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ Model: model, Messages: []openai.ChatCompletionMessageParamUnion{ openai.UserMessage(prompt), }, ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{ OfJSONSchema: schema, }, }) if err != nil { return fmt.Errorf("API call failed: %w", err) } // Parse Response content := resp.Choices[0].Message.Content if err := json.Unmarshal([]byte(content), target); err != nil { return fmt.Errorf("failed to unmarshal response: %w", err) } return nil } func buildJSONSchema(rType reflect.Type) (*shared.ResponseFormatJSONSchemaParam, error) { schema, err := jsonschema.ForType(rType, nil) if err != nil { return nil, err } return &shared.ResponseFormatJSONSchemaParam{ JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{ Name: rType.Name(), Schema: map[string]any{ "type": schema.Type, "properties": schema.Properties, "required": schema.Required, "additionalProperties": false, }, Strict: openai.Bool(true), }, }, nil } func populateLlamaCPPTimings(msgStats *types.MessageStats, extraFields map[string]respjson.Field) bool { rawTimings, found := extraFields["timings"] if !found { return false } var llamaTimings llamaCPPTimings if err := json.Unmarshal([]byte(rawTimings.Raw()), &llamaTimings); err != nil { return false } if llamaTimings.PromptN != 0 { msgStats.PromptTokens = ptr.Of(int32(llamaTimings.PromptN)) } if llamaTimings.PredictedN != 0 { msgStats.GeneratedTokens = ptr.Of(int32(llamaTimings.PredictedN)) } msgStats.PromptPerSec = ptr.Of(float32(llamaTimings.PromptPerSecond)) msgStats.GeneratedPerSec = ptr.Of(float32(llamaTimings.PredictedPerSecond)) return true } func populateUsageTimings(msgStats *types.MessageStats, usage openai.CompletionUsage) (didChange bool) { if usage.PromptTokens == 0 && usage.CompletionTokens == 0 { return } if msgStats.PromptTokens == nil { didChange = true msgStats.PromptTokens = ptr.Of(int32(usage.PromptTokens)) } if msgStats.GeneratedTokens == nil { didChange = true reasoningTokens := usage.CompletionTokensDetails.ReasoningTokens msgStats.GeneratedTokens = ptr.Of(int32(usage.CompletionTokens + reasoningTokens)) } return didChange } func NewClient(baseURL *url.URL) *Client { oaiClient := openai.NewClient(option.WithBaseURL(baseURL.String())) return &Client{oaiClient: &oaiClient} }