315 lines
8.1 KiB
Go
315 lines
8.1 KiB
Go
package indexer
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/odvcencio/gotreesitter"
|
|
"github.com/odvcencio/gotreesitter/grammars"
|
|
|
|
"codexis/db"
|
|
)
|
|
|
|
// ProgressFunc is called for each file being processed.
|
|
// current is the 1-based index, total is the total file count, path is the file being processed.
|
|
type ProgressFunc func(current, total int, path string)
|
|
|
|
const defaultBatchSize = 100
|
|
|
|
// Indexer walks a codebase, extracts symbols via tree-sitter, and stores them in SQLite.
|
|
type Indexer struct {
|
|
db *sql.DB
|
|
queries *db.Queries
|
|
root string
|
|
force bool
|
|
BatchSize int
|
|
OnProgress ProgressFunc
|
|
}
|
|
|
|
// New creates a new Indexer.
|
|
func New(sqlDB *sql.DB, queries *db.Queries, root string, force bool) *Indexer {
|
|
return &Indexer{
|
|
db: sqlDB,
|
|
queries: queries,
|
|
root: root,
|
|
force: force,
|
|
BatchSize: defaultBatchSize,
|
|
}
|
|
}
|
|
|
|
// Stats holds indexing statistics.
|
|
type Stats struct {
|
|
FilesTotal int
|
|
FilesIndexed int
|
|
FilesSkipped int
|
|
SymbolsTotal int
|
|
}
|
|
|
|
// Index walks the codebase and indexes all recognized files.
|
|
func (idx *Indexer) Index(ctx context.Context) (*Stats, error) {
|
|
files, err := WalkFiles(idx.root)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("walking files: %w", err)
|
|
}
|
|
|
|
stats := &Stats{FilesTotal: len(files)}
|
|
batchSize := idx.BatchSize
|
|
if batchSize <= 0 {
|
|
batchSize = defaultBatchSize
|
|
}
|
|
|
|
// Process files in transaction batches
|
|
for batchStart := 0; batchStart < len(files); batchStart += batchSize {
|
|
batchEnd := batchStart + batchSize
|
|
if batchEnd > len(files) {
|
|
batchEnd = len(files)
|
|
}
|
|
batch := files[batchStart:batchEnd]
|
|
|
|
if err := idx.indexBatch(ctx, batch, batchStart, stats); err != nil {
|
|
return nil, fmt.Errorf("indexing batch: %w", err)
|
|
}
|
|
}
|
|
|
|
// Clean up files that no longer exist (in its own transaction)
|
|
tx, err := idx.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("begin cleanup tx: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
txQueries := idx.queries.WithTx(tx)
|
|
if err := txQueries.DeleteStaleFileContents(ctx, files); err != nil {
|
|
return nil, fmt.Errorf("cleaning stale file contents: %w", err)
|
|
}
|
|
if err := txQueries.DeleteStaleFiles(ctx, files); err != nil {
|
|
return nil, fmt.Errorf("cleaning stale files: %w", err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("commit cleanup tx: %w", err)
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// indexBatch processes a slice of files within a single transaction.
|
|
func (idx *Indexer) indexBatch(ctx context.Context, batch []string, offset int, stats *Stats) error {
|
|
tx, err := idx.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin tx: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
txQueries := idx.queries.WithTx(tx)
|
|
|
|
for i, relPath := range batch {
|
|
if idx.OnProgress != nil {
|
|
idx.OnProgress(offset+i+1, stats.FilesTotal, relPath)
|
|
}
|
|
indexed, symbolCount, err := idx.indexFile(ctx, txQueries, relPath)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "\rwarn: %s: %v\033[K\n", relPath, err)
|
|
continue
|
|
}
|
|
if indexed {
|
|
stats.FilesIndexed++
|
|
stats.SymbolsTotal += symbolCount
|
|
} else {
|
|
stats.FilesSkipped++
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit tx: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (idx *Indexer) indexFile(ctx context.Context, q *db.Queries, relPath string) (indexed bool, symbolCount int, err error) {
|
|
absPath := filepath.Join(idx.root, relPath)
|
|
|
|
src, err := os.ReadFile(absPath)
|
|
if err != nil {
|
|
return false, 0, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
|
|
hash := fmt.Sprintf("%x", sha256.Sum256(src))
|
|
|
|
// Check if file has changed
|
|
if !idx.force {
|
|
existing, err := q.GetFileByPath(ctx, relPath)
|
|
if err == nil && existing.Hash == hash {
|
|
return false, 0, nil // unchanged
|
|
}
|
|
}
|
|
|
|
// Detect language
|
|
entry := grammars.DetectLanguage(filepath.Base(relPath))
|
|
if entry == nil {
|
|
return false, 0, nil
|
|
}
|
|
|
|
// Check if this language has a tags query — skip parsing if not
|
|
tagsQuery := grammars.ResolveTagsQuery(*entry)
|
|
hasTagsQuery := tagsQuery != ""
|
|
|
|
var tree *gotreesitter.Tree
|
|
if hasTagsQuery {
|
|
// Parse once, reuse tree for package extraction and tagging
|
|
lang := entry.Language()
|
|
parser := gotreesitter.NewParser(lang)
|
|
parsedTree, parseErr := parser.Parse(src)
|
|
if parseErr != nil {
|
|
return false, 0, fmt.Errorf("parsing: %w", parseErr)
|
|
}
|
|
tree = parsedTree
|
|
if tree != nil {
|
|
defer tree.Release()
|
|
}
|
|
}
|
|
|
|
// Extract package (uses tree if available, falls back to dir name)
|
|
pkg := ExtractPackage(src, relPath, entry, tree)
|
|
|
|
// Upsert file record
|
|
file, err := q.UpsertFile(ctx, db.UpsertFileParams{
|
|
Path: relPath,
|
|
Language: entry.Name,
|
|
Package: sql.NullString{String: pkg, Valid: pkg != ""},
|
|
Hash: hash,
|
|
})
|
|
if err != nil {
|
|
return false, 0, fmt.Errorf("upserting file: %w", err)
|
|
}
|
|
|
|
// Store file content for FTS
|
|
if err := q.UpsertFileContent(ctx, file.ID, string(src)); err != nil {
|
|
return false, 0, fmt.Errorf("upserting file content: %w", err)
|
|
}
|
|
|
|
if !hasTagsQuery {
|
|
return true, 0, nil
|
|
}
|
|
|
|
// Clear old symbols
|
|
if err := q.DeleteSymbolsByFileID(ctx, file.ID); err != nil {
|
|
return false, 0, fmt.Errorf("deleting old symbols: %w", err)
|
|
}
|
|
|
|
// Extract and store symbols
|
|
tags := extractTags(src, entry, tree)
|
|
defs := buildSymbolDefs(tags, file.ID, entry.Name)
|
|
|
|
// Insert symbols in order, tracking DB IDs for parent resolution
|
|
dbIDs := make([]int64, len(defs))
|
|
for i, def := range defs {
|
|
// Resolve parent_id from local index to actual DB ID
|
|
params := def.params
|
|
if params.ParentID.Valid {
|
|
parentIdx := params.ParentID.Int64
|
|
params.ParentID = sql.NullInt64{Int64: dbIDs[parentIdx], Valid: true}
|
|
}
|
|
|
|
id, err := q.InsertSymbol(ctx, params)
|
|
if err != nil {
|
|
return false, 0, fmt.Errorf("inserting symbol %q: %w", params.Name, err)
|
|
}
|
|
dbIDs[i] = id
|
|
}
|
|
|
|
return true, len(defs), nil
|
|
}
|
|
|
|
func extractTags(src []byte, entry *grammars.LangEntry, tree *gotreesitter.Tree) []gotreesitter.Tag {
|
|
if tree == nil {
|
|
return nil
|
|
}
|
|
|
|
lang := entry.Language()
|
|
|
|
// ResolveTagsQuery returns the explicit TagsQuery if set, otherwise infers
|
|
// one from the grammar's symbol table.
|
|
tagsQuery := grammars.ResolveTagsQuery(*entry)
|
|
if tagsQuery == "" {
|
|
return nil
|
|
}
|
|
|
|
tagger, err := gotreesitter.NewTagger(lang, tagsQuery)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return tagger.TagTree(tree)
|
|
}
|
|
|
|
type symbolDef struct {
|
|
tag gotreesitter.Tag
|
|
params db.InsertSymbolParams
|
|
}
|
|
|
|
func buildSymbolDefs(tags []gotreesitter.Tag, fileID int64, langName string) []symbolDef {
|
|
// First pass: collect all definition tags
|
|
var defs []symbolDef
|
|
|
|
for _, tag := range tags {
|
|
kind := tagKind(tag.Kind)
|
|
if kind == "" {
|
|
continue
|
|
}
|
|
|
|
exported := IsExported(tag.Name, langName)
|
|
|
|
params := db.InsertSymbolParams{
|
|
FileID: fileID,
|
|
Name: tag.Name,
|
|
Kind: kind,
|
|
Line: int64(tag.NameRange.StartPoint.Row) + 1, // 1-indexed
|
|
LineEnd: sql.NullInt64{Int64: int64(tag.Range.EndPoint.Row) + 1, Valid: true},
|
|
Col: sql.NullInt64{Int64: int64(tag.NameRange.StartPoint.Column), Valid: true},
|
|
ColEnd: sql.NullInt64{Int64: int64(tag.NameRange.EndPoint.Column), Valid: true},
|
|
Exported: sql.NullBool{Bool: exported, Valid: true},
|
|
ParentID: sql.NullInt64{Valid: false},
|
|
}
|
|
|
|
defs = append(defs, symbolDef{tag: tag, params: params})
|
|
}
|
|
|
|
// Second pass: determine parent relationships based on range containment.
|
|
// ParentID stores the local index — resolved to DB ID during insert.
|
|
// Tree-sitter returns tags in document order (outer before inner),
|
|
// so scanning backwards finds the nearest enclosing definition.
|
|
for i := range defs {
|
|
for j := i - 1; j >= 0; j-- {
|
|
if containsRange(defs[j].tag.Range, defs[i].tag.Range) {
|
|
defs[i].params.ParentID = sql.NullInt64{Int64: int64(j), Valid: true}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
return defs
|
|
}
|
|
|
|
func containsRange(outer, inner gotreesitter.Range) bool {
|
|
return outer.StartByte <= inner.StartByte && outer.EndByte >= inner.EndByte
|
|
}
|
|
|
|
// tagKind converts a tree-sitter tag kind like "definition.function" to "function".
|
|
// Returns empty string for non-definition tags.
|
|
func tagKind(kind string) string {
|
|
const prefix = "definition."
|
|
if strings.HasPrefix(kind, prefix) {
|
|
return kind[len(prefix):]
|
|
}
|
|
if kind == "reference.call" {
|
|
return "reference"
|
|
}
|
|
fmt.Println(kind)
|
|
return ""
|
|
}
|