initial commit

This commit is contained in:
2025-12-31 15:33:16 -05:00
commit 4641e7d0ef
51 changed files with 4779 additions and 0 deletions

View File

@@ -0,0 +1,290 @@
package client
import (
"context"
"encoding/json"
"fmt"
"net/url"
"reflect"
"time"
"github.com/google/jsonschema-go/jsonschema"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/packages/respjson"
"github.com/openai/openai-go/v3/shared"
"reichard.io/aethera/internal/store"
"reichard.io/aethera/internal/types"
"reichard.io/aethera/pkg/ptr"
"reichard.io/aethera/pkg/slices"
)
type StreamCallback func(*MessageChunk) error
type Client struct {
oaiClient *openai.Client
}
func (c *Client) GetModels(ctx context.Context) ([]Model, error) {
// Get Models
currPage, err := c.oaiClient.Models.List(ctx)
if err != nil {
return nil, err
}
allData := currPage.Data
// Pagination
for {
currPage, err = currPage.GetNextPage()
if err != nil {
return nil, err
} else if currPage == nil {
break
}
allData = append(allData, currPage.Data...)
}
// Convert
return slices.Map(allData, fromOpenAIModel), nil
}
func (c *Client) GenerateImages(ctx context.Context, body openai.ImageGenerateParams) ([]openai.Image, error) {
// Generate Images
resp, err := c.oaiClient.Images.Generate(ctx, body)
if err != nil {
return nil, err
}
return resp.Data, nil
}
func (c *Client) EditImage(ctx context.Context, body openai.ImageEditParams) ([]openai.Image, error) {
// Edit Image
resp, err := c.oaiClient.Images.Edit(ctx, body)
if err != nil {
return nil, err
}
return resp.Data, nil
}
func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, model string, cb StreamCallback) (string, error) {
// Ensure Callback
if cb == nil {
cb = func(mc *MessageChunk) error { return nil }
}
// Map Messages
messages := slices.Map(chatMessages, func(m *store.Message) openai.ChatCompletionMessageParamUnion {
if m.Role == "user" {
return openai.UserMessage(m.Content)
}
return openai.AssistantMessage(m.Content)
})
// Create Request
chatReq := openai.ChatCompletionNewParams{
Model: model,
Messages: messages,
StreamOptions: openai.ChatCompletionStreamOptionsParam{
IncludeUsage: openai.Bool(true),
},
}
chatReq.SetExtraFields(map[string]any{
"timings_per_token": true, // Llama.cpp
})
// Perform Request & Allocate Stats
msgStats := types.MessageStats{StartTime: time.Now()}
stream := c.oaiClient.Chat.Completions.NewStreaming(ctx, chatReq)
// Iterate Stream
var respContent string
for stream.Next() {
// Check Context
if ctx.Err() != nil {
return respContent, ctx.Err()
}
// Load Chunk
chunk := stream.Current()
msgChunk := &MessageChunk{Stats: &msgStats}
// Populate Timings
sendUpdate := populateLlamaCPPTimings(&msgStats, chunk.JSON.ExtraFields)
sendUpdate = populateUsageTimings(&msgStats, chunk.Usage) || sendUpdate
if len(chunk.Choices) > 0 {
delta := chunk.Choices[0].Delta
// Check Thinking
if thinkingField, found := delta.JSON.ExtraFields["reasoning_content"]; found {
var thinkingContent string
if err := json.Unmarshal([]byte(thinkingField.Raw()), &thinkingContent); err != nil {
return respContent, fmt.Errorf("thinking unmarshal error: %w", err)
} else if thinkingContent != "" {
msgStats.RecordFirstToken()
sendUpdate = true
msgChunk.Thinking = ptr.Of(thinkingContent)
}
}
// Check Content
if delta.Content != "" {
msgStats.RecordFirstToken()
sendUpdate = true
msgChunk.Message = ptr.Of(delta.Content)
respContent += delta.Content
}
}
// Send Timings
if sendUpdate {
msgStats.CalculateDerived()
if err := cb(msgChunk); err != nil {
return respContent, fmt.Errorf("chunk callback error: %w", err)
}
}
}
// Check Error
if err := stream.Err(); err != nil {
return respContent, fmt.Errorf("stream error: %w", err)
}
// Send Final Chunk
msgStats.RecordLastToken()
msgStats.CalculateDerived()
if err := cb(&MessageChunk{Stats: &msgStats}); err != nil {
return respContent, fmt.Errorf("chunk callback error: %w", err)
}
return respContent, nil
}
func (c *Client) CreateTitle(ctx context.Context, userMessage, model string) (string, error) {
prompt := "You are an agent responsible for creating titles for chats based on the initial message. " +
"Your titles should be succinct and short. Respond with JUST the chat title. Initial Message: \n\n" + userMessage
// Generate Text Stream
output, err := c.SendMessage(ctx, []*store.Message{{
Role: "user",
Content: prompt,
}}, model, nil)
if err != nil {
return "", fmt.Errorf("failed to sent message: %w", err)
}
return output, nil
}
func (c *Client) StructuredOutput(ctx context.Context, target any, prompt, model string) error {
// Validate Target Pointer
v := reflect.ValueOf(target)
if v.Kind() != reflect.Pointer {
return fmt.Errorf("target must be a pointer, got %T", target)
}
if v.IsNil() {
return fmt.Errorf("target pointer is nil")
}
// Validate Target Struct
elem := v.Elem()
if elem.Kind() != reflect.Struct {
return fmt.Errorf("target must be a pointer to struct, got pointer to %s", elem.Kind())
}
// Build Schema
schema, err := buildJSONSchema(elem.Type())
if err != nil {
return fmt.Errorf("failed to build schema: %w", err)
}
// Perform Request
resp, err := c.oaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Model: model,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
},
ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{
OfJSONSchema: schema,
},
})
if err != nil {
return fmt.Errorf("API call failed: %w", err)
}
// Parse Response
content := resp.Choices[0].Message.Content
if err := json.Unmarshal([]byte(content), target); err != nil {
return fmt.Errorf("failed to unmarshal response: %w", err)
}
return nil
}
func buildJSONSchema(rType reflect.Type) (*shared.ResponseFormatJSONSchemaParam, error) {
schema, err := jsonschema.ForType(rType, nil)
if err != nil {
return nil, err
}
return &shared.ResponseFormatJSONSchemaParam{
JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
Name: rType.Name(),
Schema: map[string]any{
"type": schema.Type,
"properties": schema.Properties,
"required": schema.Required,
"additionalProperties": false,
},
Strict: openai.Bool(true),
},
}, nil
}
func populateLlamaCPPTimings(msgStats *types.MessageStats, extraFields map[string]respjson.Field) bool {
rawTimings, found := extraFields["timings"]
if !found {
return false
}
var llamaTimings llamaCPPTimings
if err := json.Unmarshal([]byte(rawTimings.Raw()), &llamaTimings); err != nil {
return false
}
if llamaTimings.PromptN != 0 {
msgStats.PromptTokens = ptr.Of(int32(llamaTimings.PromptN))
}
if llamaTimings.PredictedN != 0 {
msgStats.GeneratedTokens = ptr.Of(int32(llamaTimings.PredictedN))
}
msgStats.PromptPerSec = ptr.Of(float32(llamaTimings.PromptPerSecond))
msgStats.GeneratedPerSec = ptr.Of(float32(llamaTimings.PredictedPerSecond))
return true
}
func populateUsageTimings(msgStats *types.MessageStats, usage openai.CompletionUsage) (didChange bool) {
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 {
return
}
if msgStats.PromptTokens == nil {
didChange = true
msgStats.PromptTokens = ptr.Of(int32(usage.PromptTokens))
}
if msgStats.GeneratedTokens == nil {
didChange = true
reasoningTokens := usage.CompletionTokensDetails.ReasoningTokens
msgStats.GeneratedTokens = ptr.Of(int32(usage.CompletionTokens + reasoningTokens))
}
return didChange
}
func NewClient(baseURL *url.URL) *Client {
oaiClient := openai.NewClient(option.WithBaseURL(baseURL.String()))
return &Client{oaiClient: &oaiClient}
}