291 lines
7.6 KiB
Go
291 lines
7.6 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/url"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/google/jsonschema-go/jsonschema"
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/openai/openai-go/v3/option"
|
|
"github.com/openai/openai-go/v3/packages/respjson"
|
|
"github.com/openai/openai-go/v3/shared"
|
|
"reichard.io/aethera/internal/store"
|
|
"reichard.io/aethera/internal/types"
|
|
"reichard.io/aethera/pkg/ptr"
|
|
"reichard.io/aethera/pkg/slices"
|
|
)
|
|
|
|
type StreamCallback func(*MessageChunk) error
|
|
|
|
type Client struct {
|
|
oaiClient *openai.Client
|
|
}
|
|
|
|
func (c *Client) GetModels(ctx context.Context) ([]Model, error) {
|
|
// Get Models
|
|
currPage, err := c.oaiClient.Models.List(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allData := currPage.Data
|
|
|
|
// Pagination
|
|
for {
|
|
currPage, err = currPage.GetNextPage()
|
|
if err != nil {
|
|
return nil, err
|
|
} else if currPage == nil {
|
|
break
|
|
}
|
|
allData = append(allData, currPage.Data...)
|
|
}
|
|
|
|
// Convert
|
|
return slices.Map(allData, fromOpenAIModel), nil
|
|
}
|
|
|
|
func (c *Client) GenerateImages(ctx context.Context, body openai.ImageGenerateParams) ([]openai.Image, error) {
|
|
// Generate Images
|
|
resp, err := c.oaiClient.Images.Generate(ctx, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp.Data, nil
|
|
}
|
|
|
|
func (c *Client) EditImage(ctx context.Context, body openai.ImageEditParams) ([]openai.Image, error) {
|
|
// Edit Image
|
|
resp, err := c.oaiClient.Images.Edit(ctx, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp.Data, nil
|
|
}
|
|
|
|
func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, model string, cb StreamCallback) (string, error) {
|
|
// Ensure Callback
|
|
if cb == nil {
|
|
cb = func(mc *MessageChunk) error { return nil }
|
|
}
|
|
|
|
// Map Messages
|
|
messages := slices.Map(chatMessages, func(m *store.Message) openai.ChatCompletionMessageParamUnion {
|
|
if m.Role == "user" {
|
|
return openai.UserMessage(m.Content)
|
|
}
|
|
return openai.AssistantMessage(m.Content)
|
|
})
|
|
|
|
// Create Request
|
|
chatReq := openai.ChatCompletionNewParams{
|
|
Model: model,
|
|
Messages: messages,
|
|
StreamOptions: openai.ChatCompletionStreamOptionsParam{
|
|
IncludeUsage: openai.Bool(true),
|
|
},
|
|
}
|
|
chatReq.SetExtraFields(map[string]any{
|
|
"timings_per_token": true, // Llama.cpp
|
|
})
|
|
|
|
// Perform Request & Allocate Stats
|
|
msgStats := types.MessageStats{StartTime: time.Now()}
|
|
stream := c.oaiClient.Chat.Completions.NewStreaming(ctx, chatReq)
|
|
|
|
// Iterate Stream
|
|
var respContent string
|
|
for stream.Next() {
|
|
// Check Context
|
|
if ctx.Err() != nil {
|
|
return respContent, ctx.Err()
|
|
}
|
|
|
|
// Load Chunk
|
|
chunk := stream.Current()
|
|
msgChunk := &MessageChunk{Stats: &msgStats}
|
|
|
|
// Populate Timings
|
|
sendUpdate := populateLlamaCPPTimings(&msgStats, chunk.JSON.ExtraFields)
|
|
sendUpdate = populateUsageTimings(&msgStats, chunk.Usage) || sendUpdate
|
|
|
|
if len(chunk.Choices) > 0 {
|
|
delta := chunk.Choices[0].Delta
|
|
|
|
// Check Thinking
|
|
if thinkingField, found := delta.JSON.ExtraFields["reasoning_content"]; found {
|
|
var thinkingContent string
|
|
if err := json.Unmarshal([]byte(thinkingField.Raw()), &thinkingContent); err != nil {
|
|
return respContent, fmt.Errorf("thinking unmarshal error: %w", err)
|
|
} else if thinkingContent != "" {
|
|
msgStats.RecordFirstToken()
|
|
sendUpdate = true
|
|
msgChunk.Thinking = ptr.Of(thinkingContent)
|
|
}
|
|
}
|
|
|
|
// Check Content
|
|
if delta.Content != "" {
|
|
msgStats.RecordFirstToken()
|
|
sendUpdate = true
|
|
msgChunk.Message = ptr.Of(delta.Content)
|
|
respContent += delta.Content
|
|
}
|
|
}
|
|
|
|
// Send Timings
|
|
if sendUpdate {
|
|
msgStats.CalculateDerived()
|
|
if err := cb(msgChunk); err != nil {
|
|
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check Error
|
|
if err := stream.Err(); err != nil {
|
|
return respContent, fmt.Errorf("stream error: %w", err)
|
|
}
|
|
|
|
// Send Final Chunk
|
|
msgStats.RecordLastToken()
|
|
msgStats.CalculateDerived()
|
|
if err := cb(&MessageChunk{Stats: &msgStats}); err != nil {
|
|
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
|
}
|
|
|
|
return respContent, nil
|
|
}
|
|
|
|
func (c *Client) CreateTitle(ctx context.Context, userMessage, model string) (string, error) {
|
|
prompt := "You are an agent responsible for creating titles for chats based on the initial message. " +
|
|
"Your titles should be succinct and short. Respond with JUST the chat title. Initial Message: \n\n" + userMessage
|
|
|
|
// Generate Text Stream
|
|
output, err := c.SendMessage(ctx, []*store.Message{{
|
|
Role: "user",
|
|
Content: prompt,
|
|
}}, model, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sent message: %w", err)
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
func (c *Client) StructuredOutput(ctx context.Context, target any, prompt, model string) error {
|
|
// Validate Target Pointer
|
|
v := reflect.ValueOf(target)
|
|
if v.Kind() != reflect.Pointer {
|
|
return fmt.Errorf("target must be a pointer, got %T", target)
|
|
}
|
|
if v.IsNil() {
|
|
return fmt.Errorf("target pointer is nil")
|
|
}
|
|
|
|
// Validate Target Struct
|
|
elem := v.Elem()
|
|
if elem.Kind() != reflect.Struct {
|
|
return fmt.Errorf("target must be a pointer to struct, got pointer to %s", elem.Kind())
|
|
}
|
|
|
|
// Build Schema
|
|
schema, err := buildJSONSchema(elem.Type())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to build schema: %w", err)
|
|
}
|
|
|
|
// Perform Request
|
|
resp, err := c.oaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
|
|
Model: model,
|
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
|
openai.UserMessage(prompt),
|
|
},
|
|
ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{
|
|
OfJSONSchema: schema,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("API call failed: %w", err)
|
|
}
|
|
|
|
// Parse Response
|
|
content := resp.Choices[0].Message.Content
|
|
if err := json.Unmarshal([]byte(content), target); err != nil {
|
|
return fmt.Errorf("failed to unmarshal response: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func buildJSONSchema(rType reflect.Type) (*shared.ResponseFormatJSONSchemaParam, error) {
|
|
schema, err := jsonschema.ForType(rType, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &shared.ResponseFormatJSONSchemaParam{
|
|
JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
|
|
Name: rType.Name(),
|
|
Schema: map[string]any{
|
|
"type": schema.Type,
|
|
"properties": schema.Properties,
|
|
"required": schema.Required,
|
|
"additionalProperties": false,
|
|
},
|
|
Strict: openai.Bool(true),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func populateLlamaCPPTimings(msgStats *types.MessageStats, extraFields map[string]respjson.Field) bool {
|
|
rawTimings, found := extraFields["timings"]
|
|
if !found {
|
|
return false
|
|
}
|
|
|
|
var llamaTimings llamaCPPTimings
|
|
if err := json.Unmarshal([]byte(rawTimings.Raw()), &llamaTimings); err != nil {
|
|
return false
|
|
}
|
|
|
|
if llamaTimings.PromptN != 0 {
|
|
msgStats.PromptTokens = ptr.Of(int32(llamaTimings.PromptN))
|
|
}
|
|
if llamaTimings.PredictedN != 0 {
|
|
msgStats.GeneratedTokens = ptr.Of(int32(llamaTimings.PredictedN))
|
|
}
|
|
|
|
msgStats.PromptPerSec = ptr.Of(float32(llamaTimings.PromptPerSecond))
|
|
msgStats.GeneratedPerSec = ptr.Of(float32(llamaTimings.PredictedPerSecond))
|
|
|
|
return true
|
|
}
|
|
|
|
func populateUsageTimings(msgStats *types.MessageStats, usage openai.CompletionUsage) (didChange bool) {
|
|
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 {
|
|
return
|
|
}
|
|
|
|
if msgStats.PromptTokens == nil {
|
|
didChange = true
|
|
msgStats.PromptTokens = ptr.Of(int32(usage.PromptTokens))
|
|
}
|
|
|
|
if msgStats.GeneratedTokens == nil {
|
|
didChange = true
|
|
reasoningTokens := usage.CompletionTokensDetails.ReasoningTokens
|
|
msgStats.GeneratedTokens = ptr.Of(int32(usage.CompletionTokens + reasoningTokens))
|
|
}
|
|
|
|
return didChange
|
|
}
|
|
|
|
func NewClient(baseURL *url.URL) *Client {
|
|
oaiClient := openai.NewClient(option.WithBaseURL(baseURL.String()))
|
|
return &Client{oaiClient: &oaiClient}
|
|
}
|