initial commit
This commit is contained in:
290
backend/internal/client/client.go
Normal file
290
backend/internal/client/client.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/packages/respjson"
|
||||
"github.com/openai/openai-go/v3/shared"
|
||||
"reichard.io/aethera/internal/store"
|
||||
"reichard.io/aethera/internal/types"
|
||||
"reichard.io/aethera/pkg/ptr"
|
||||
"reichard.io/aethera/pkg/slices"
|
||||
)
|
||||
|
||||
type StreamCallback func(*MessageChunk) error
|
||||
|
||||
type Client struct {
|
||||
oaiClient *openai.Client
|
||||
}
|
||||
|
||||
func (c *Client) GetModels(ctx context.Context) ([]Model, error) {
|
||||
// Get Models
|
||||
currPage, err := c.oaiClient.Models.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allData := currPage.Data
|
||||
|
||||
// Pagination
|
||||
for {
|
||||
currPage, err = currPage.GetNextPage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if currPage == nil {
|
||||
break
|
||||
}
|
||||
allData = append(allData, currPage.Data...)
|
||||
}
|
||||
|
||||
// Convert
|
||||
return slices.Map(allData, fromOpenAIModel), nil
|
||||
}
|
||||
|
||||
func (c *Client) GenerateImages(ctx context.Context, body openai.ImageGenerateParams) ([]openai.Image, error) {
|
||||
// Generate Images
|
||||
resp, err := c.oaiClient.Images.Generate(ctx, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Data, nil
|
||||
}
|
||||
|
||||
func (c *Client) EditImage(ctx context.Context, body openai.ImageEditParams) ([]openai.Image, error) {
|
||||
// Edit Image
|
||||
resp, err := c.oaiClient.Images.Edit(ctx, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Data, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, model string, cb StreamCallback) (string, error) {
|
||||
// Ensure Callback
|
||||
if cb == nil {
|
||||
cb = func(mc *MessageChunk) error { return nil }
|
||||
}
|
||||
|
||||
// Map Messages
|
||||
messages := slices.Map(chatMessages, func(m *store.Message) openai.ChatCompletionMessageParamUnion {
|
||||
if m.Role == "user" {
|
||||
return openai.UserMessage(m.Content)
|
||||
}
|
||||
return openai.AssistantMessage(m.Content)
|
||||
})
|
||||
|
||||
// Create Request
|
||||
chatReq := openai.ChatCompletionNewParams{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
StreamOptions: openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
},
|
||||
}
|
||||
chatReq.SetExtraFields(map[string]any{
|
||||
"timings_per_token": true, // Llama.cpp
|
||||
})
|
||||
|
||||
// Perform Request & Allocate Stats
|
||||
msgStats := types.MessageStats{StartTime: time.Now()}
|
||||
stream := c.oaiClient.Chat.Completions.NewStreaming(ctx, chatReq)
|
||||
|
||||
// Iterate Stream
|
||||
var respContent string
|
||||
for stream.Next() {
|
||||
// Check Context
|
||||
if ctx.Err() != nil {
|
||||
return respContent, ctx.Err()
|
||||
}
|
||||
|
||||
// Load Chunk
|
||||
chunk := stream.Current()
|
||||
msgChunk := &MessageChunk{Stats: &msgStats}
|
||||
|
||||
// Populate Timings
|
||||
sendUpdate := populateLlamaCPPTimings(&msgStats, chunk.JSON.ExtraFields)
|
||||
sendUpdate = populateUsageTimings(&msgStats, chunk.Usage) || sendUpdate
|
||||
|
||||
if len(chunk.Choices) > 0 {
|
||||
delta := chunk.Choices[0].Delta
|
||||
|
||||
// Check Thinking
|
||||
if thinkingField, found := delta.JSON.ExtraFields["reasoning_content"]; found {
|
||||
var thinkingContent string
|
||||
if err := json.Unmarshal([]byte(thinkingField.Raw()), &thinkingContent); err != nil {
|
||||
return respContent, fmt.Errorf("thinking unmarshal error: %w", err)
|
||||
} else if thinkingContent != "" {
|
||||
msgStats.RecordFirstToken()
|
||||
sendUpdate = true
|
||||
msgChunk.Thinking = ptr.Of(thinkingContent)
|
||||
}
|
||||
}
|
||||
|
||||
// Check Content
|
||||
if delta.Content != "" {
|
||||
msgStats.RecordFirstToken()
|
||||
sendUpdate = true
|
||||
msgChunk.Message = ptr.Of(delta.Content)
|
||||
respContent += delta.Content
|
||||
}
|
||||
}
|
||||
|
||||
// Send Timings
|
||||
if sendUpdate {
|
||||
msgStats.CalculateDerived()
|
||||
if err := cb(msgChunk); err != nil {
|
||||
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check Error
|
||||
if err := stream.Err(); err != nil {
|
||||
return respContent, fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
|
||||
// Send Final Chunk
|
||||
msgStats.RecordLastToken()
|
||||
msgStats.CalculateDerived()
|
||||
if err := cb(&MessageChunk{Stats: &msgStats}); err != nil {
|
||||
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
||||
}
|
||||
|
||||
return respContent, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateTitle(ctx context.Context, userMessage, model string) (string, error) {
|
||||
prompt := "You are an agent responsible for creating titles for chats based on the initial message. " +
|
||||
"Your titles should be succinct and short. Respond with JUST the chat title. Initial Message: \n\n" + userMessage
|
||||
|
||||
// Generate Text Stream
|
||||
output, err := c.SendMessage(ctx, []*store.Message{{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
}}, model, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sent message: %w", err)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (c *Client) StructuredOutput(ctx context.Context, target any, prompt, model string) error {
|
||||
// Validate Target Pointer
|
||||
v := reflect.ValueOf(target)
|
||||
if v.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("target must be a pointer, got %T", target)
|
||||
}
|
||||
if v.IsNil() {
|
||||
return fmt.Errorf("target pointer is nil")
|
||||
}
|
||||
|
||||
// Validate Target Struct
|
||||
elem := v.Elem()
|
||||
if elem.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("target must be a pointer to struct, got pointer to %s", elem.Kind())
|
||||
}
|
||||
|
||||
// Build Schema
|
||||
schema, err := buildJSONSchema(elem.Type())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build schema: %w", err)
|
||||
}
|
||||
|
||||
// Perform Request
|
||||
resp, err := c.oaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
|
||||
Model: model,
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
openai.UserMessage(prompt),
|
||||
},
|
||||
ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{
|
||||
OfJSONSchema: schema,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("API call failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse Response
|
||||
content := resp.Choices[0].Message.Content
|
||||
if err := json.Unmarshal([]byte(content), target); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildJSONSchema(rType reflect.Type) (*shared.ResponseFormatJSONSchemaParam, error) {
|
||||
schema, err := jsonschema.ForType(rType, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.ResponseFormatJSONSchemaParam{
|
||||
JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
|
||||
Name: rType.Name(),
|
||||
Schema: map[string]any{
|
||||
"type": schema.Type,
|
||||
"properties": schema.Properties,
|
||||
"required": schema.Required,
|
||||
"additionalProperties": false,
|
||||
},
|
||||
Strict: openai.Bool(true),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func populateLlamaCPPTimings(msgStats *types.MessageStats, extraFields map[string]respjson.Field) bool {
|
||||
rawTimings, found := extraFields["timings"]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
var llamaTimings llamaCPPTimings
|
||||
if err := json.Unmarshal([]byte(rawTimings.Raw()), &llamaTimings); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if llamaTimings.PromptN != 0 {
|
||||
msgStats.PromptTokens = ptr.Of(int32(llamaTimings.PromptN))
|
||||
}
|
||||
if llamaTimings.PredictedN != 0 {
|
||||
msgStats.GeneratedTokens = ptr.Of(int32(llamaTimings.PredictedN))
|
||||
}
|
||||
|
||||
msgStats.PromptPerSec = ptr.Of(float32(llamaTimings.PromptPerSecond))
|
||||
msgStats.GeneratedPerSec = ptr.Of(float32(llamaTimings.PredictedPerSecond))
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func populateUsageTimings(msgStats *types.MessageStats, usage openai.CompletionUsage) (didChange bool) {
|
||||
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if msgStats.PromptTokens == nil {
|
||||
didChange = true
|
||||
msgStats.PromptTokens = ptr.Of(int32(usage.PromptTokens))
|
||||
}
|
||||
|
||||
if msgStats.GeneratedTokens == nil {
|
||||
didChange = true
|
||||
reasoningTokens := usage.CompletionTokensDetails.ReasoningTokens
|
||||
msgStats.GeneratedTokens = ptr.Of(int32(usage.CompletionTokens + reasoningTokens))
|
||||
}
|
||||
|
||||
return didChange
|
||||
}
|
||||
|
||||
func NewClient(baseURL *url.URL) *Client {
|
||||
oaiClient := openai.NewClient(option.WithBaseURL(baseURL.String()))
|
||||
return &Client{oaiClient: &oaiClient}
|
||||
}
|
||||
Reference in New Issue
Block a user