diff --git a/codec/codec.go b/codec/codec.go index 0f8dd67..447eb83 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -1,17 +1,26 @@ package codec import ( - "github.com/pentops/j5/lib/j5schema" + "github.com/pentops/j5/lib/j5reflect" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" ) +// MessageTypeResolver is a subset of protoregistry.MessageTypeResolver +type MessageTypeResolver interface { + FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) +} + type Codec struct { - schemaSet *j5schema.SchemaCache + refl *j5reflect.Reflector + resolver MessageTypeResolver } -func NewCodec() *Codec { +func NewCodec(resolver protoregistry.MessageTypeResolver) *Codec { + refl := j5reflect.New() return &Codec{ - schemaSet: j5schema.NewSchemaCache(), + refl: refl, + resolver: resolver, } } diff --git a/codec/codec_test.go b/codec/codec_test.go index 74db923..3f85975 100644 --- a/codec/codec_test.go +++ b/codec/codec_test.go @@ -12,7 +12,9 @@ import ( "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/dynamicpb" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pentops/flowtest/prototest" @@ -21,17 +23,18 @@ import ( "github.com/pentops/j5/j5types/decimal_j5t" ) -/* - func mustAny(t testing.TB, msg proto.Message) *anypb.Any { - a, err := anypb.New(msg) - if err != nil { - t.Fatal(err) - } - return a +func mustAny(t testing.TB, msg proto.Message) *anypb.Any { + a, err := anypb.New(msg) + if err != nil { + t.Fatal(err) } -*/ + return a +} + func TestUnmarshal(t *testing.T) { + codec := NewCodec(protoregistry.GlobalTypes) + testTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) for _, tc := range []struct { @@ -339,13 +342,12 @@ func TestUnmarshal(t *testing.T) { wantProto: &schema_testpb.FullSchema{ KeyString: "keyVal", }, - }, - /*{ + }, { name: "any", json: `{ "any": { - "!type": "test.v1.Bar", - "bar": { + "!type": "test.schema.v1.Bar", + "value": { "barId": "barId" } } @@ -355,20 +357,12 @@ func TestUnmarshal(t *testing.T) { BarId: "barId", }), }, - }*/} { + }} { t.Run(tc.name, func(t *testing.T) { allInputs := append(tc.altInputJSON, tc.json) - codec := NewCodec() for _, input := range allInputs { - /* - schema, err := codec.schemaSet.SchemaObject(tc.wantProto.ProtoReflect().Descriptor()) - if err != nil { - t.Fatal(err) - } - t.Logf("SCHEMA: %s", protojson.Format(schema)) - */ logIndent(t, "input", input) msg := tc.wantProto.ProtoReflect().New().Interface() @@ -376,16 +370,16 @@ func TestUnmarshal(t *testing.T) { t.Fatalf("JSONToProto: %s", err) } - t.Logf("got decoded proto: %s \n%v\n", msg.ProtoReflect().Descriptor().FullName(), prototext.Format(msg)) + t.Logf("GOT proto: %s \n%v\n", msg.ProtoReflect().Descriptor().FullName(), prototext.Format(msg)) if !proto.Equal(tc.wantProto, msg) { a := prototext.Format(tc.wantProto) - t.Fatalf("expected proto %s\n%v\n", tc.wantProto.ProtoReflect().Descriptor().FullName(), string(a)) + t.Fatalf("FATAL: Expected proto %s\n%v\n", tc.wantProto.ProtoReflect().Descriptor().FullName(), string(a)) } encoded, err := codec.ProtoToJSON(msg.ProtoReflect()) if err != nil { - t.Fatalf("ProtoToJSON: %s", err) + t.Fatalf("FATAL: ProtoToJSON: %s", err) } logIndent(t, "output", string(encoded)) @@ -419,7 +413,7 @@ func TestScalars(t *testing.T) { runTest := func(t testing.TB, tc testCase) { - codec := NewCodec() + codec := NewCodec(protoregistry.GlobalTypes) msgIn := dynamicpb.NewMessage(tc.desc) diff --git a/codec/decoder.go b/codec/decoder.go index 6ae3e2a..0e9786e 100644 --- a/codec/decoder.go +++ b/codec/decoder.go @@ -7,7 +7,10 @@ import ( "strings" "github.com/pentops/j5/lib/j5reflect" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/known/anypb" ) func (c *Codec) decode(jsonData []byte, msg protoreflect.Message) error { @@ -18,7 +21,7 @@ func (c *Codec) decode(jsonData []byte, msg protoreflect.Message) error { codec: c, } - root, err := j5reflect.NewWithCache(c.schemaSet).NewRoot(msg) + root, err := c.refl.NewRoot(msg) if err != nil { return err } @@ -33,6 +36,7 @@ func (c *Codec) decode(jsonData []byte, msg protoreflect.Message) error { } } +// decoder is an instance for decoding a single message, not reusable. type decoder struct { jd *json.Decoder codec *Codec @@ -180,16 +184,7 @@ func (dec *decoder) decodeValue(field j5reflect.Field) error { return dec.decodeEnum(ft) case j5reflect.ScalarField: - tok, err := dec.Token() - if err != nil { - return err - } - - if _, ok := tok.(json.Delim); ok { - return unexpectedTokenError(tok, "scalar") - } - - return ft.SetGoValue(tok) + return dec.decodeScalar(ft) case j5reflect.AnyField: return dec.decodeAny(ft) @@ -199,6 +194,19 @@ func (dec *decoder) decodeValue(field j5reflect.Field) error { } } +func (dec *decoder) decodeScalar(field j5reflect.ScalarField) error { + tok, err := dec.Token() + if err != nil { + return err + } + + if _, ok := tok.(json.Delim); ok { + return unexpectedTokenError(tok, "scalar") + } + + return field.SetGoValue(tok) +} + func (dec *decoder) decodeEnum(field j5reflect.EnumField) error { token, err := dec.Token() if err != nil { @@ -257,11 +265,37 @@ func (dec *decoder) decodeAny(field j5reflect.AnyField) error { return newFieldError("value", "no value found in Any") } - // This code assumes the schema has been pre-loaded + // takes the PROTO name, which should match the encoder. + innerDesc, err := dec.codec.resolver.FindMessageByName(protoreflect.FullName(*constrainType)) + if err != nil { + if err == protoregistry.NotFound { + return newFieldError(*constrainType, fmt.Sprintf("no type %q in registry", *constrainType)) + } + return newFieldError(*constrainType, err.Error()) + } + msg := innerDesc.New() + + if err := dec.codec.decode(valueBytes, msg); err != nil { + return newFieldError(*constrainType, err.Error()) + } + + protoBytes, err := proto.Marshal(msg.Interface()) + if err != nil { + return newFieldError(*constrainType, err.Error()) + } + + anyVal := &anypb.Any{ + Value: protoBytes, + TypeUrl: anyPrefix + *constrainType, + } + + field.SetProtoAny(anyVal) return nil } +const anyPrefix = "type.googleapis.com/" + func (dec *decoder) decodeOneof(oneof j5reflect.Oneof) error { foundKeys := []string{} diff --git a/codec/encoder.go b/codec/encoder.go index 1d491ec..d9295df 100644 --- a/codec/encoder.go +++ b/codec/encoder.go @@ -11,10 +11,11 @@ import ( func (c *Codec) encode(msg protoreflect.Message) ([]byte, error) { enc := &encoder{ - b: &bytes.Buffer{}, + codec: c, + b: &bytes.Buffer{}, } - root, err := j5reflect.NewWithCache(c.schemaSet).NewRoot(msg) + root, err := c.refl.NewRoot(msg) if err != nil { return nil, err } @@ -35,7 +36,8 @@ func (c *Codec) encode(msg protoreflect.Message) ([]byte, error) { } type encoder struct { - b *bytes.Buffer + b *bytes.Buffer + codec *Codec } func (enc *encoder) add(b []byte) { diff --git a/codec/structure_encode.go b/codec/structure_encode.go index cae3cd7..c1f882a 100644 --- a/codec/structure_encode.go +++ b/codec/structure_encode.go @@ -3,6 +3,7 @@ package codec import ( "encoding/base64" "fmt" + "strings" "time" "github.com/pentops/j5/j5types/date_j5t" @@ -73,6 +74,45 @@ func (enc *encoder) encodeObject(object j5reflect.Object) error { return enc.encodeObjectBody(object) } +func (enc *encoder) encodeAny(anyField j5reflect.AnyField) error { + protoAny := anyField.GetProtoAny() + msg, err := protoAny.UnmarshalNew() + if err != nil { + return err + } + + innerBytes, err := enc.codec.encode(msg.ProtoReflect()) + if err != nil { + return err + } + + enc.openObject() + defer enc.closeObject() + + err = enc.fieldLabel("!type") + if err != nil { + return err + } + + typeName := strings.TrimPrefix(protoAny.TypeUrl, anyPrefix) + err = enc.addString(typeName) + if err != nil { + return err + } + + enc.fieldSep() + + err = enc.fieldLabel("value") + if err != nil { + return err + } + + enc.add(innerBytes) + + return nil + +} + func (enc *encoder) encodeValue(field j5reflect.Field) error { switch ft := field.(type) { @@ -94,6 +134,9 @@ func (enc *encoder) encodeValue(field j5reflect.Field) error { case j5reflect.ScalarField: return enc.encodeScalarField(ft) + case j5reflect.AnyField: + return enc.encodeAny(ft) + default: return fmt.Errorf("encode value of type %q, unsupported", field.FullTypeName()) } diff --git a/internal/grpcreflect/protoread.go b/internal/grpcreflect/protoread.go deleted file mode 100644 index 7a927ac..0000000 --- a/internal/grpcreflect/protoread.go +++ /dev/null @@ -1,123 +0,0 @@ -package grpcreflect - -import ( - "context" - "errors" - "fmt" - "strings" - "time" - - "github.com/pentops/log.go/log" - "google.golang.org/grpc" - "google.golang.org/grpc/reflection/grpc_reflection_v1" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protodesc" - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/descriptorpb" -) - -type Client struct { - client grpc_reflection_v1.ServerReflectionClient -} - -func NewClient(conn *grpc.ClientConn) *Client { - client := grpc_reflection_v1.NewServerReflectionClient(conn) - return &Client{client: client} -} - -func (rc *Client) FetchServices(ctx context.Context) ([]protoreflect.ServiceDescriptor, error) { - - var stream grpc_reflection_v1.ServerReflection_ServerReflectionInfoClient - for { - cc, err := rc.client.ServerReflectionInfo(ctx) - if err == nil { - stream = cc - break - } - - log.WithError(ctx, err).Error("fetching services. Retrying") - - if errors.Is(err, context.Canceled) { - return nil, err - } - - select { - case <-time.After(time.Second): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - - roundTrip := func(req *grpc_reflection_v1.ServerReflectionRequest) (*grpc_reflection_v1.ServerReflectionResponse, error) { - if err := stream.Send(req); err != nil { - return nil, err - } - return stream.Recv() - } - - resp, err := roundTrip(&grpc_reflection_v1.ServerReflectionRequest{ - MessageRequest: &grpc_reflection_v1.ServerReflectionRequest_ListServices{}, - }) - if err != nil { - return nil, err - } - - ds := &descriptorpb.FileDescriptorSet{} - serviceNames := make([]string, 0) - - fileSet := make(map[string]struct{}) - - for _, service := range resp.GetListServicesResponse().Service { - // don't register the reflection service - if strings.HasPrefix(service.Name, "grpc.reflection") { - continue - } - - serviceNames = append(serviceNames, service.Name) - - fileResp, err := roundTrip(&grpc_reflection_v1.ServerReflectionRequest{ - MessageRequest: &grpc_reflection_v1.ServerReflectionRequest_FileContainingSymbol{ - FileContainingSymbol: service.Name, - }, - }) - if err != nil { - return nil, err - } - - for _, rawFile := range fileResp.GetFileDescriptorResponse().FileDescriptorProto { - file := &descriptorpb.FileDescriptorProto{} - if err := proto.Unmarshal(rawFile, file); err != nil { - return nil, err - } - if _, ok := fileSet[file.GetName()]; ok { - continue - } - fileSet[file.GetName()] = struct{}{} - ds.File = append(ds.File, file) - } - } - - files, err := protodesc.NewFiles(ds) - if err != nil { - return nil, err - } - - services := make([]protoreflect.ServiceDescriptor, 0, len(serviceNames)) - - for _, serviceName := range serviceNames { - - ssI, err := files.FindDescriptorByName(protoreflect.FullName(serviceName)) - if err != nil { - return nil, err - } - - ss, ok := ssI.(protoreflect.ServiceDescriptor) - if !ok { - return nil, fmt.Errorf("not a service") - } - - services = append(services, ss) - } - - return services, nil -} diff --git a/internal/grpcreflect/protoread_test.go b/internal/grpcreflect/protoread_test.go deleted file mode 100644 index 982fd2e..0000000 --- a/internal/grpcreflect/protoread_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package grpcreflect - -import ( - "context" - "testing" - - "github.com/pentops/flowtest" - "github.com/pentops/j5/gen/test/foo/v1/foo_testspb" - "google.golang.org/grpc/reflection" - "google.golang.org/protobuf/reflect/protoreflect" -) - -type Service struct { - foo_testspb.UnimplementedFooCommandServiceServer - foo_testspb.UnimplementedFooQueryServiceServer - foo_testspb.UnimplementedFooDownloadServiceServer -} - -func TestProtoReadHappy(t *testing.T) { - grpcPair := flowtest.NewGRPCPair(t) - - service := &Service{} - foo_testspb.RegisterFooQueryServiceServer(grpcPair.Server, service) - foo_testspb.RegisterFooCommandServiceServer(grpcPair.Server, service) - foo_testspb.RegisterFooDownloadServiceServer(grpcPair.Server, service) - - reflection.Register(grpcPair.Server) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - grpcPair.ServeUntilDone(t, ctx) - - cl := NewClient(grpcPair.Client) - - desc, err := cl.FetchServices(ctx) - if err != nil { - t.Fatal(err) - } - - byName := make(map[protoreflect.FullName]protoreflect.ServiceDescriptor) - for _, d := range desc { - byName[d.FullName()] = d - } - - if len(byName) != 3 { - t.Fatalf("expected 3 services, got %d", len(byName)) - } - - _, ok := byName["test.foo.v1.service.FooQueryService"] - if !ok { - t.Fatal("missing FooQueryService") - } - -} diff --git a/lib/j5reflect/type_any.go b/lib/j5reflect/type_any.go index b6fc867..ec090e2 100644 --- a/lib/j5reflect/type_any.go +++ b/lib/j5reflect/type_any.go @@ -3,12 +3,16 @@ package j5reflect import ( "github.com/pentops/j5/lib/j5schema" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" ) /*** Interface ***/ type AnyField interface { Field + + SetProtoAny(val *anypb.Any) + GetProtoAny() *anypb.Any } /*** Implementation ***/ @@ -26,6 +30,27 @@ func (field *anyField) IsSet() bool { return true } +var anyTypeField protoreflect.FieldDescriptor +var anyValueField protoreflect.FieldDescriptor + +func init() { + desc := (&anypb.Any{}).ProtoReflect().Descriptor() + anyTypeField = desc.Fields().ByName("type_url") + anyValueField = desc.Fields().ByName("value") +} + +func (field *anyField) SetProtoAny(val *anypb.Any) { + field.value.Set(anyTypeField, protoreflect.ValueOfString(val.TypeUrl)) + field.value.Set(anyValueField, protoreflect.ValueOfBytes(val.Value)) +} + +func (field *anyField) GetProtoAny() *anypb.Any { + return &anypb.Any{ + TypeUrl: field.value.Get(anyTypeField).String(), + Value: field.value.Get(anyValueField).Bytes(), + } +} + var _ AnyField = (*anyField)(nil) type anyFieldFactory struct { @@ -34,7 +59,8 @@ type anyFieldFactory struct { func (factory *anyFieldFactory) buildField(context fieldContext, value protoreflect.Message) Field { return &anyField{ - schema: factory.schema, - value: value, + schema: factory.schema, + value: value, + fieldContext: context, } } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index e60adec..76924f3 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -19,6 +19,7 @@ import ( "google.golang.org/genproto/googleapis/api/httpbody" "google.golang.org/grpc" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoregistry" ) func TestGetHandlerMapping(t *testing.T) { @@ -26,7 +27,7 @@ func TestGetHandlerMapping(t *testing.T) { serviceDesc := foo_testspb.File_test_foo_v1_service_foo_service_proto. Services().ByName("FooQueryService") - rr := NewRouter(codec.NewCodec()) + rr := NewRouter(codec.NewCodec(protoregistry.GlobalTypes)) rr.globalAuth = AuthHeadersFunc(func(ctx context.Context, req *http.Request) (map[string]string, error) { return map[string]string{}, nil }) @@ -150,7 +151,7 @@ func TestBodyHandlerMapping(t *testing.T) { Services().ByName("FooCommandService"). Methods().ByName("PostFoo") - rr := NewRouter(codec.NewCodec()) + rr := NewRouter(codec.NewCodec(protoregistry.GlobalTypes)) method, err := rr.buildMethod(fd, nil, &auth_j5pb.MethodAuthType_None{}) if err != nil { t.Fatal(err) @@ -299,7 +300,7 @@ func TestRawBodyHandler(t *testing.T) { fd := foo_testspb.File_test_foo_v1_service_foo_service_proto. Services().ByName("FooDownloadService") - rr := NewRouter(codec.NewCodec()) + rr := NewRouter(codec.NewCodec(protoregistry.GlobalTypes)) bodyData, err := proto.Marshal(&httpbody.HttpBody{ ContentType: "application/octet-stream", @@ -332,7 +333,7 @@ func TestAuthMethods(t *testing.T) { authHeaders := map[string]string{} called := false - rr := NewRouter(codec.NewCodec()) + rr := NewRouter(codec.NewCodec(protoregistry.GlobalTypes)) rr.globalAuth = AuthHeadersFunc(func(ctx context.Context, req *http.Request) (map[string]string, error) { called = true return authHeaders, nil