All checks were successful
continuous-integration/drone/push Build is passing
vLLM sends thinking content in a "reasoning" delta field, unlike DeepSeek which uses "reasoning_content". Check both field names so thinking blocks render for vLLM-hosted models like qwen3.6-27b-thinking. Also update client tests to exercise thinking output and skip by default so they don't run in Drone CI (require live LLM API).
295 lines
7.7 KiB
Go
295 lines
7.7 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/url"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/google/jsonschema-go/jsonschema"
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/openai/openai-go/v3/option"
|
|
"github.com/openai/openai-go/v3/packages/respjson"
|
|
"github.com/openai/openai-go/v3/shared"
|
|
"reichard.io/aethera/internal/store"
|
|
"reichard.io/aethera/internal/types"
|
|
"reichard.io/aethera/pkg/ptr"
|
|
"reichard.io/aethera/pkg/slices"
|
|
)
|
|
|
|
type StreamCallback func(*MessageChunk) error
|
|
|
|
type Client struct {
|
|
oaiClient *openai.Client
|
|
}
|
|
|
|
func (c *Client) GetModels(ctx context.Context) ([]Model, error) {
|
|
// Get Models
|
|
currPage, err := c.oaiClient.Models.List(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allData := currPage.Data
|
|
|
|
// Pagination
|
|
for {
|
|
currPage, err = currPage.GetNextPage()
|
|
if err != nil {
|
|
return nil, err
|
|
} else if currPage == nil {
|
|
break
|
|
}
|
|
allData = append(allData, currPage.Data...)
|
|
}
|
|
|
|
// Convert
|
|
return slices.Map(allData, fromOpenAIModel), nil
|
|
}
|
|
|
|
func (c *Client) GenerateImages(ctx context.Context, body openai.ImageGenerateParams) ([]openai.Image, error) {
|
|
// Generate Images
|
|
resp, err := c.oaiClient.Images.Generate(ctx, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp.Data, nil
|
|
}
|
|
|
|
func (c *Client) EditImage(ctx context.Context, body openai.ImageEditParams) ([]openai.Image, error) {
|
|
// Edit Image
|
|
resp, err := c.oaiClient.Images.Edit(ctx, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp.Data, nil
|
|
}
|
|
|
|
func (c *Client) SendMessage(ctx context.Context, chatMessages []*store.Message, model string, cb StreamCallback) (string, error) {
|
|
// Ensure Callback
|
|
if cb == nil {
|
|
cb = func(mc *MessageChunk) error { return nil }
|
|
}
|
|
|
|
// Map Messages
|
|
messages := slices.Map(chatMessages, func(m *store.Message) openai.ChatCompletionMessageParamUnion {
|
|
if m.Role == "user" {
|
|
return openai.UserMessage(m.Content)
|
|
}
|
|
return openai.AssistantMessage(m.Content)
|
|
})
|
|
|
|
// Create Request
|
|
chatReq := openai.ChatCompletionNewParams{
|
|
Model: model,
|
|
Messages: messages,
|
|
StreamOptions: openai.ChatCompletionStreamOptionsParam{
|
|
IncludeUsage: openai.Bool(true),
|
|
},
|
|
}
|
|
chatReq.SetExtraFields(map[string]any{
|
|
"timings_per_token": true, // Llama.cpp
|
|
})
|
|
|
|
// Perform Request & Allocate Stats
|
|
msgStats := types.MessageStats{StartTime: time.Now()}
|
|
stream := c.oaiClient.Chat.Completions.NewStreaming(ctx, chatReq)
|
|
|
|
// Iterate Stream
|
|
var respContent string
|
|
for stream.Next() {
|
|
// Check Context
|
|
if ctx.Err() != nil {
|
|
return respContent, ctx.Err()
|
|
}
|
|
|
|
// Load Chunk
|
|
chunk := stream.Current()
|
|
msgChunk := &MessageChunk{Stats: &msgStats}
|
|
|
|
// Populate Timings
|
|
sendUpdate := populateLlamaCPPTimings(&msgStats, chunk.JSON.ExtraFields)
|
|
sendUpdate = populateUsageTimings(&msgStats, chunk.Usage) || sendUpdate
|
|
|
|
if len(chunk.Choices) > 0 {
|
|
delta := chunk.Choices[0].Delta
|
|
|
|
// Check Thinking - Support both "reasoning_content" (DeepSeek)
|
|
// and "reasoning" (vLLM) field names.
|
|
for _, thinkingKey := range []string{"reasoning_content", "reasoning"} {
|
|
if thinkingField, found := delta.JSON.ExtraFields[thinkingKey]; found {
|
|
var thinkingContent string
|
|
if err := json.Unmarshal([]byte(thinkingField.Raw()), &thinkingContent); err != nil {
|
|
return respContent, fmt.Errorf("thinking unmarshal error: %w", err)
|
|
} else if thinkingContent != "" {
|
|
msgStats.RecordFirstToken()
|
|
sendUpdate = true
|
|
msgChunk.Thinking = ptr.Of(thinkingContent)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
// Check Content
|
|
if delta.Content != "" {
|
|
msgStats.RecordFirstToken()
|
|
sendUpdate = true
|
|
msgChunk.Message = ptr.Of(delta.Content)
|
|
respContent += delta.Content
|
|
}
|
|
}
|
|
|
|
// Send Timings
|
|
if sendUpdate {
|
|
msgStats.CalculateDerived()
|
|
if err := cb(msgChunk); err != nil {
|
|
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check Error
|
|
if err := stream.Err(); err != nil {
|
|
return respContent, fmt.Errorf("stream error: %w", err)
|
|
}
|
|
|
|
// Send Final Chunk
|
|
msgStats.RecordLastToken()
|
|
msgStats.CalculateDerived()
|
|
if err := cb(&MessageChunk{Stats: &msgStats}); err != nil {
|
|
return respContent, fmt.Errorf("chunk callback error: %w", err)
|
|
}
|
|
|
|
return respContent, nil
|
|
}
|
|
|
|
func (c *Client) CreateTitle(ctx context.Context, userMessage, model string) (string, error) {
|
|
prompt := "You are an agent responsible for creating titles for chats based on the initial message. " +
|
|
"Your titles should be succinct and short. Respond with JUST the chat title. Initial Message: \n\n" + userMessage
|
|
|
|
// Generate Text Stream
|
|
output, err := c.SendMessage(ctx, []*store.Message{{
|
|
Role: "user",
|
|
Content: prompt,
|
|
}}, model, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sent message: %w", err)
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
func (c *Client) StructuredOutput(ctx context.Context, target any, prompt, model string) error {
|
|
// Validate Target Pointer
|
|
v := reflect.ValueOf(target)
|
|
if v.Kind() != reflect.Pointer {
|
|
return fmt.Errorf("target must be a pointer, got %T", target)
|
|
}
|
|
if v.IsNil() {
|
|
return fmt.Errorf("target pointer is nil")
|
|
}
|
|
|
|
// Validate Target Struct
|
|
elem := v.Elem()
|
|
if elem.Kind() != reflect.Struct {
|
|
return fmt.Errorf("target must be a pointer to struct, got pointer to %s", elem.Kind())
|
|
}
|
|
|
|
// Build Schema
|
|
schema, err := buildJSONSchema(elem.Type())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to build schema: %w", err)
|
|
}
|
|
|
|
// Perform Request
|
|
resp, err := c.oaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
|
|
Model: model,
|
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
|
openai.UserMessage(prompt),
|
|
},
|
|
ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{
|
|
OfJSONSchema: schema,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("API call failed: %w", err)
|
|
}
|
|
|
|
// Parse Response
|
|
content := resp.Choices[0].Message.Content
|
|
if err := json.Unmarshal([]byte(content), target); err != nil {
|
|
return fmt.Errorf("failed to unmarshal response: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func buildJSONSchema(rType reflect.Type) (*shared.ResponseFormatJSONSchemaParam, error) {
|
|
schema, err := jsonschema.ForType(rType, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &shared.ResponseFormatJSONSchemaParam{
|
|
JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
|
|
Name: rType.Name(),
|
|
Schema: map[string]any{
|
|
"type": schema.Type,
|
|
"properties": schema.Properties,
|
|
"required": schema.Required,
|
|
"additionalProperties": false,
|
|
},
|
|
Strict: openai.Bool(true),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func populateLlamaCPPTimings(msgStats *types.MessageStats, extraFields map[string]respjson.Field) bool {
|
|
rawTimings, found := extraFields["timings"]
|
|
if !found {
|
|
return false
|
|
}
|
|
|
|
var llamaTimings llamaCPPTimings
|
|
if err := json.Unmarshal([]byte(rawTimings.Raw()), &llamaTimings); err != nil {
|
|
return false
|
|
}
|
|
|
|
if llamaTimings.PromptN != 0 {
|
|
msgStats.PromptTokens = ptr.Of(int32(llamaTimings.PromptN))
|
|
}
|
|
if llamaTimings.PredictedN != 0 {
|
|
msgStats.GeneratedTokens = ptr.Of(int32(llamaTimings.PredictedN))
|
|
}
|
|
|
|
msgStats.PromptPerSec = ptr.Of(float32(llamaTimings.PromptPerSecond))
|
|
msgStats.GeneratedPerSec = ptr.Of(float32(llamaTimings.PredictedPerSecond))
|
|
|
|
return true
|
|
}
|
|
|
|
func populateUsageTimings(msgStats *types.MessageStats, usage openai.CompletionUsage) (didChange bool) {
|
|
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 {
|
|
return
|
|
}
|
|
|
|
if msgStats.PromptTokens == nil {
|
|
didChange = true
|
|
msgStats.PromptTokens = ptr.Of(int32(usage.PromptTokens))
|
|
}
|
|
|
|
if msgStats.GeneratedTokens == nil {
|
|
didChange = true
|
|
reasoningTokens := usage.CompletionTokensDetails.ReasoningTokens
|
|
msgStats.GeneratedTokens = ptr.Of(int32(usage.CompletionTokens + reasoningTokens))
|
|
}
|
|
|
|
return didChange
|
|
}
|
|
|
|
func NewClient(baseURL *url.URL) *Client {
|
|
oaiClient := openai.NewClient(option.WithBaseURL(baseURL.String()))
|
|
return &Client{oaiClient: &oaiClient}
|
|
}
|