feat: stream persistent
This commit is contained in:
194
backend/internal/api/generation.go
Normal file
194
backend/internal/api/generation.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"reichard.io/aethera/internal/store"
|
||||
"reichard.io/aethera/internal/types"
|
||||
)
|
||||
|
||||
var errGenerationActive = errors.New("generation already active")
|
||||
|
||||
type generationManager struct {
|
||||
mu sync.RWMutex
|
||||
generations map[uuid.UUID]*generation
|
||||
}
|
||||
|
||||
type generation struct {
|
||||
mu sync.RWMutex
|
||||
subscribers map[chan *MessageChunk]struct{}
|
||||
done chan struct{}
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newGenerationManager() *generationManager {
|
||||
return &generationManager{generations: make(map[uuid.UUID]*generation)}
|
||||
}
|
||||
|
||||
func (m *generationManager) start(chatID uuid.UUID, prepare func(*generation) error, run func(*generation)) error {
|
||||
m.mu.Lock()
|
||||
if _, found := m.generations[chatID]; found {
|
||||
m.mu.Unlock()
|
||||
return errGenerationActive
|
||||
}
|
||||
|
||||
// Reserve Generation
|
||||
gen := &generation{
|
||||
subscribers: make(map[chan *MessageChunk]struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
m.generations[chatID] = gen
|
||||
m.mu.Unlock()
|
||||
|
||||
// Prepare Generation - This runs while the generation is reserved so a
|
||||
// concurrent request cannot persist duplicate user/assistant messages.
|
||||
if err := prepare(gen); err != nil {
|
||||
gen.close()
|
||||
m.mu.Lock()
|
||||
delete(m.generations, chatID)
|
||||
m.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Run Generation
|
||||
go func() {
|
||||
defer func() {
|
||||
gen.close()
|
||||
m.mu.Lock()
|
||||
delete(m.generations, chatID)
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
run(gen)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *generationManager) subscribe(chatID uuid.UUID) (<-chan *MessageChunk, func(), bool) {
|
||||
m.mu.RLock()
|
||||
gen, found := m.generations[chatID]
|
||||
m.mu.RUnlock()
|
||||
if !found {
|
||||
return nil, func() {}, false
|
||||
}
|
||||
|
||||
ch := gen.subscribe()
|
||||
return ch, func() { gen.unsubscribe(ch) }, true
|
||||
}
|
||||
|
||||
func (g *generation) subscribe() chan *MessageChunk {
|
||||
ch := make(chan *MessageChunk, 64)
|
||||
|
||||
// Add Subscriber
|
||||
g.mu.Lock()
|
||||
if g.closed {
|
||||
close(ch)
|
||||
} else {
|
||||
g.subscribers[ch] = struct{}{}
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
func (g *generation) unsubscribe(ch chan *MessageChunk) {
|
||||
// Remove Subscriber
|
||||
g.mu.Lock()
|
||||
if _, found := g.subscribers[ch]; found {
|
||||
delete(g.subscribers, ch)
|
||||
close(ch)
|
||||
}
|
||||
g.mu.Unlock()
|
||||
}
|
||||
|
||||
func (g *generation) broadcast(chunk *MessageChunk) {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
|
||||
// Broadcast Chunk
|
||||
for subscriber := range g.subscribers {
|
||||
select {
|
||||
case subscriber <- cloneMessageChunk(chunk):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *generation) close() {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
// Close Subscribers
|
||||
if g.closed {
|
||||
return
|
||||
}
|
||||
g.closed = true
|
||||
close(g.done)
|
||||
for subscriber := range g.subscribers {
|
||||
close(subscriber)
|
||||
delete(g.subscribers, subscriber)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneMessageChunk(chunk *MessageChunk) *MessageChunk {
|
||||
if chunk == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clone Chunk
|
||||
cloned := &MessageChunk{
|
||||
Chat: chunk.Chat,
|
||||
UserMessage: cloneStoreMessage(chunk.UserMessage),
|
||||
AssistantMessage: cloneStoreMessage(chunk.AssistantMessage),
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneStoreMessage(msg *store.Message) *store.Message {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clone Message
|
||||
cloned := *msg
|
||||
if msg.Stats != nil {
|
||||
stats := *msg.Stats
|
||||
cloned.Stats = &stats
|
||||
cloneMessageStatsPointers(msg.Stats, cloned.Stats)
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneMessageStatsPointers(src, dst *types.MessageStats) {
|
||||
// Clone Pointer Fields
|
||||
if src.EndTime != nil {
|
||||
v := *src.EndTime
|
||||
dst.EndTime = &v
|
||||
}
|
||||
if src.PromptTokens != nil {
|
||||
v := *src.PromptTokens
|
||||
dst.PromptTokens = &v
|
||||
}
|
||||
if src.GeneratedTokens != nil {
|
||||
v := *src.GeneratedTokens
|
||||
dst.GeneratedTokens = &v
|
||||
}
|
||||
if src.PromptPerSec != nil {
|
||||
v := *src.PromptPerSec
|
||||
dst.PromptPerSec = &v
|
||||
}
|
||||
if src.GeneratedPerSec != nil {
|
||||
v := *src.GeneratedPerSec
|
||||
dst.GeneratedPerSec = &v
|
||||
}
|
||||
if src.TimeToFirstToken != nil {
|
||||
v := *src.TimeToFirstToken
|
||||
dst.TimeToFirstToken = &v
|
||||
}
|
||||
if src.TimeToLastToken != nil {
|
||||
v := *src.TimeToLastToken
|
||||
dst.TimeToLastToken = &v
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user