This commit is contained in:
2026-01-27 13:27:34 -05:00
parent fb6f260630
commit 28b1ad32f5
4 changed files with 302 additions and 169 deletions

View File

@@ -11,97 +11,79 @@ import (
type Builtin struct { type Builtin struct {
Name string Name string
Function any Function func(*goja.Runtime) func(goja.FunctionCall) goja.Value
Definition string Definition string
} }
type EmptyArgs struct{}
var ( var (
builtinRegistry = make(map[string]Builtin) builtinRegistry = make(map[string]Builtin)
registryMutex sync.RWMutex registryMutex sync.RWMutex
customConverters = make(map[reflect.Type]func(*goja.Runtime, reflect.Value) goja.Value)
) )
func RegisterBuiltin(name string, fn any) { func RegisterBuiltin[T any, R any](name string, fn any) {
fnValue := reflect.ValueOf(fn) var zeroT T
fnType := fnValue.Type() tType := reflect.TypeOf(zeroT)
wrapper := createGenericWrapper(fnValue, fnType) if tType.Kind() != reflect.Struct {
definition := generateTypeScriptDefinition(name, fnType) panic(fmt.Sprintf("builtin %s: argument must be a struct type, got %v", name, tType))
}
fnType := reflect.TypeOf(fn)
wrapper := createWrapper[T](fn, fnType)
registryMutex.Lock() registryMutex.Lock()
builtinRegistry[name] = Builtin{ builtinRegistry[name] = Builtin{
Name: name, Name: name,
Function: wrapper, Function: wrapper,
Definition: definition, Definition: generateTypeScriptDefinition(name, tType, fnType),
} }
registryMutex.Unlock() registryMutex.Unlock()
} }
func RegisterCustomConverter[T any](converter func(vm *goja.Runtime, value T) goja.Value) { func createWrapper[T any](fn any, fnType reflect.Type) func(*goja.Runtime) func(goja.FunctionCall) goja.Value {
var t T return func(vm *goja.Runtime) func(goja.FunctionCall) goja.Value {
typeOf := reflect.TypeOf(t)
registryMutex.Lock()
wrappedConverter := func(vm *goja.Runtime, value reflect.Value) goja.Value {
return converter(vm, value.Interface().(T))
}
customConverters[typeOf] = wrappedConverter
if typeOf.Kind() == reflect.Pointer {
elemType := typeOf.Elem()
customConverters[elemType] = func(vm *goja.Runtime, value reflect.Value) goja.Value {
if value.IsNil() {
return goja.Null()
}
return converter(vm, value.Interface().(T))
}
}
registryMutex.Unlock()
}
func createGenericWrapper(fnValue reflect.Value, fnType reflect.Type) any {
return func(vm *goja.Runtime) any {
return func(call goja.FunctionCall) goja.Value { return func(call goja.FunctionCall) goja.Value {
args := make([]reflect.Value, fnType.NumIn()) var args T
argsValue := reflect.ValueOf(&args).Elem()
for i := 0; i < fnType.NumIn(); i++ { for i := 0; i < argsValue.NumField() && i < len(call.Arguments); i++ {
argType := fnType.In(i) jsArg := call.Arguments[i]
var jsArg goja.Value field := argsValue.Field(i)
if i < len(call.Arguments) {
jsArg = call.Arguments[i]
} else {
jsArg = goja.Undefined()
}
if goja.IsUndefined(jsArg) || goja.IsNull(jsArg) { if goja.IsUndefined(jsArg) || goja.IsNull(jsArg) {
if argType.Kind() == reflect.Map { if field.Kind() == reflect.Pointer {
args[i] = reflect.MakeMap(argType)
continue
}
if argType.Kind() == reflect.Interface {
args[i] = reflect.Zero(argType)
continue continue
} }
} }
converted, err := convertJSValueToGo(vm, jsArg, argType) converted, err := convertJSValueToGo(vm, jsArg, field.Type())
if err != nil { if err != nil {
panic(fmt.Sprintf("argument %d: %v", i, err)) panic(fmt.Sprintf("argument %d (%s): %v", i, getFieldName(argsValue.Type().Field(i)), err))
}
if converted != nil {
field.Set(reflect.ValueOf(converted))
} }
args[i] = reflect.ValueOf(converted)
} }
results := fnValue.Call(args) if defaults, ok := any(args).(interface{ Defaults() T }); ok {
args = defaults.Defaults()
}
fnValue := reflect.ValueOf(fn)
firstParamType := fnType.In(0)
argValue := reflect.ValueOf(args).Convert(firstParamType)
results := fnValue.Call([]reflect.Value{argValue})
if len(results) == 0 { if len(results) == 0 {
return goja.Undefined() return goja.Undefined()
} }
lastResult := results[len(results)-1] if err, isError := results[len(results)-1].Interface().(error); isError {
if lastResult.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { if err != nil {
if !lastResult.IsNil() { panic(err)
panic(fmt.Sprintf("error: %v", lastResult.Interface()))
} }
if len(results) == 1 { if len(results) == 1 {
return goja.Undefined() return goja.Undefined()
@@ -115,8 +97,15 @@ func createGenericWrapper(fnValue reflect.Value, fnType reflect.Type) any {
} }
func convertJSValueToGo(vm *goja.Runtime, jsValue goja.Value, targetType reflect.Type) (any, error) { func convertJSValueToGo(vm *goja.Runtime, jsValue goja.Value, targetType reflect.Type) (any, error) {
if goja.IsUndefined(jsValue) || goja.IsNull(jsValue) { if goja.IsNull(jsValue) {
if targetType.Kind() == reflect.Interface || targetType.Kind() == reflect.Pointer { if targetType.Kind() == reflect.Pointer || targetType.Kind() == reflect.Map {
return nil, nil
}
return nil, fmt.Errorf("cannot convert null/undefined to %v", targetType)
}
if goja.IsUndefined(jsValue) {
if targetType.Kind() == reflect.Pointer || targetType.Kind() == reflect.Map {
return nil, nil return nil, nil
} }
return nil, fmt.Errorf("cannot convert null/undefined to %v", targetType) return nil, fmt.Errorf("cannot convert null/undefined to %v", targetType)
@@ -154,20 +143,88 @@ func convertJSValueToGo(vm *goja.Runtime, jsValue goja.Value, targetType reflect
return jsValue.Export(), nil return jsValue.Export(), nil
case reflect.Map: case reflect.Map:
if targetType.Key().Kind() == reflect.String && targetType.Elem().Kind() == reflect.Interface { if goja.IsUndefined(jsValue) || goja.IsNull(jsValue) {
return nil, nil
}
if targetType.Key().Kind() == reflect.String {
obj := jsValue.ToObject(vm) obj := jsValue.ToObject(vm)
if obj == nil { if obj == nil {
return nil, fmt.Errorf("not an object") return nil, fmt.Errorf("not an object")
} }
result := make(map[string]any) if targetType.Elem().Kind() == reflect.Interface {
for _, key := range obj.Keys() { result := make(map[string]any)
result[key] = obj.Get(key).Export() for _, key := range obj.Keys() {
result[key] = obj.Get(key).Export()
}
return result, nil
} else if targetType.Elem().Kind() == reflect.String {
result := make(map[string]string)
for _, key := range obj.Keys() {
v := obj.Get(key)
result[key] = v.String()
}
return result, nil
} }
return result, nil
} }
return nil, fmt.Errorf("unsupported map type: %v", targetType) return nil, fmt.Errorf("unsupported map type: %v", targetType)
case reflect.Struct:
obj := jsValue.ToObject(vm)
if obj == nil {
return nil, fmt.Errorf("not an object")
}
result := reflect.New(targetType).Elem()
for i := 0; i < targetType.NumField(); i++ {
field := targetType.Field(i)
fieldName := getFieldName(field)
jsField := obj.Get(fieldName)
var err error
var converted any
func() {
defer func() {
if r := recover(); r != nil {
// goja.Value was zero - treat as undefined
err = nil
converted = nil
}
}()
converted, err = convertJSValueToGo(vm, jsField, field.Type)
}()
if err != nil {
return nil, fmt.Errorf("field %s: %v", fieldName, err)
}
if converted == nil {
if field.Type.Kind() == reflect.Pointer || field.Type.Kind() == reflect.Map {
continue
}
} else {
result.Field(i).Set(reflect.ValueOf(converted))
}
}
return result.Interface(), nil
case reflect.Pointer:
if goja.IsNull(jsValue) || goja.IsUndefined(jsValue) {
return nil, nil
}
elemType := targetType.Elem()
converted, err := convertJSValueToGo(vm, jsValue, elemType)
if err != nil {
return nil, err
}
ptr := reflect.New(elemType)
ptr.Elem().Set(reflect.ValueOf(converted))
return ptr.Interface(), nil
default: default:
return nil, fmt.Errorf("unsupported type: %v", targetType) return nil, fmt.Errorf("unsupported type: %v", targetType)
} }
@@ -175,26 +232,6 @@ func convertJSValueToGo(vm *goja.Runtime, jsValue goja.Value, targetType reflect
func convertGoValueToJS(vm *goja.Runtime, goValue reflect.Value) goja.Value { func convertGoValueToJS(vm *goja.Runtime, goValue reflect.Value) goja.Value {
value := goValue.Interface() value := goValue.Interface()
valueType := goValue.Type()
registryMutex.RLock()
converter, ok := customConverters[valueType]
registryMutex.RUnlock()
if ok {
return converter(vm, goValue)
}
if goValue.Kind() == reflect.Pointer && !goValue.IsNil() {
elemType := goValue.Type().Elem()
registryMutex.RLock()
converter, ok := customConverters[elemType]
registryMutex.RUnlock()
if ok {
return converter(vm, goValue.Elem())
}
}
switch v := value.(type) { switch v := value.(type) {
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
@@ -213,7 +250,7 @@ func convertGoValueToJS(vm *goja.Runtime, goValue reflect.Value) goja.Value {
case map[string]any: case map[string]any:
obj := vm.NewObject() obj := vm.NewObject()
for key, val := range v { for key, val := range v {
_ = obj.Set(key, val) _ = obj.Set(key, convertGoValueToJS(vm, reflect.ValueOf(val)))
} }
return obj return obj
@@ -225,56 +262,76 @@ func convertGoValueToJS(vm *goja.Runtime, goValue reflect.Value) goja.Value {
return vm.ToValue(arr) return vm.ToValue(arr)
default: default:
if goValue.Kind() == reflect.Pointer {
if goValue.IsNil() {
return goja.Null()
}
return convertGoValueToJS(vm, goValue.Elem())
}
if goValue.Kind() == reflect.Struct {
obj := vm.NewObject()
for i := 0; i < goValue.NumField(); i++ {
field := goValue.Type().Field(i)
fieldName := getFieldName(field)
_ = obj.Set(fieldName, convertGoValueToJS(vm, goValue.Field(i)))
}
return obj
}
return vm.ToValue(v) return vm.ToValue(v)
} }
} }
func generateTypeScriptDefinition(name string, fnType reflect.Type) string { func getFieldName(field reflect.StructField) string {
if fnType.Kind() != reflect.Func { jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
name, _, _ := strings.Cut(jsonTag, ",")
return name
}
return field.Name
}
func generateTypeScriptDefinition(name string, argsType reflect.Type, fnType reflect.Type) string {
if argsType.Kind() != reflect.Struct {
return "" return ""
} }
var params []string var params []string
for i := 0; i < fnType.NumIn(); i++ { for i := 0; i < argsType.NumField(); i++ {
paramName := fmt.Sprintf("arg%d", i) field := argsType.Field(i)
if fnType.In(i).Kind() == reflect.Pointer { fieldName := getFieldName(field)
ptrType := fnType.In(i).Elem() goType := field.Type
if ptrType.Kind() == reflect.Struct {
if s, ok := extractStructParamName(ptrType); ok { tsType := goTypeToTSType(goType, goType.Kind() == reflect.Pointer)
paramName = s params = append(params, fmt.Sprintf("%s: %s", fieldName, tsType))
}
}
}
params = append(params, fmt.Sprintf("%s: %s", paramName, goTypeToTSType(fnType.In(i))))
} }
returnSignature := "void" returnSignature := "any"
if fnType.NumOut() > 0 { if fnType.Kind() == reflect.Func && fnType.NumOut() > 0 {
lastIndex := fnType.NumOut() - 1 lastIndex := fnType.NumOut() - 1
lastType := fnType.Out(lastIndex) lastType := fnType.Out(lastIndex)
if lastType.Implements(reflect.TypeOf((*error)(nil)).Elem()) { if lastType.Implements(reflect.TypeOf((*error)(nil)).Elem()) {
if fnType.NumOut() > 1 { if fnType.NumOut() > 1 {
returnSignature = goTypeToTSType(fnType.Out(0)) returnSignature = goTypeToTSType(fnType.Out(0), false)
} else {
returnSignature = "void"
} }
} else { } else {
returnSignature = goTypeToTSType(lastType) returnSignature = goTypeToTSType(lastType, false)
} }
} }
return fmt.Sprintf("declare function %s(%s): %s;", name, strings.Join(params, ", "), returnSignature) return fmt.Sprintf("declare function %s(%s): %s;", name, strings.Join(params, ", "), returnSignature)
} }
func extractStructParamName(structType reflect.Type) (string, bool) { func goTypeToTSType(t reflect.Type, isPointer bool) string {
if structType.Name() != "" { if isPointer {
return strings.ToLower(structType.Name()), true if t.Kind() == reflect.Pointer {
return goTypeToTSType(t.Elem(), false) + " | null"
}
return goTypeToTSType(t, false) + " | null"
} }
return "", false
}
func goTypeToTSType(t reflect.Type) string {
switch t.Kind() { switch t.Kind() {
case reflect.String: case reflect.String:
return "string" return "string"
@@ -286,17 +343,30 @@ func goTypeToTSType(t reflect.Type) string {
return "number" return "number"
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
case reflect.Interface, reflect.Pointer: case reflect.Interface:
return "any" return "any"
case reflect.Slice: case reflect.Slice:
return fmt.Sprintf("%s[]", goTypeToTSType(t.Elem())) return fmt.Sprintf("%s[]", goTypeToTSType(t.Elem(), false))
case reflect.Map: case reflect.Map:
if t.Key().Kind() == reflect.String && t.Elem().Kind() == reflect.Interface { if t.Key().Kind() == reflect.String && t.Elem().Kind() == reflect.Interface {
return "Record<string, any>" return "Record<string, any>"
} }
return "Record<string, any>" return "Record<string, any>"
case reflect.Struct: case reflect.Struct:
return "any" fields := make([]string, 0, t.NumField())
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
name := getFieldName(field)
tsType := goTypeToTSType(field.Type, field.Type.Kind() == reflect.Pointer)
if field.Type.Kind() == reflect.Pointer {
tsType = strings.TrimSuffix(tsType, " | null")
tsType += "?"
} else if strings.Contains(field.Tag.Get("json"), ",omitempty") {
tsType += "?"
}
fields = append(fields, fmt.Sprintf("%s: %s", name, tsType))
}
return fmt.Sprintf("{ %s }", strings.Join(fields, "; "))
default: default:
return "any" return "any"
} }
@@ -318,10 +388,6 @@ func RegisterBuiltins(vm *goja.Runtime) {
defer registryMutex.RUnlock() defer registryMutex.RUnlock()
for name, builtin := range builtinRegistry { for name, builtin := range builtinRegistry {
if wrapperFactory, ok := builtin.Function.(func(*goja.Runtime) any); ok { _ = vm.Set(name, builtin.Function(vm))
_ = vm.Set(name, wrapperFactory(vm))
} else {
_ = vm.Set(name, builtin.Function)
}
} }
} }

View File

@@ -3,26 +3,66 @@ package standard
import ( import (
"fmt" "fmt"
"io" "io"
"maps"
"net/http" "net/http"
"strings" "strings"
"github.com/dop251/goja"
"reichard.io/poiesis/internal/runtime/pkg/builtin" "reichard.io/poiesis/internal/runtime/pkg/builtin"
) )
type FetchResult struct { type FetchArgs struct {
OK bool URL string `json:"url"`
Status int Options *FetchOptions `json:"options"`
Body string
Headers map[string]string
} }
func Fetch(url string, options map[string]any) (*FetchResult, error) { type FetchOptions struct {
req, err := http.NewRequest("GET", url, nil) Method string `json:"method"`
Headers *map[string]string `json:"headers"`
}
func (o *FetchOptions) Defaults() *FetchOptions {
if o.Method == "" {
o.Method = "GET"
}
return o
}
type FetchResult struct {
OK bool `json:"ok"`
Status int `json:"status"`
Body string `json:"body"`
Headers map[string]string `json:"headers"`
}
type AddArgs struct {
A int `json:"a"`
B int `json:"b"`
}
type GreetArgs struct {
Name string `json:"name"`
}
func Fetch(args FetchArgs) (*FetchResult, error) {
method := "GET"
headers := make(map[string]string)
if args.Options != nil {
method = args.Options.Method
if args.Options.Headers != nil {
maps.Copy(headers, *args.Options.Headers)
}
}
req, err := http.NewRequest(method, args.URL, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request: %w", err)
} }
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch: %w", err) return nil, fmt.Errorf("failed to fetch: %w", err)
@@ -36,12 +76,12 @@ func Fetch(url string, options map[string]any) (*FetchResult, error) {
return nil, fmt.Errorf("failed to read body: %w", err) return nil, fmt.Errorf("failed to read body: %w", err)
} }
headers := make(map[string]string) resultHeaders := make(map[string]string)
for key, values := range resp.Header { for key, values := range resp.Header {
if len(values) > 0 { if len(values) > 0 {
val := values[0] val := values[0]
headers[key] = val resultHeaders[key] = val
headers[strings.ToLower(key)] = val resultHeaders[strings.ToLower(key)] = val
} }
} }
@@ -49,44 +89,20 @@ func Fetch(url string, options map[string]any) (*FetchResult, error) {
OK: resp.StatusCode >= 200 && resp.StatusCode < 300, OK: resp.StatusCode >= 200 && resp.StatusCode < 300,
Status: resp.StatusCode, Status: resp.StatusCode,
Body: string(body), Body: string(body),
Headers: headers, Headers: resultHeaders,
}, nil }, nil
} }
func convertFetchResult(vm *goja.Runtime, result *FetchResult) goja.Value { func add(args AddArgs) int {
if result == nil { return args.A + args.B
return goja.Null() }
}
obj := vm.NewObject() func greet(args GreetArgs) string {
_ = obj.Set("ok", result.OK) return fmt.Sprintf("Hello, %s!", args.Name)
_ = obj.Set("status", result.Status)
_ = obj.Set("text", func() string {
return result.Body
})
headersObj := vm.NewObject()
headers := result.Headers
_ = headersObj.Set("get", func(c goja.FunctionCall) goja.Value {
if len(c.Arguments) < 1 {
return goja.Undefined()
}
key := c.Arguments[0].String()
return vm.ToValue(headers[key])
})
_ = obj.Set("headers", headersObj)
return obj
} }
func init() { func init() {
builtin.RegisterCustomConverter(convertFetchResult) builtin.RegisterBuiltin[FetchArgs, *FetchResult]("fetch", Fetch)
builtin.RegisterBuiltin[AddArgs, int]("add", add)
builtin.RegisterBuiltin("fetch", Fetch) builtin.RegisterBuiltin[GreetArgs, string]("greet", greet)
builtin.RegisterBuiltin("add", func(a, b int) int {
return a + b
})
builtin.RegisterBuiltin("greet", func(name string) string {
return fmt.Sprintf("Hello, %s!", name)
})
} }

View File

@@ -18,7 +18,7 @@ func TestFetch(t *testing.T) {
})) }))
defer server.Close() defer server.Close()
result, err := Fetch(server.URL, nil) result, err := Fetch(FetchArgs{URL: server.URL})
require.NoError(t, err) require.NoError(t, err)
assert.True(t, result.OK) assert.True(t, result.OK)
@@ -32,7 +32,7 @@ func TestFetch(t *testing.T) {
func TestFetchHTTPBin(t *testing.T) { func TestFetchHTTPBin(t *testing.T) {
t.Skip("httpbin.org test is flaky") t.Skip("httpbin.org test is flaky")
result, err := Fetch("https://httpbin.org/get", nil) result, err := Fetch(FetchArgs{URL: "https://httpbin.org/get"})
require.NoError(t, err) require.NoError(t, err)
assert.True(t, result.OK) assert.True(t, result.OK)
@@ -42,7 +42,7 @@ func TestFetchHTTPBin(t *testing.T) {
} }
func TestFetchWith404(t *testing.T) { func TestFetchWith404(t *testing.T) {
result, err := Fetch("https://httpbin.org/status/404", nil) result, err := Fetch(FetchArgs{URL: "https://httpbin.org/status/404"})
require.NoError(t, err) require.NoError(t, err)
assert.False(t, result.OK) assert.False(t, result.OK)
@@ -50,7 +50,58 @@ func TestFetchWith404(t *testing.T) {
} }
func TestFetchWithInvalidURL(t *testing.T) { func TestFetchWithInvalidURL(t *testing.T) {
_, err := Fetch("http://this-domain-does-not-exist-12345.com", nil) _, err := Fetch(FetchArgs{URL: "http://this-domain-does-not-exist-12345.com"})
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to fetch") assert.Contains(t, err.Error(), "failed to fetch")
} }
func TestFetchWithHeaders(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`ok`))
}))
defer server.Close()
headers := map[string]string{
"Authorization": "Bearer test-token",
}
options := &FetchOptions{
Method: "GET",
Headers: &headers,
}
result, err := Fetch(FetchArgs{URL: server.URL, Options: options})
require.NoError(t, err)
assert.True(t, result.OK)
}
func TestFetchDefaults(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method, "default method should be GET")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`ok`))
}))
defer server.Close()
options := &FetchOptions{}
result, err := Fetch(FetchArgs{URL: server.URL, Options: options})
require.NoError(t, err)
assert.True(t, result.OK)
}
func TestAdd(t *testing.T) {
result := add(AddArgs{A: 5, B: 10})
assert.Equal(t, 15, result)
result = add(AddArgs{A: -3, B: 7})
assert.Equal(t, 4, result)
}
func TestGreet(t *testing.T) {
result := greet(GreetArgs{Name: "World"})
assert.Equal(t, "Hello, World!", result)
result = greet(GreetArgs{Name: "Alice"})
assert.Equal(t, "Hello, Alice!", result)
}

View File

@@ -2,5 +2,5 @@ const response = fetch("https://httpbin.org/get");
console.log("OK:", response.ok); console.log("OK:", response.ok);
console.log("Status:", response.status); console.log("Status:", response.status);
console.log("Body:", response.text()); console.log("Body:", response.body);
console.log("Content-Type:", response.headers.get("content-type")); console.log("Content-Type:", response.headers["content-type"]);