initial commit
This commit is contained in:
290
backend/internal/client/client.go
Normal file
290
backend/internal/client/client.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/packages/respjson"
|
||||
"github.com/openai/openai-go/v3/shared"
|
||||
"reichard.io/aethera/internal/store"
|
||||
"reichard.io/aethera/internal/types"
|
||||
"reichard.io/aethera/pkg/ptr"
|
||||
"reichard.io/aethera/pkg/slices"
|
||||
)
|
||||
|
||||
type StreamCallback func(*MessageChunk) error
|
||||
|
||||
type Client struct {
|
||||
oaiClient *openai.Client
|
||||
}
|
||||
|
||||
func (c *Client) GetModels(ctx context.Context) ([]Model, error) {
|
||||
// Get Models
|
||||
currPage, err := c.oaiClient.Models.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allData := currPage.Data
|
||||
|
||||
// Pagination
|
||||
for {
|
||||
currPage, err = currPage.GetNextPage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if currPage == nil {
|
||||
break
|
||||
}
|
||||
allData = append(allData, currPage.Data...)
|
||||
}
|
||||
|
||||
// Convert
|
||||
return slices.Map(allData, fromOpenAIModel), nil
|
||||
}
|
||||
|
||||
func (c *Client) GenerateImages(ctx context.Context, body openai.ImageGenerateParams) ([]openai.Image, error) {
|
||||
// Generate Images
|
||||
resp, err := c.oaiClient.Images.Generate(ctx, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Data, nil
|
||||
}
|
||||
|
||||
func (c *Client) EditImage(ctx context.Context, body openai.ImageEditParams) ([]openai.Image, error) {
|
||||
// Edit Image
|
||||
resp, err := c.oaiClient.Images.Edit(ctx, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Data, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, model string, cb StreamCallback) (string, error) {
|
||||
// Ensure Callback
|
||||
if cb == nil {
|
||||
cb = func(mc *MessageChunk) error { return nil }
|
||||
}
|
||||
|
||||
// Map Messages
|
||||
messages := slices.Map(chatMessages, func(m *store.Message) openai.ChatCompletionMessageParamUnion {
|
||||
if m.Role == "user" {
|
||||
return openai.UserMessage(m.Content)
|
||||
}
|
||||
return openai.AssistantMessage(m.Content)
|
||||
})
|
||||
|
||||
// Create Request
|
||||
chatReq := openai.ChatCompletionNewParams{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
StreamOptions: openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
},
|
||||
}
|
||||
chatReq.SetExtraFields(map[string]any{
|
||||
"timings_per_token": true, // Llama.cpp
|
||||
})
|
||||
|
||||
// Perform Request & Allocate Stats
|
||||
msgStats := types.MessageStats{StartTime: time.Now()}
|
||||
stream := c.oaiClient.Chat.Completions.NewStreaming(ctx, chatReq)
|
||||
|
||||
// Iterate Stream
|
||||
var respContent string
|
||||
for stream.Next() {
|
||||
// Check Context
|
||||
if ctx.Err() != nil {
|
||||
return respContent, ctx.Err()
|
||||
}
|
||||
|
||||
// Load Chunk
|
||||
chunk := stream.Current()
|
||||
msgChunk := &MessageChunk{Stats: &msgStats}
|
||||
|
||||
// Populate Timings
|
||||
sendUpdate := populateLlamaCPPTimings(&msgStats, chunk.JSON.ExtraFields)
|
||||
sendUpdate = populateUsageTimings(&msgStats, chunk.Usage) || sendUpdate
|
||||
|
||||
if len(chunk.Choices) > 0 {
|
||||
delta := chunk.Choices[0].Delta
|
||||
|
||||
// Check Thinking
|
||||
if thinkingField, found := delta.JSON.ExtraFields["reasoning_content"]; found {
|
||||
var thinkingContent string
|
||||
if err := json.Unmarshal([]byte(thinkingField.Raw()), &thinkingContent); err != nil {
|
||||
return respContent, fmt.Errorf("thinking unmarshal error: %w", err)
|
||||
} else if thinkingContent != "" {
|
||||
msgStats.RecordFirstToken()
|
||||
sendUpdate = true
|
||||
msgChunk.Thinking = ptr.Of(thinkingContent)
|
||||
}
|
||||
}
|
||||
|
||||
// Check Content
|
||||
if delta.Content != "" {
|
||||
msgStats.RecordFirstToken()
|
||||
sendUpdate = true
|
||||
msgChunk.Message = ptr.Of(delta.Content)
|
||||
respContent += delta.Content
|
||||
}
|
||||
}
|
||||
|
||||
// Send Timings
|
||||
if sendUpdate {
|
||||
msgStats.CalculateDerived()
|
||||
if err := cb(msgChunk); err != nil {
|
||||
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check Error
|
||||
if err := stream.Err(); err != nil {
|
||||
return respContent, fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
|
||||
// Send Final Chunk
|
||||
msgStats.RecordLastToken()
|
||||
msgStats.CalculateDerived()
|
||||
if err := cb(&MessageChunk{Stats: &msgStats}); err != nil {
|
||||
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
||||
}
|
||||
|
||||
return respContent, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateTitle(ctx context.Context, userMessage, model string) (string, error) {
|
||||
prompt := "You are an agent responsible for creating titles for chats based on the initial message. " +
|
||||
"Your titles should be succinct and short. Respond with JUST the chat title. Initial Message: \n\n" + userMessage
|
||||
|
||||
// Generate Text Stream
|
||||
output, err := c.SendMessage(ctx, []*store.Message{{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
}}, model, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sent message: %w", err)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (c *Client) StructuredOutput(ctx context.Context, target any, prompt, model string) error {
|
||||
// Validate Target Pointer
|
||||
v := reflect.ValueOf(target)
|
||||
if v.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("target must be a pointer, got %T", target)
|
||||
}
|
||||
if v.IsNil() {
|
||||
return fmt.Errorf("target pointer is nil")
|
||||
}
|
||||
|
||||
// Validate Target Struct
|
||||
elem := v.Elem()
|
||||
if elem.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("target must be a pointer to struct, got pointer to %s", elem.Kind())
|
||||
}
|
||||
|
||||
// Build Schema
|
||||
schema, err := buildJSONSchema(elem.Type())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build schema: %w", err)
|
||||
}
|
||||
|
||||
// Perform Request
|
||||
resp, err := c.oaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
|
||||
Model: model,
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
openai.UserMessage(prompt),
|
||||
},
|
||||
ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{
|
||||
OfJSONSchema: schema,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("API call failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse Response
|
||||
content := resp.Choices[0].Message.Content
|
||||
if err := json.Unmarshal([]byte(content), target); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildJSONSchema(rType reflect.Type) (*shared.ResponseFormatJSONSchemaParam, error) {
|
||||
schema, err := jsonschema.ForType(rType, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.ResponseFormatJSONSchemaParam{
|
||||
JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
|
||||
Name: rType.Name(),
|
||||
Schema: map[string]any{
|
||||
"type": schema.Type,
|
||||
"properties": schema.Properties,
|
||||
"required": schema.Required,
|
||||
"additionalProperties": false,
|
||||
},
|
||||
Strict: openai.Bool(true),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func populateLlamaCPPTimings(msgStats *types.MessageStats, extraFields map[string]respjson.Field) bool {
|
||||
rawTimings, found := extraFields["timings"]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
var llamaTimings llamaCPPTimings
|
||||
if err := json.Unmarshal([]byte(rawTimings.Raw()), &llamaTimings); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if llamaTimings.PromptN != 0 {
|
||||
msgStats.PromptTokens = ptr.Of(int32(llamaTimings.PromptN))
|
||||
}
|
||||
if llamaTimings.PredictedN != 0 {
|
||||
msgStats.GeneratedTokens = ptr.Of(int32(llamaTimings.PredictedN))
|
||||
}
|
||||
|
||||
msgStats.PromptPerSec = ptr.Of(float32(llamaTimings.PromptPerSecond))
|
||||
msgStats.GeneratedPerSec = ptr.Of(float32(llamaTimings.PredictedPerSecond))
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func populateUsageTimings(msgStats *types.MessageStats, usage openai.CompletionUsage) (didChange bool) {
|
||||
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if msgStats.PromptTokens == nil {
|
||||
didChange = true
|
||||
msgStats.PromptTokens = ptr.Of(int32(usage.PromptTokens))
|
||||
}
|
||||
|
||||
if msgStats.GeneratedTokens == nil {
|
||||
didChange = true
|
||||
reasoningTokens := usage.CompletionTokensDetails.ReasoningTokens
|
||||
msgStats.GeneratedTokens = ptr.Of(int32(usage.CompletionTokens + reasoningTokens))
|
||||
}
|
||||
|
||||
return didChange
|
||||
}
|
||||
|
||||
func NewClient(baseURL *url.URL) *Client {
|
||||
oaiClient := openai.NewClient(option.WithBaseURL(baseURL.String()))
|
||||
return &Client{oaiClient: &oaiClient}
|
||||
}
|
||||
79
backend/internal/client/client_test.go
Normal file
79
backend/internal/client/client_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"reichard.io/aethera/internal/store"
|
||||
)
|
||||
|
||||
const model = "devstral-small-2-instruct"
|
||||
|
||||
func TestSendMessage(t *testing.T) {
|
||||
// Initialize Client
|
||||
baseURL, err := url.Parse("https://llm-api.va.reichard.io/v1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse base URL: %v", err)
|
||||
}
|
||||
client := NewClient(baseURL)
|
||||
|
||||
// Create Context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Generate Text Stream
|
||||
var buf bytes.Buffer
|
||||
_, err = client.SendMessage(ctx, []*store.Message{{
|
||||
Role: "user",
|
||||
Content: "Hello, how are you?",
|
||||
}}, model, func(mc *MessageChunk) error {
|
||||
if mc.Message != nil {
|
||||
_, err := buf.Write([]byte(*mc.Message))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate text stream: %v", err)
|
||||
}
|
||||
|
||||
// Verify Results
|
||||
output := buf.String()
|
||||
if output == "" {
|
||||
t.Error("No content was written to the buffer")
|
||||
} else {
|
||||
t.Logf("Successfully received %d bytes from the stream", len(output))
|
||||
t.Logf("Output: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeChat(t *testing.T) {
|
||||
// Initialize Client
|
||||
baseURL, err := url.Parse("https://llm-api.va.reichard.io/v1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse base URL: %v", err)
|
||||
}
|
||||
client := NewClient(baseURL)
|
||||
|
||||
// Create Context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Generate Text Stream
|
||||
userMessage := "Write me a go program that reads in a zip file and prints the contents along with their sizes and mimetype."
|
||||
output, err := client.CreateTitle(ctx, userMessage, model)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate text stream: %v", err)
|
||||
}
|
||||
|
||||
// Verify Results
|
||||
if output == "" {
|
||||
t.Error("No content was written to the buffer")
|
||||
} else {
|
||||
t.Logf("Successfully received %d bytes from the stream", len(output))
|
||||
t.Logf("Output: %s", output)
|
||||
}
|
||||
}
|
||||
41
backend/internal/client/convert.go
Normal file
41
backend/internal/client/convert.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
)
|
||||
|
||||
func fromOpenAIModel(m openai.Model) Model {
|
||||
newModel := Model{
|
||||
Model: m,
|
||||
Name: m.ID,
|
||||
}
|
||||
|
||||
extraFields := make(map[string]any)
|
||||
for k, v := range m.JSON.ExtraFields {
|
||||
var val any
|
||||
if err := json.Unmarshal([]byte(v.Raw()), &val); err != nil {
|
||||
continue
|
||||
}
|
||||
extraFields[k] = val
|
||||
}
|
||||
|
||||
// Extract Name
|
||||
if rawName, found := extraFields["name"]; found {
|
||||
if name, ok := rawName.(string); ok {
|
||||
newModel.Name = name
|
||||
}
|
||||
}
|
||||
|
||||
// Extract Meta
|
||||
if rawMeta, found := extraFields["meta"]; found {
|
||||
if parsedMeta, ok := rawMeta.(map[string]any); ok {
|
||||
if llamaMeta, ok := parsedMeta["llamaswap"].(map[string]any); ok {
|
||||
newModel.Meta = llamaMeta
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return newModel
|
||||
}
|
||||
31
backend/internal/client/types.go
Normal file
31
backend/internal/client/types.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"github.com/openai/openai-go/v3"
|
||||
"reichard.io/aethera/internal/types"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
openai.Model
|
||||
|
||||
Name string `json:"name"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
type MessageChunk struct {
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Message *string `json:"message,omitempty"`
|
||||
Stats *types.MessageStats `json:"stats,omitempty"`
|
||||
}
|
||||
|
||||
type llamaCPPTimings struct {
|
||||
CacheN int `json:"cache_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedPerSecond float64 `json:"predicted_per_second"`
|
||||
PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||
PromptPerTokenMS float64 `json:"prompt_per_token_ms"`
|
||||
}
|
||||
Reference in New Issue
Block a user