This commit is contained in:
2026-01-28 12:43:17 -05:00
parent f9d3753806
commit 604178341d
10 changed files with 188 additions and 172 deletions

View File

@@ -19,7 +19,7 @@ func (t TestArgs) Validate() error {
}
func TestAsyncBuiltin(t *testing.T) {
RegisterAsyncBuiltin[TestArgs, string]("testAsync", func(_ context.Context, args TestArgs) (string, error) {
RegisterAsyncBuiltin("testAsync", func(_ context.Context, args TestArgs) (string, error) {
return "result: " + args.Field1, nil
})
@@ -28,11 +28,11 @@ func TestAsyncBuiltin(t *testing.T) {
registryMutex.RUnlock()
require.True(t, ok, "testAsync should be registered")
assert.Contains(t, builtin.Definition, "Promise<string>", "definition should include Promise<string>")
assert.Contains(t, builtin.Definition(), "Promise<string>", "definition should include Promise<string>")
}
func TestAsyncBuiltinResolution(t *testing.T) {
RegisterAsyncBuiltin[TestArgs, string]("resolveTest", func(_ context.Context, args TestArgs) (string, error) {
RegisterAsyncBuiltin("resolveTest", func(_ context.Context, args TestArgs) (string, error) {
return "test-result", nil
})
@@ -43,15 +43,15 @@ func TestAsyncBuiltinResolution(t *testing.T) {
}()
vm.SetCanBlock(true)
RegisterBuiltins(vm)
RegisterBuiltins(context.Background(), vm)
result, err := vm.Eval(`resolveTest({field1: "hello"})`, quickjs.EvalGlobal)
result, err := vm.Eval(`resolveTest("hello")`, quickjs.EvalGlobal)
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestAsyncBuiltinRejection(t *testing.T) {
RegisterAsyncBuiltin[TestArgs, string]("rejectTest", func(_ context.Context, args TestArgs) (string, error) {
RegisterAsyncBuiltin("rejectTest", func(_ context.Context, args TestArgs) (string, error) {
return "", assert.AnError
})
@@ -62,7 +62,7 @@ func TestAsyncBuiltinRejection(t *testing.T) {
}()
vm.SetCanBlock(true)
RegisterBuiltins(vm)
RegisterBuiltins(context.Background(), vm)
result, err := vm.Eval(`rejectTest({field1: "hello"})`, quickjs.EvalGlobal)
require.NoError(t, err)
@@ -70,7 +70,7 @@ func TestAsyncBuiltinRejection(t *testing.T) {
}
func TestNonPromise(t *testing.T) {
RegisterBuiltin[TestArgs, string]("nonPromiseTest", func(_ context.Context, args TestArgs) (string, error) {
RegisterBuiltin("nonPromiseTest", func(_ context.Context, args TestArgs) (string, error) {
return "sync-result", nil
})
@@ -81,7 +81,7 @@ func TestNonPromise(t *testing.T) {
}()
vm.SetCanBlock(true)
RegisterBuiltins(vm)
RegisterBuiltins(context.Background(), vm)
result, err := vm.Eval(`nonPromiseTest({field1: "hello"})`, quickjs.EvalGlobal)
require.NoError(t, err)

View File

@@ -5,8 +5,6 @@ import (
"reflect"
"strings"
"sync"
"modernc.org/quickjs"
)
var (
@@ -15,34 +13,29 @@ var (
collector *typeCollector
)
func registerBuiltin[T Args, R any](name string, isAsync bool, fn Func[T, R]) {
func registerBuiltin[A Args, R any](name string, isAsync bool, fn Func[A, R]) {
registryMutex.Lock()
defer registryMutex.Unlock()
if collector == nil {
collector = newTypeCollector()
}
var zeroT T
tType := reflect.TypeOf(zeroT)
tType := reflect.TypeFor[A]()
if tType.Kind() != reflect.Struct {
panic(fmt.Sprintf("builtin %s: argument must be a struct type, got %v", name, tType))
}
fnType := reflect.TypeOf(fn)
wrapper := createWrapper(fn, isAsync)
types := collector.collectTypes(tType, fnType)
paramTypes := collector.getParamTypes()
registryMutex.Lock()
b := Builtin{
Name: name,
Function: wrapper,
Definition: generateTypeScriptDefinition(name, tType, fnType, isAsync, paramTypes),
Types: types,
ParamTypes: paramTypes,
builtinRegistry[name] = &builtinImpl[A, R]{
name: name,
fn: fn,
types: types,
definition: generateTypeScriptDefinition(name, tType, fnType, isAsync, paramTypes),
}
builtinRegistry[name] = b
registryMutex.Unlock()
}
func GetBuiltinsDeclarations() string {
@@ -54,13 +47,13 @@ func GetBuiltinsDeclarations() string {
var functionDecls []string
for _, builtin := range builtinRegistry {
for _, t := range builtin.Types {
for _, t := range builtin.Types() {
if !typeDefinitions[t] {
typeDefinitions[t] = true
typeDefs = append(typeDefs, t)
}
}
functionDecls = append(functionDecls, builtin.Definition)
functionDecls = append(functionDecls, builtin.Definition())
}
result := strings.Join(typeDefs, "\n\n")
@@ -80,14 +73,6 @@ func RegisterAsyncBuiltin[T Args, R any](name string, fn Func[T, R]) {
registerBuiltin(name, true, fn)
}
func RegisterBuiltins(vm *quickjs.VM) {
registryMutex.RLock()
defer registryMutex.RUnlock()
for name, builtin := range builtinRegistry {
err := vm.RegisterFunc(name, builtin.Function, false)
if err != nil {
panic(fmt.Sprintf("failed to register builtin %s: %v", name, err))
}
}
func GetBuiltins() map[string]Builtin {
return builtinRegistry
}

View File

@@ -2,24 +2,75 @@ package builtin
import (
"context"
"errors"
"reflect"
)
type Builtin struct {
Name string
Function interface{}
Definition string
Types []string
ParamTypes map[string]bool
type Builtin interface {
Name() string
Types() []string
Definition() string
WrapFn(context.Context) func(...any) (any, error)
}
func (b *Builtin) HasParamType(typeName string) bool {
return b.ParamTypes[typeName]
}
type EmptyArgs struct{}
type Func[A Args, R any] func(ctx context.Context, args A) (R, error)
type Args interface {
Validate() error
}
type Func[T Args, R any] func(ctx context.Context, args T) (R, error)
type builtinImpl[A Args, R any] struct {
name string
fn Func[A, R]
definition string
types []string
}
func (b *builtinImpl[A, R]) Name() string {
return b.name
}
func (b *builtinImpl[A, R]) Types() []string {
return b.types
}
func (b *builtinImpl[A, R]) Definition() string {
return b.definition
}
func (b *builtinImpl[A, R]) WrapFn(ctx context.Context) func(...any) (any, error) {
return func(allArgs ...any) (any, error) {
// Populate Arguments
var fnArgs A
aVal := reflect.ValueOf(&fnArgs).Elem()
// Populate Fields
for i := range min(aVal.NumField(), len(allArgs)) {
field := aVal.Field(i)
if !field.CanSet() {
return nil, errors.New("cannot set field")
}
argVal := reflect.ValueOf(allArgs[i])
if !argVal.Type().AssignableTo(field.Type()) {
return nil, errors.New("cannot assign field")
}
field.Set(argVal)
}
// Validate
if err := fnArgs.Validate(); err != nil {
return nil, errors.New("cannot validate args")
}
// Call Function
resp, err := b.fn(ctx, fnArgs)
if err != nil {
return nil, err
}
return resp, nil
}
}

View File

@@ -1,81 +1 @@
package builtin
import (
"context"
"encoding/json"
"fmt"
"modernc.org/quickjs"
)
func createWrapper[T Args, R any](fn Func[T, R], isAsync bool) interface{} {
if !isAsync {
return createSyncWrapper[T, R](fn)
}
return createAsyncWrapper[T, R](fn)
}
func createSyncWrapper[T Args, R any](fn Func[T, R]) interface{} {
return func(rawArgs any) (R, error) {
var zero R
var args T
obj, ok := rawArgs.(*quickjs.Object)
if ok {
jsonData, err := obj.MarshalJSON()
if err != nil {
return zero, fmt.Errorf("failed to marshal args: %w", err)
}
if err := json.Unmarshal(jsonData, &args); err != nil {
return zero, fmt.Errorf("failed to unmarshal args: %w", err)
}
} else if rawArgs != nil && rawArgs != quickjs.UndefinedValue {
jsonData, err := json.Marshal(rawArgs)
if err != nil {
return zero, fmt.Errorf("failed to marshal args: %w", err)
}
if err := json.Unmarshal(jsonData, &args); err != nil {
return zero, fmt.Errorf("failed to unmarshal args: %w", err)
}
}
if err := args.Validate(); err != nil {
return zero, fmt.Errorf("argument validation failed: %w", err)
}
ctx := context.Background()
return fn(ctx, args)
}
}
func createAsyncWrapper[T Args, R any](fn Func[T, R]) interface{} {
return func(rawArgs any) (any, error) {
var args T
obj, ok := rawArgs.(*quickjs.Object)
if ok {
jsonData, err := obj.MarshalJSON()
if err != nil {
return nil, fmt.Errorf("failed to marshal args: %w", err)
}
if err := json.Unmarshal(jsonData, &args); err != nil {
return nil, fmt.Errorf("failed to unmarshal args: %w", err)
}
} else if rawArgs != nil && rawArgs != quickjs.UndefinedValue {
jsonData, err := json.Marshal(rawArgs)
if err != nil {
return nil, fmt.Errorf("failed to marshal args: %w", err)
}
if err := json.Unmarshal(jsonData, &args); err != nil {
return nil, fmt.Errorf("failed to unmarshal args: %w", err)
}
}
if err := args.Validate(); err != nil {
return nil, fmt.Errorf("argument validation failed: %w", err)
}
ctx := context.Background()
return fn(ctx, args)
}
}

View File

@@ -1,6 +1,7 @@
package runtime
import (
"context"
"fmt"
"io"
"os"
@@ -11,36 +12,37 @@ import (
)
type Runtime struct {
vm *quickjs.VM
stdout io.Writer
stderr io.Writer
consoleSetup bool
vm *quickjs.VM
ctx context.Context
stdout io.Writer
stderr io.Writer
}
func New() *Runtime {
func New(ctx context.Context) (*Runtime, error) {
// Create VM
vm, err := quickjs.NewVM()
if err != nil {
panic(err)
return nil, err
}
vm.SetCanBlock(true)
r := &Runtime{vm: vm, stdout: os.Stdout, stderr: os.Stderr}
r.setupConsole()
// Create Runtime
r := &Runtime{vm: vm, ctx: ctx, stdout: os.Stdout, stderr: os.Stderr}
if err := r.populateGlobals(); err != nil {
return nil, err
}
builtin.RegisterBuiltins(vm)
return r
return r, nil
}
func (r *Runtime) setupConsole() {
if r.consoleSetup {
return
}
func (r *Runtime) populateGlobals() error {
// Add Helpers
if err := r.vm.StdAddHelpers(); err != nil {
panic(fmt.Sprintf("failed to add std helpers: %v", err))
return err
}
// Add Log Hook
if err := r.vm.RegisterFunc("customLog", func(args ...any) {
for i, arg := range args {
if i > 0 {
@@ -50,24 +52,44 @@ func (r *Runtime) setupConsole() {
}
_, _ = fmt.Fprintln(r.stdout)
}, false); err != nil {
panic(fmt.Sprintf("failed to register customLog: %v", err))
return err
}
if _, err := r.vm.Eval("console.log = customLog;", quickjs.EvalGlobal); err != nil {
return err
}
_, _ = r.vm.Eval("console.log = customLog;", quickjs.EvalGlobal)
// Register Custom Functions
for name, builtin := range builtin.GetBuiltins() {
// Register Main Function
if err := r.vm.RegisterFunc(name, builtin.WrapFn(r.ctx), false); err != nil {
return err
}
r.consoleSetup = true
}
// Wrap Exception - The QuickJS library does not allow us to throw exceptions, so we
// wrap the function with native JS to appropriately throw on error.
if _, err := r.vm.Eval(fmt.Sprintf(`
(function() {
const original = globalThis[%q];
globalThis[%q] = function(...args) {
const [result, error] = original.apply(this, args);
if (error) {
throw new Error(error);
}
return result;
};
})();
`, name, name), quickjs.EvalGlobal); err != nil {
return err
}
}
func (r *Runtime) SetOutput(stdout, stderr io.Writer) {
r.stdout = stdout
r.stderr = stderr
r.setupConsole()
return nil
}
func (r *Runtime) RunFile(filePath string, stdout, stderr io.Writer) error {
r.stdout = stdout
r.stderr = stderr
r.setupConsole()
content, err := r.transformFile(filePath)
if err != nil {
@@ -95,7 +117,6 @@ func (r *Runtime) RunFile(filePath string, stdout, stderr io.Writer) error {
func (r *Runtime) RunCode(tsCode string, stdout, stderr io.Writer) error {
r.stdout = stdout
r.stderr = stderr
r.setupConsole()
content := r.transformCode(tsCode)
@@ -131,6 +152,14 @@ func (r *Runtime) transformFile(filePath string) (*transformResult, error) {
}
func (r *Runtime) transformCode(tsCode string) *transformResult {
// wrappedCode := `(async () => {
// try {
// ` + tsCode + `
// } catch (err) {
// console.error(err);
// }
// })()`
result := api.Transform(tsCode, api.TransformOptions{
Loader: api.LoaderTS,
Target: api.ES2022,

View File

@@ -2,6 +2,7 @@ package runtime
import (
"bytes"
"context"
"strings"
"testing"
@@ -14,8 +15,10 @@ import (
func TestExecuteTypeScript(t *testing.T) {
var stdout, stderr bytes.Buffer
rt := New()
err := rt.RunFile("../../test_data/test.ts", &stdout, &stderr)
rt, err := New(context.Background())
assert.NoError(t, err, "Expected no error")
err = rt.RunFile("../../test_data/test.ts", &stdout, &stderr)
assert.NoError(t, err, "Expected no error")
assert.Empty(t, stderr.String(), "Expected no error output")
@@ -32,7 +35,8 @@ func TestExecuteTypeScript(t *testing.T) {
}
func TestFetchBuiltinIntegration(t *testing.T) {
rt := New()
rt, err := New(context.Background())
assert.NoError(t, err, "Expected no error")
tsContent := `
const result = add({a: 5, b: 10});
@@ -40,7 +44,7 @@ func TestFetchBuiltinIntegration(t *testing.T) {
`
var stdout, stderr bytes.Buffer
err := rt.RunCode(tsContent, &stdout, &stderr)
err = rt.RunCode(tsContent, &stdout, &stderr)
require.NoError(t, err)
assert.Contains(t, stdout.String(), "Result:")
}

View File

@@ -1,6 +1,7 @@
package standard
import (
"context"
"net/http"
"net/http/httptest"
"testing"
@@ -21,7 +22,7 @@ func TestFetchReturnsPromise(t *testing.T) {
}()
vm.SetCanBlock(true)
builtin.RegisterBuiltins(vm)
builtin.RegisterBuiltins(context.Background(), vm)
result, err := vm.Eval(`fetch({input: "https://example.com"})`, quickjs.EvalGlobal)
require.NoError(t, err)
@@ -43,7 +44,7 @@ func TestFetchAsyncAwait(t *testing.T) {
}()
vm.SetCanBlock(true)
builtin.RegisterBuiltins(vm)
builtin.RegisterBuiltins(context.Background(), vm)
result, err := vm.Eval(`fetch({input: "`+server.URL+`"})`, quickjs.EvalGlobal)
require.NoError(t, err)