package indexer import ( "context" "crypto/sha256" "database/sql" "fmt" "os" "path/filepath" "strings" "github.com/odvcencio/gotreesitter" "github.com/odvcencio/gotreesitter/grammars" "codexis/db" ) // ProgressFunc is called for each file being processed. // current is the 1-based index, total is the total file count, path is the file being processed. type ProgressFunc func(current, total int, path string) const defaultBatchSize = 100 // Indexer walks a codebase, extracts symbols via tree-sitter, and stores them in SQLite. type Indexer struct { db *sql.DB queries *db.Queries root string force bool BatchSize int OnProgress ProgressFunc } // New creates a new Indexer. func New(sqlDB *sql.DB, queries *db.Queries, root string, force bool) *Indexer { return &Indexer{ db: sqlDB, queries: queries, root: root, force: force, BatchSize: defaultBatchSize, } } // Stats holds indexing statistics. type Stats struct { FilesTotal int FilesIndexed int FilesSkipped int SymbolsTotal int } // Index walks the codebase and indexes all recognized files. func (idx *Indexer) Index(ctx context.Context) (*Stats, error) { files, err := WalkFiles(idx.root) if err != nil { return nil, fmt.Errorf("walking files: %w", err) } stats := &Stats{FilesTotal: len(files)} batchSize := idx.BatchSize if batchSize <= 0 { batchSize = defaultBatchSize } // Process files in transaction batches for batchStart := 0; batchStart < len(files); batchStart += batchSize { batchEnd := batchStart + batchSize if batchEnd > len(files) { batchEnd = len(files) } batch := files[batchStart:batchEnd] if err := idx.indexBatch(ctx, batch, batchStart, stats); err != nil { return nil, fmt.Errorf("indexing batch: %w", err) } } // Clean up files that no longer exist (in its own transaction) tx, err := idx.db.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("begin cleanup tx: %w", err) } defer tx.Rollback() txQueries := idx.queries.WithTx(tx) if err := txQueries.DeleteStaleFileContents(ctx, files); err != nil { return nil, fmt.Errorf("cleaning stale file contents: %w", err) } if err := txQueries.DeleteStaleFiles(ctx, files); err != nil { return nil, fmt.Errorf("cleaning stale files: %w", err) } if err := tx.Commit(); err != nil { return nil, fmt.Errorf("commit cleanup tx: %w", err) } return stats, nil } // indexBatch processes a slice of files within a single transaction. func (idx *Indexer) indexBatch(ctx context.Context, batch []string, offset int, stats *Stats) error { tx, err := idx.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin tx: %w", err) } defer tx.Rollback() txQueries := idx.queries.WithTx(tx) for i, relPath := range batch { if idx.OnProgress != nil { idx.OnProgress(offset+i+1, stats.FilesTotal, relPath) } indexed, symbolCount, err := idx.indexFile(ctx, txQueries, relPath) if err != nil { fmt.Fprintf(os.Stderr, "\rwarn: %s: %v\033[K\n", relPath, err) continue } if indexed { stats.FilesIndexed++ stats.SymbolsTotal += symbolCount } else { stats.FilesSkipped++ } } if err := tx.Commit(); err != nil { return fmt.Errorf("commit tx: %w", err) } return nil } func (idx *Indexer) indexFile(ctx context.Context, q *db.Queries, relPath string) (indexed bool, symbolCount int, err error) { absPath := filepath.Join(idx.root, relPath) src, err := os.ReadFile(absPath) if err != nil { return false, 0, fmt.Errorf("reading file: %w", err) } hash := fmt.Sprintf("%x", sha256.Sum256(src)) // Check if file has changed if !idx.force { existing, err := q.GetFileByPath(ctx, relPath) if err == nil && existing.Hash == hash { return false, 0, nil // unchanged } } // Detect language entry := grammars.DetectLanguage(filepath.Base(relPath)) if entry == nil { return false, 0, nil } // Check if this language has a tags query — skip parsing if not tagsQuery := grammars.ResolveTagsQuery(*entry) hasTagsQuery := tagsQuery != "" var tree *gotreesitter.Tree if hasTagsQuery { // Parse once, reuse tree for package extraction and tagging lang := entry.Language() parser := gotreesitter.NewParser(lang) parsedTree, parseErr := parser.Parse(src) if parseErr != nil { return false, 0, fmt.Errorf("parsing: %w", parseErr) } tree = parsedTree if tree != nil { defer tree.Release() } } // Extract package (uses tree if available, falls back to dir name) pkg := ExtractPackage(src, relPath, entry, tree) // Upsert file record file, err := q.UpsertFile(ctx, db.UpsertFileParams{ Path: relPath, Language: entry.Name, Package: sql.NullString{String: pkg, Valid: pkg != ""}, Hash: hash, }) if err != nil { return false, 0, fmt.Errorf("upserting file: %w", err) } // Store file content for FTS if err := q.UpsertFileContent(ctx, file.ID, string(src)); err != nil { return false, 0, fmt.Errorf("upserting file content: %w", err) } if !hasTagsQuery { return true, 0, nil } // Clear old symbols if err := q.DeleteSymbolsByFileID(ctx, file.ID); err != nil { return false, 0, fmt.Errorf("deleting old symbols: %w", err) } // Extract and store symbols tags := extractTags(src, entry, tree) defs := buildSymbolDefs(tags, file.ID, entry.Name) // Insert symbols in order, tracking DB IDs for parent resolution dbIDs := make([]int64, len(defs)) for i, def := range defs { // Resolve parent_id from local index to actual DB ID params := def.params if params.ParentID.Valid { parentIdx := params.ParentID.Int64 params.ParentID = sql.NullInt64{Int64: dbIDs[parentIdx], Valid: true} } id, err := q.InsertSymbol(ctx, params) if err != nil { return false, 0, fmt.Errorf("inserting symbol %q: %w", params.Name, err) } dbIDs[i] = id } return true, len(defs), nil } func extractTags(src []byte, entry *grammars.LangEntry, tree *gotreesitter.Tree) []gotreesitter.Tag { if tree == nil { return nil } lang := entry.Language() // ResolveTagsQuery returns the explicit TagsQuery if set, otherwise infers // one from the grammar's symbol table. tagsQuery := grammars.ResolveTagsQuery(*entry) if tagsQuery == "" { return nil } tagger, err := gotreesitter.NewTagger(lang, tagsQuery) if err != nil { return nil } return tagger.TagTree(tree) } type symbolDef struct { tag gotreesitter.Tag params db.InsertSymbolParams } func buildSymbolDefs(tags []gotreesitter.Tag, fileID int64, langName string) []symbolDef { // First pass: collect all definition tags var defs []symbolDef for _, tag := range tags { kind := tagKind(tag.Kind) if kind == "" { continue } exported := IsExported(tag.Name, langName) params := db.InsertSymbolParams{ FileID: fileID, Name: tag.Name, Kind: kind, Line: int64(tag.NameRange.StartPoint.Row) + 1, // 1-indexed LineEnd: sql.NullInt64{Int64: int64(tag.Range.EndPoint.Row) + 1, Valid: true}, Col: sql.NullInt64{Int64: int64(tag.NameRange.StartPoint.Column), Valid: true}, ColEnd: sql.NullInt64{Int64: int64(tag.NameRange.EndPoint.Column), Valid: true}, Exported: sql.NullBool{Bool: exported, Valid: true}, ParentID: sql.NullInt64{Valid: false}, } defs = append(defs, symbolDef{tag: tag, params: params}) } // Second pass: determine parent relationships based on range containment. // ParentID stores the local index — resolved to DB ID during insert. // Tree-sitter returns tags in document order (outer before inner), // so scanning backwards finds the nearest enclosing definition. for i := range defs { for j := i - 1; j >= 0; j-- { if containsRange(defs[j].tag.Range, defs[i].tag.Range) { defs[i].params.ParentID = sql.NullInt64{Int64: int64(j), Valid: true} break } } } return defs } func containsRange(outer, inner gotreesitter.Range) bool { return outer.StartByte <= inner.StartByte && outer.EndByte >= inner.EndByte } // tagKind converts a tree-sitter tag kind like "definition.function" to "function". // Returns empty string for non-definition tags. func tagKind(kind string) string { const prefix = "definition." if strings.HasPrefix(kind, prefix) { return kind[len(prefix):] } if kind == "reference.call" { return "reference" } fmt.Println(kind) return "" }