diff --git a/cmd/poiesis/main.go b/cmd/poiesis/main.go index 64a6689..8caf775 100644 --- a/cmd/poiesis/main.go +++ b/cmd/poiesis/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" @@ -22,12 +23,15 @@ func main() { return } - // Get File - filePath := os.Args[1] + // Create Runtime + rt, err := runtime.New(context.Background()) + if err != nil { + panic(err) + } // Run File - rt := runtime.New() + filePath := os.Args[1] if err := rt.RunFile(filePath, os.Stdout, os.Stderr); err != nil { - os.Exit(1) + panic(err) } } diff --git a/internal/builtin/builtin_test.go b/internal/builtin/builtin_test.go index ac33bbd..254fbb4 100644 --- a/internal/builtin/builtin_test.go +++ b/internal/builtin/builtin_test.go @@ -19,7 +19,7 @@ func (t TestArgs) Validate() error { } func TestAsyncBuiltin(t *testing.T) { - RegisterAsyncBuiltin[TestArgs, string]("testAsync", func(_ context.Context, args TestArgs) (string, error) { + RegisterAsyncBuiltin("testAsync", func(_ context.Context, args TestArgs) (string, error) { return "result: " + args.Field1, nil }) @@ -28,11 +28,11 @@ func TestAsyncBuiltin(t *testing.T) { registryMutex.RUnlock() require.True(t, ok, "testAsync should be registered") - assert.Contains(t, builtin.Definition, "Promise", "definition should include Promise") + assert.Contains(t, builtin.Definition(), "Promise", "definition should include Promise") } func TestAsyncBuiltinResolution(t *testing.T) { - RegisterAsyncBuiltin[TestArgs, string]("resolveTest", func(_ context.Context, args TestArgs) (string, error) { + RegisterAsyncBuiltin("resolveTest", func(_ context.Context, args TestArgs) (string, error) { return "test-result", nil }) @@ -43,15 +43,15 @@ func TestAsyncBuiltinResolution(t *testing.T) { }() vm.SetCanBlock(true) - RegisterBuiltins(vm) + RegisterBuiltins(context.Background(), vm) - result, err := vm.Eval(`resolveTest({field1: "hello"})`, quickjs.EvalGlobal) + result, err := vm.Eval(`resolveTest("hello")`, quickjs.EvalGlobal) require.NoError(t, err) assert.NotNil(t, result) } func TestAsyncBuiltinRejection(t *testing.T) { - RegisterAsyncBuiltin[TestArgs, string]("rejectTest", func(_ context.Context, args TestArgs) (string, error) { + RegisterAsyncBuiltin("rejectTest", func(_ context.Context, args TestArgs) (string, error) { return "", assert.AnError }) @@ -62,7 +62,7 @@ func TestAsyncBuiltinRejection(t *testing.T) { }() vm.SetCanBlock(true) - RegisterBuiltins(vm) + RegisterBuiltins(context.Background(), vm) result, err := vm.Eval(`rejectTest({field1: "hello"})`, quickjs.EvalGlobal) require.NoError(t, err) @@ -70,7 +70,7 @@ func TestAsyncBuiltinRejection(t *testing.T) { } func TestNonPromise(t *testing.T) { - RegisterBuiltin[TestArgs, string]("nonPromiseTest", func(_ context.Context, args TestArgs) (string, error) { + RegisterBuiltin("nonPromiseTest", func(_ context.Context, args TestArgs) (string, error) { return "sync-result", nil }) @@ -81,7 +81,7 @@ func TestNonPromise(t *testing.T) { }() vm.SetCanBlock(true) - RegisterBuiltins(vm) + RegisterBuiltins(context.Background(), vm) result, err := vm.Eval(`nonPromiseTest({field1: "hello"})`, quickjs.EvalGlobal) require.NoError(t, err) diff --git a/internal/builtin/registry.go b/internal/builtin/registry.go index c45b4b5..63e656d 100644 --- a/internal/builtin/registry.go +++ b/internal/builtin/registry.go @@ -5,8 +5,6 @@ import ( "reflect" "strings" "sync" - - "modernc.org/quickjs" ) var ( @@ -15,34 +13,29 @@ var ( collector *typeCollector ) -func registerBuiltin[T Args, R any](name string, isAsync bool, fn Func[T, R]) { +func registerBuiltin[A Args, R any](name string, isAsync bool, fn Func[A, R]) { + registryMutex.Lock() + defer registryMutex.Unlock() + if collector == nil { collector = newTypeCollector() } - var zeroT T - tType := reflect.TypeOf(zeroT) - + tType := reflect.TypeFor[A]() if tType.Kind() != reflect.Struct { panic(fmt.Sprintf("builtin %s: argument must be a struct type, got %v", name, tType)) } fnType := reflect.TypeOf(fn) - - wrapper := createWrapper(fn, isAsync) types := collector.collectTypes(tType, fnType) paramTypes := collector.getParamTypes() - registryMutex.Lock() - b := Builtin{ - Name: name, - Function: wrapper, - Definition: generateTypeScriptDefinition(name, tType, fnType, isAsync, paramTypes), - Types: types, - ParamTypes: paramTypes, + builtinRegistry[name] = &builtinImpl[A, R]{ + name: name, + fn: fn, + types: types, + definition: generateTypeScriptDefinition(name, tType, fnType, isAsync, paramTypes), } - builtinRegistry[name] = b - registryMutex.Unlock() } func GetBuiltinsDeclarations() string { @@ -54,13 +47,13 @@ func GetBuiltinsDeclarations() string { var functionDecls []string for _, builtin := range builtinRegistry { - for _, t := range builtin.Types { + for _, t := range builtin.Types() { if !typeDefinitions[t] { typeDefinitions[t] = true typeDefs = append(typeDefs, t) } } - functionDecls = append(functionDecls, builtin.Definition) + functionDecls = append(functionDecls, builtin.Definition()) } result := strings.Join(typeDefs, "\n\n") @@ -80,14 +73,6 @@ func RegisterAsyncBuiltin[T Args, R any](name string, fn Func[T, R]) { registerBuiltin(name, true, fn) } -func RegisterBuiltins(vm *quickjs.VM) { - registryMutex.RLock() - defer registryMutex.RUnlock() - - for name, builtin := range builtinRegistry { - err := vm.RegisterFunc(name, builtin.Function, false) - if err != nil { - panic(fmt.Sprintf("failed to register builtin %s: %v", name, err)) - } - } +func GetBuiltins() map[string]Builtin { + return builtinRegistry } diff --git a/internal/builtin/types.go b/internal/builtin/types.go index 094c533..d2388b2 100644 --- a/internal/builtin/types.go +++ b/internal/builtin/types.go @@ -2,24 +2,75 @@ package builtin import ( "context" + "errors" + "reflect" ) -type Builtin struct { - Name string - Function interface{} - Definition string - Types []string - ParamTypes map[string]bool +type Builtin interface { + Name() string + Types() []string + Definition() string + WrapFn(context.Context) func(...any) (any, error) } -func (b *Builtin) HasParamType(typeName string) bool { - return b.ParamTypes[typeName] -} - -type EmptyArgs struct{} +type Func[A Args, R any] func(ctx context.Context, args A) (R, error) type Args interface { Validate() error } -type Func[T Args, R any] func(ctx context.Context, args T) (R, error) +type builtinImpl[A Args, R any] struct { + name string + fn Func[A, R] + definition string + types []string +} + +func (b *builtinImpl[A, R]) Name() string { + return b.name +} + +func (b *builtinImpl[A, R]) Types() []string { + return b.types +} + +func (b *builtinImpl[A, R]) Definition() string { + return b.definition +} + +func (b *builtinImpl[A, R]) WrapFn(ctx context.Context) func(...any) (any, error) { + return func(allArgs ...any) (any, error) { + // Populate Arguments + var fnArgs A + aVal := reflect.ValueOf(&fnArgs).Elem() + + // Populate Fields + for i := range min(aVal.NumField(), len(allArgs)) { + field := aVal.Field(i) + + if !field.CanSet() { + return nil, errors.New("cannot set field") + } + + argVal := reflect.ValueOf(allArgs[i]) + if !argVal.Type().AssignableTo(field.Type()) { + return nil, errors.New("cannot assign field") + } + + field.Set(argVal) + } + + // Validate + if err := fnArgs.Validate(); err != nil { + return nil, errors.New("cannot validate args") + } + + // Call Function + resp, err := b.fn(ctx, fnArgs) + if err != nil { + return nil, err + } + + return resp, nil + } +} diff --git a/internal/builtin/wrapper.go b/internal/builtin/wrapper.go index 588024f..5d9ee61 100644 --- a/internal/builtin/wrapper.go +++ b/internal/builtin/wrapper.go @@ -1,81 +1 @@ package builtin - -import ( - "context" - "encoding/json" - "fmt" - - "modernc.org/quickjs" -) - -func createWrapper[T Args, R any](fn Func[T, R], isAsync bool) interface{} { - if !isAsync { - return createSyncWrapper[T, R](fn) - } - return createAsyncWrapper[T, R](fn) -} - -func createSyncWrapper[T Args, R any](fn Func[T, R]) interface{} { - return func(rawArgs any) (R, error) { - var zero R - var args T - - obj, ok := rawArgs.(*quickjs.Object) - if ok { - jsonData, err := obj.MarshalJSON() - if err != nil { - return zero, fmt.Errorf("failed to marshal args: %w", err) - } - if err := json.Unmarshal(jsonData, &args); err != nil { - return zero, fmt.Errorf("failed to unmarshal args: %w", err) - } - } else if rawArgs != nil && rawArgs != quickjs.UndefinedValue { - jsonData, err := json.Marshal(rawArgs) - if err != nil { - return zero, fmt.Errorf("failed to marshal args: %w", err) - } - if err := json.Unmarshal(jsonData, &args); err != nil { - return zero, fmt.Errorf("failed to unmarshal args: %w", err) - } - } - - if err := args.Validate(); err != nil { - return zero, fmt.Errorf("argument validation failed: %w", err) - } - - ctx := context.Background() - return fn(ctx, args) - } -} - -func createAsyncWrapper[T Args, R any](fn Func[T, R]) interface{} { - return func(rawArgs any) (any, error) { - var args T - - obj, ok := rawArgs.(*quickjs.Object) - if ok { - jsonData, err := obj.MarshalJSON() - if err != nil { - return nil, fmt.Errorf("failed to marshal args: %w", err) - } - if err := json.Unmarshal(jsonData, &args); err != nil { - return nil, fmt.Errorf("failed to unmarshal args: %w", err) - } - } else if rawArgs != nil && rawArgs != quickjs.UndefinedValue { - jsonData, err := json.Marshal(rawArgs) - if err != nil { - return nil, fmt.Errorf("failed to marshal args: %w", err) - } - if err := json.Unmarshal(jsonData, &args); err != nil { - return nil, fmt.Errorf("failed to unmarshal args: %w", err) - } - } - - if err := args.Validate(); err != nil { - return nil, fmt.Errorf("argument validation failed: %w", err) - } - - ctx := context.Background() - return fn(ctx, args) - } -} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 1cdc36f..cab9d71 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -1,6 +1,7 @@ package runtime import ( + "context" "fmt" "io" "os" @@ -11,36 +12,37 @@ import ( ) type Runtime struct { - vm *quickjs.VM - stdout io.Writer - stderr io.Writer - consoleSetup bool + vm *quickjs.VM + ctx context.Context + + stdout io.Writer + stderr io.Writer } -func New() *Runtime { +func New(ctx context.Context) (*Runtime, error) { + // Create VM vm, err := quickjs.NewVM() if err != nil { - panic(err) + return nil, err } - vm.SetCanBlock(true) - r := &Runtime{vm: vm, stdout: os.Stdout, stderr: os.Stderr} - r.setupConsole() + // Create Runtime + r := &Runtime{vm: vm, ctx: ctx, stdout: os.Stdout, stderr: os.Stderr} + if err := r.populateGlobals(); err != nil { + return nil, err + } - builtin.RegisterBuiltins(vm) - return r + return r, nil } -func (r *Runtime) setupConsole() { - if r.consoleSetup { - return - } - +func (r *Runtime) populateGlobals() error { + // Add Helpers if err := r.vm.StdAddHelpers(); err != nil { - panic(fmt.Sprintf("failed to add std helpers: %v", err)) + return err } + // Add Log Hook if err := r.vm.RegisterFunc("customLog", func(args ...any) { for i, arg := range args { if i > 0 { @@ -50,24 +52,44 @@ func (r *Runtime) setupConsole() { } _, _ = fmt.Fprintln(r.stdout) }, false); err != nil { - panic(fmt.Sprintf("failed to register customLog: %v", err)) + return err + } + if _, err := r.vm.Eval("console.log = customLog;", quickjs.EvalGlobal); err != nil { + return err } - _, _ = r.vm.Eval("console.log = customLog;", quickjs.EvalGlobal) + // Register Custom Functions + for name, builtin := range builtin.GetBuiltins() { + // Register Main Function + if err := r.vm.RegisterFunc(name, builtin.WrapFn(r.ctx), false); err != nil { + return err + } - r.consoleSetup = true -} + // Wrap Exception - The QuickJS library does not allow us to throw exceptions, so we + // wrap the function with native JS to appropriately throw on error. + if _, err := r.vm.Eval(fmt.Sprintf(` + (function() { + const original = globalThis[%q]; + + globalThis[%q] = function(...args) { + const [result, error] = original.apply(this, args); + if (error) { + throw new Error(error); + } + return result; + }; + })(); + `, name, name), quickjs.EvalGlobal); err != nil { + return err + } + } -func (r *Runtime) SetOutput(stdout, stderr io.Writer) { - r.stdout = stdout - r.stderr = stderr - r.setupConsole() + return nil } func (r *Runtime) RunFile(filePath string, stdout, stderr io.Writer) error { r.stdout = stdout r.stderr = stderr - r.setupConsole() content, err := r.transformFile(filePath) if err != nil { @@ -95,7 +117,6 @@ func (r *Runtime) RunFile(filePath string, stdout, stderr io.Writer) error { func (r *Runtime) RunCode(tsCode string, stdout, stderr io.Writer) error { r.stdout = stdout r.stderr = stderr - r.setupConsole() content := r.transformCode(tsCode) @@ -131,6 +152,14 @@ func (r *Runtime) transformFile(filePath string) (*transformResult, error) { } func (r *Runtime) transformCode(tsCode string) *transformResult { + // wrappedCode := `(async () => { + // try { + // ` + tsCode + ` + // } catch (err) { + // console.error(err); + // } + // })()` + result := api.Transform(tsCode, api.TransformOptions{ Loader: api.LoaderTS, Target: api.ES2022, diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 65af22f..219215d 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -2,6 +2,7 @@ package runtime import ( "bytes" + "context" "strings" "testing" @@ -14,8 +15,10 @@ import ( func TestExecuteTypeScript(t *testing.T) { var stdout, stderr bytes.Buffer - rt := New() - err := rt.RunFile("../../test_data/test.ts", &stdout, &stderr) + rt, err := New(context.Background()) + assert.NoError(t, err, "Expected no error") + + err = rt.RunFile("../../test_data/test.ts", &stdout, &stderr) assert.NoError(t, err, "Expected no error") assert.Empty(t, stderr.String(), "Expected no error output") @@ -32,7 +35,8 @@ func TestExecuteTypeScript(t *testing.T) { } func TestFetchBuiltinIntegration(t *testing.T) { - rt := New() + rt, err := New(context.Background()) + assert.NoError(t, err, "Expected no error") tsContent := ` const result = add({a: 5, b: 10}); @@ -40,7 +44,7 @@ func TestFetchBuiltinIntegration(t *testing.T) { ` var stdout, stderr bytes.Buffer - err := rt.RunCode(tsContent, &stdout, &stderr) + err = rt.RunCode(tsContent, &stdout, &stderr) require.NoError(t, err) assert.Contains(t, stdout.String(), "Result:") } diff --git a/internal/standard/fetch_promise_test.go b/internal/standard/fetch_promise_test.go index aa2442f..d550653 100644 --- a/internal/standard/fetch_promise_test.go +++ b/internal/standard/fetch_promise_test.go @@ -1,6 +1,7 @@ package standard import ( + "context" "net/http" "net/http/httptest" "testing" @@ -21,7 +22,7 @@ func TestFetchReturnsPromise(t *testing.T) { }() vm.SetCanBlock(true) - builtin.RegisterBuiltins(vm) + builtin.RegisterBuiltins(context.Background(), vm) result, err := vm.Eval(`fetch({input: "https://example.com"})`, quickjs.EvalGlobal) require.NoError(t, err) @@ -43,7 +44,7 @@ func TestFetchAsyncAwait(t *testing.T) { }() vm.SetCanBlock(true) - builtin.RegisterBuiltins(vm) + builtin.RegisterBuiltins(context.Background(), vm) result, err := vm.Eval(`fetch({input: "`+server.URL+`"})`, quickjs.EvalGlobal) require.NoError(t, err) diff --git a/test_data/fetch-new.ts b/test_data/fetch-new.ts new file mode 100644 index 0000000..e299263 --- /dev/null +++ b/test_data/fetch-new.ts @@ -0,0 +1,14 @@ +try { + console.log(1); + const response = fetch("https://httpbin.org/get"); + console.log(2); + console.log(response); + + console.log("OK:", response.ok); + console.log("Status:", response.status); + console.log("Body:", response.body); + console.log("Content-Type:", response.headers["content-type"]); +} catch (e) { + console.log(e.message); + console.log("exception"); +} diff --git a/test_data/fetch.ts b/test_data/fetch.ts index 22d0300..cf00e81 100644 --- a/test_data/fetch.ts +++ b/test_data/fetch.ts @@ -1,10 +1,18 @@ +var done = false; async function main() { - const response = await fetch("https://httpbin.org/get"); + try { + console.log(11); + const response = fetch("https://httpbin.org/get"); + console.log(response); - console.log("OK:", response.ok); - console.log("Status:", response.status); - console.log("Body:", response.body); - console.log("Content-Type:", response.headers["content-type"]); + console.log("OK:", response.ok); + console.log("Status:", response.status); + console.log("Body:", response.body); + console.log("Content-Type:", response.headers["content-type"]); + } catch (e) { + console.log(e); + } + done = true; } console.log(1);