package api import ( "context" "errors" "slices" "sync" "github.com/google/uuid" "reichard.io/aethera/internal/store" "reichard.io/aethera/internal/types" ) var errGenerationActive = errors.New("generation already active") type generationManager struct { mu sync.RWMutex generations map[uuid.UUID]*generation } type generation struct { mu sync.RWMutex ctx context.Context cancel context.CancelFunc subscribers map[chan *MessageChunk]struct{} done chan struct{} closed bool } func newGenerationManager() *generationManager { return &generationManager{generations: make(map[uuid.UUID]*generation)} } func (m *generationManager) start(chatID uuid.UUID, prepare func(*generation) error, run func(*generation)) error { m.mu.Lock() if _, found := m.generations[chatID]; found { m.mu.Unlock() return errGenerationActive } // Reserve Generation ctx, cancel := context.WithCancel(context.Background()) gen := &generation{ ctx: ctx, cancel: cancel, subscribers: make(map[chan *MessageChunk]struct{}), done: make(chan struct{}), } m.generations[chatID] = gen m.mu.Unlock() // Prepare Generation - This runs while the generation is reserved so a // concurrent request cannot persist duplicate user/assistant messages. if err := prepare(gen); err != nil { gen.close() m.mu.Lock() delete(m.generations, chatID) m.mu.Unlock() return err } // Run Generation go func() { defer func() { gen.close() m.mu.Lock() delete(m.generations, chatID) m.mu.Unlock() }() run(gen) }() return nil } func (m *generationManager) subscribe(chatID uuid.UUID) (<-chan *MessageChunk, func(), bool) { m.mu.RLock() gen, found := m.generations[chatID] m.mu.RUnlock() if !found { return nil, func() {}, false } ch := gen.subscribe() return ch, func() { gen.unsubscribe(ch) }, true } func (m *generationManager) stop(chatID uuid.UUID) bool { m.mu.RLock() gen, found := m.generations[chatID] m.mu.RUnlock() if !found { return false } // Cancel Generation gen.cancel() return true } func (g *generation) subscribe() chan *MessageChunk { ch := make(chan *MessageChunk, 64) // Add Subscriber g.mu.Lock() if g.closed { close(ch) } else { g.subscribers[ch] = struct{}{} } g.mu.Unlock() return ch } func (g *generation) unsubscribe(ch chan *MessageChunk) { // Remove Subscriber g.mu.Lock() if _, found := g.subscribers[ch]; found { delete(g.subscribers, ch) close(ch) } g.mu.Unlock() } func (g *generation) broadcast(chunk *MessageChunk) { g.mu.RLock() defer g.mu.RUnlock() // Broadcast Chunk for subscriber := range g.subscribers { select { case subscriber <- cloneMessageChunk(chunk): default: } } } func (g *generation) close() { g.mu.Lock() defer g.mu.Unlock() // Close Subscribers if g.closed { return } g.closed = true g.cancel() close(g.done) for subscriber := range g.subscribers { close(subscriber) delete(g.subscribers, subscriber) } } func cloneMessageChunk(chunk *MessageChunk) *MessageChunk { if chunk == nil { return nil } // Clone Chunk cloned := &MessageChunk{ Chat: chunk.Chat, UserMessage: cloneStoreMessage(chunk.UserMessage), AssistantMessage: cloneStoreMessage(chunk.AssistantMessage), } return cloned } func cloneStoreMessage(msg *store.Message) *store.Message { if msg == nil { return nil } // Clone Message cloned := *msg cloned.Images = slices.Clone(msg.Images) if msg.Stats != nil { stats := *msg.Stats cloned.Stats = &stats cloneMessageStatsPointers(msg.Stats, cloned.Stats) } return &cloned } func cloneMessageStatsPointers(src, dst *types.MessageStats) { // Clone Pointer Fields if src.EndTime != nil { v := *src.EndTime dst.EndTime = &v } if src.PromptTokens != nil { v := *src.PromptTokens dst.PromptTokens = &v } if src.GeneratedTokens != nil { v := *src.GeneratedTokens dst.GeneratedTokens = &v } if src.PromptPerSec != nil { v := *src.PromptPerSec dst.PromptPerSec = &v } if src.GeneratedPerSec != nil { v := *src.GeneratedPerSec dst.GeneratedPerSec = &v } if src.TimeToFirstToken != nil { v := *src.TimeToFirstToken dst.TimeToFirstToken = &v } if src.TimeToLastToken != nil { v := *src.TimeToLastToken dst.TimeToLastToken = &v } }