Files
aethera/backend/internal/api/generation.go

217 lines
4.2 KiB
Go

package api
import (
"context"
"errors"
"slices"
"sync"
"github.com/google/uuid"
"reichard.io/aethera/internal/store"
"reichard.io/aethera/internal/types"
)
var errGenerationActive = errors.New("generation already active")
type generationManager struct {
mu sync.RWMutex
generations map[uuid.UUID]*generation
}
type generation struct {
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
subscribers map[chan *MessageChunk]struct{}
done chan struct{}
closed bool
}
func newGenerationManager() *generationManager {
return &generationManager{generations: make(map[uuid.UUID]*generation)}
}
func (m *generationManager) start(chatID uuid.UUID, prepare func(*generation) error, run func(*generation)) error {
m.mu.Lock()
if _, found := m.generations[chatID]; found {
m.mu.Unlock()
return errGenerationActive
}
// Reserve Generation
ctx, cancel := context.WithCancel(context.Background())
gen := &generation{
ctx: ctx,
cancel: cancel,
subscribers: make(map[chan *MessageChunk]struct{}),
done: make(chan struct{}),
}
m.generations[chatID] = gen
m.mu.Unlock()
// Prepare Generation - This runs while the generation is reserved so a
// concurrent request cannot persist duplicate user/assistant messages.
if err := prepare(gen); err != nil {
gen.close()
m.mu.Lock()
delete(m.generations, chatID)
m.mu.Unlock()
return err
}
// Run Generation
go func() {
defer func() {
gen.close()
m.mu.Lock()
delete(m.generations, chatID)
m.mu.Unlock()
}()
run(gen)
}()
return nil
}
func (m *generationManager) subscribe(chatID uuid.UUID) (<-chan *MessageChunk, func(), bool) {
m.mu.RLock()
gen, found := m.generations[chatID]
m.mu.RUnlock()
if !found {
return nil, func() {}, false
}
ch := gen.subscribe()
return ch, func() { gen.unsubscribe(ch) }, true
}
func (m *generationManager) stop(chatID uuid.UUID) bool {
m.mu.RLock()
gen, found := m.generations[chatID]
m.mu.RUnlock()
if !found {
return false
}
// Cancel Generation
gen.cancel()
return true
}
func (g *generation) subscribe() chan *MessageChunk {
ch := make(chan *MessageChunk, 64)
// Add Subscriber
g.mu.Lock()
if g.closed {
close(ch)
} else {
g.subscribers[ch] = struct{}{}
}
g.mu.Unlock()
return ch
}
func (g *generation) unsubscribe(ch chan *MessageChunk) {
// Remove Subscriber
g.mu.Lock()
if _, found := g.subscribers[ch]; found {
delete(g.subscribers, ch)
close(ch)
}
g.mu.Unlock()
}
func (g *generation) broadcast(chunk *MessageChunk) {
g.mu.RLock()
defer g.mu.RUnlock()
// Broadcast Chunk
for subscriber := range g.subscribers {
select {
case subscriber <- cloneMessageChunk(chunk):
default:
}
}
}
func (g *generation) close() {
g.mu.Lock()
defer g.mu.Unlock()
// Close Subscribers
if g.closed {
return
}
g.closed = true
g.cancel()
close(g.done)
for subscriber := range g.subscribers {
close(subscriber)
delete(g.subscribers, subscriber)
}
}
func cloneMessageChunk(chunk *MessageChunk) *MessageChunk {
if chunk == nil {
return nil
}
// Clone Chunk
cloned := &MessageChunk{
Chat: chunk.Chat,
UserMessage: cloneStoreMessage(chunk.UserMessage),
AssistantMessage: cloneStoreMessage(chunk.AssistantMessage),
}
return cloned
}
func cloneStoreMessage(msg *store.Message) *store.Message {
if msg == nil {
return nil
}
// Clone Message
cloned := *msg
cloned.Images = slices.Clone(msg.Images)
if msg.Stats != nil {
stats := *msg.Stats
cloned.Stats = &stats
cloneMessageStatsPointers(msg.Stats, cloned.Stats)
}
return &cloned
}
func cloneMessageStatsPointers(src, dst *types.MessageStats) {
// Clone Pointer Fields
if src.EndTime != nil {
v := *src.EndTime
dst.EndTime = &v
}
if src.PromptTokens != nil {
v := *src.PromptTokens
dst.PromptTokens = &v
}
if src.GeneratedTokens != nil {
v := *src.GeneratedTokens
dst.GeneratedTokens = &v
}
if src.PromptPerSec != nil {
v := *src.PromptPerSec
dst.PromptPerSec = &v
}
if src.GeneratedPerSec != nil {
v := *src.GeneratedPerSec
dst.GeneratedPerSec = &v
}
if src.TimeToFirstToken != nil {
v := *src.TimeToFirstToken
dst.TimeToFirstToken = &v
}
if src.TimeToLastToken != nil {
v := *src.TimeToLastToken
dst.TimeToLastToken = &v
}
}