diff --git a/bigtable/type.go b/bigtable/type.go index 59f954f081f7..b99bb4a251d1 100644 --- a/bigtable/type.go +++ b/bigtable/type.go @@ -16,7 +16,11 @@ limitations under the License. package bigtable -import btapb "cloud.google.com/go/bigtable/admin/apiv2/adminpb" +import ( + btapb "cloud.google.com/go/bigtable/admin/apiv2/adminpb" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) // Type wraps the protobuf representation of a type. See the protobuf definition // for more details on types. @@ -24,6 +28,28 @@ type Type interface { proto() *btapb.Type } +var marshalOptions = protojson.MarshalOptions{AllowPartial: true, UseEnumNumbers: true} +var unmarshalOptions = protojson.UnmarshalOptions{AllowPartial: true} + +// MarshalJSON returns the string representation of the Type protobuf. +func MarshalJSON(t Type) ([]byte, error) { + return marshalOptions.Marshal(t.proto()) +} + +// UnmarshalJSON returns a Type object from json bytes. +func UnmarshalJSON(data []byte) (Type, error) { + result := &btapb.Type{} + if err := unmarshalOptions.Unmarshal(data, result); err != nil { + return nil, err + } + return ProtoToType(result), nil +} + +// Equal compares Type objects. +func Equal(a, b Type) bool { + return proto.Equal(a.proto(), b.proto()) +} + type unknown[T interface{}] struct { wrapped *T } @@ -205,6 +231,8 @@ func ProtoToType(pb *btapb.Type) Type { return int64ProtoToType(t.Int64Type) case *btapb.Type_BytesType: return bytesProtoToType(t.BytesType) + case *btapb.Type_StringType: + return stringProtoToType(t.StringType) case *btapb.Type_AggregateType: return aggregateProtoToType(t.AggregateType) default: @@ -229,6 +257,23 @@ func bytesProtoToType(b *btapb.Type_Bytes) BytesType { return BytesType{Encoding: bytesEncodingProtoToType(b.Encoding)} } +func stringEncodingProtoToType(se *btapb.Type_String_Encoding) StringEncoding { + if se == nil { + return unknown[btapb.Type_String_Encoding]{wrapped: se} + } + + switch se.Encoding.(type) { + case *btapb.Type_String_Encoding_Utf8Raw_: + return StringUtf8Encoding{} + default: + return unknown[btapb.Type_String_Encoding]{wrapped: se} + } +} + +func stringProtoToType(s *btapb.Type_String) Type { + return StringType{Encoding: stringEncodingProtoToType(s.Encoding)} +} + func int64EncodingProtoToEncoding(ie *btapb.Type_Int64_Encoding) Int64Encoding { if ie == nil { return unknown[btapb.Type_Int64_Encoding]{wrapped: ie} @@ -246,7 +291,7 @@ func int64ProtoToType(i *btapb.Type_Int64) Type { return Int64Type{Encoding: int64EncodingProtoToEncoding(i.Encoding)} } -func aggregateProtoToType(agg *btapb.Type_Aggregate) Type { +func aggregateProtoToType(agg *btapb.Type_Aggregate) AggregateType { if agg == nil { return AggregateType{Input: nil, Aggregator: unknownAggregator{wrapped: agg}} } diff --git a/bigtable/type_test.go b/bigtable/type_test.go index e80525a698c4..9961ee410c7f 100644 --- a/bigtable/type_test.go +++ b/bigtable/type_test.go @@ -20,6 +20,8 @@ import ( "testing" btapb "cloud.google.com/go/bigtable/admin/apiv2/adminpb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/protobuf/proto" ) @@ -37,12 +39,25 @@ func aggregateProto() *btapb.Type { } } +func TestUnknown(t *testing.T) { + unsupportedType := &btapb.Type{ + Kind: &btapb.Type_Float64Type{ + Float64Type: &btapb.Type_Float64{}, + }, + } + got, ok := ProtoToType(unsupportedType).(unknown[btapb.Type]) + if !ok { + t.Errorf("got: %T, wanted unknown[btapb.Type]", got) + } + + assertType(t, got, unsupportedType) +} + func TestInt64Proto(t *testing.T) { want := aggregateProto() - got := Int64Type{}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) - } + it := Int64Type{Encoding: BigEndianBytesEncoding{}} + + assertType(t, it, want) } func TestStringProto(t *testing.T) { @@ -55,39 +70,9 @@ func TestStringProto(t *testing.T) { }, }, } + st := StringType{Encoding: StringUtf8Encoding{}} - got := StringType{}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) - } -} - -func TestSumAggregateProto(t *testing.T) { - want := &btapb.Type{ - Kind: &btapb.Type_AggregateType{ - AggregateType: &btapb.Type_Aggregate{ - InputType: &btapb.Type{ - Kind: &btapb.Type_Int64Type{ - Int64Type: &btapb.Type_Int64{ - Encoding: &btapb.Type_Int64_Encoding{ - Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ - BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, - }, - }, - }, - }, - }, - Aggregator: &btapb.Type_Aggregate_Sum_{ - Sum: &btapb.Type_Aggregate_Sum{}, - }, - }, - }, - } - - got := AggregateType{Input: Int64Type{}, Aggregator: SumAggregator{}}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) - } + assertType(t, st, want) } func TestProtoBijection(t *testing.T) { @@ -98,88 +83,100 @@ func TestProtoBijection(t *testing.T) { } } -func TestMinAggregateProto(t *testing.T) { - want := &btapb.Type{ - Kind: &btapb.Type_AggregateType{ - AggregateType: &btapb.Type_Aggregate{ - InputType: &btapb.Type{ - Kind: &btapb.Type_Int64Type{ - Int64Type: &btapb.Type_Int64{ - Encoding: &btapb.Type_Int64_Encoding{ - Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ - BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, - }, - }, - }, +func TestAggregateProto(t *testing.T) { + intType := &btapb.Type{ + Kind: &btapb.Type_Int64Type{ + Int64Type: &btapb.Type_Int64{ + Encoding: &btapb.Type_Int64_Encoding{ + Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ + BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, }, }, - Aggregator: &btapb.Type_Aggregate_Min_{ - Min: &btapb.Type_Aggregate_Min{}, - }, }, }, } - got := AggregateType{Input: Int64Type{}, Aggregator: MinAggregator{}}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) - } -} - -func TestMaxAggregateProto(t *testing.T) { - want := &btapb.Type{ - Kind: &btapb.Type_AggregateType{ - AggregateType: &btapb.Type_Aggregate{ - InputType: &btapb.Type{ - Kind: &btapb.Type_Int64Type{ - Int64Type: &btapb.Type_Int64{ - Encoding: &btapb.Type_Int64_Encoding{ - Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ - BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, - }, - }, - }, - }, + testCases := []struct { + name string + agg Aggregator + protoAgg btapb.Type_Aggregate + }{ + { + name: "hll", + agg: HllppUniqueCountAggregator{}, + protoAgg: btapb.Type_Aggregate{ + InputType: intType, + Aggregator: &btapb.Type_Aggregate_HllppUniqueCount{ + HllppUniqueCount: &btapb.Type_Aggregate_HyperLogLogPlusPlusUniqueCount{}, + }, + }, + }, + { + name: "min", + agg: MinAggregator{}, + protoAgg: btapb.Type_Aggregate{ + InputType: intType, + Aggregator: &btapb.Type_Aggregate_Min_{ + Min: &btapb.Type_Aggregate_Min{}, }, + }, + }, + { + name: "max", + agg: MaxAggregator{}, + protoAgg: btapb.Type_Aggregate{ + InputType: intType, Aggregator: &btapb.Type_Aggregate_Max_{ Max: &btapb.Type_Aggregate_Max{}, }, }, }, - } + { + name: "sum", + agg: SumAggregator{}, + protoAgg: btapb.Type_Aggregate{ + InputType: intType, + Aggregator: &btapb.Type_Aggregate_Sum_{ + Sum: &btapb.Type_Aggregate_Sum{}, + }, + }, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + want := &btapb.Type{ + Kind: &btapb.Type_AggregateType{ + AggregateType: &tc.protoAgg, + }, + } + at := AggregateType{Input: Int64Type{Encoding: BigEndianBytesEncoding{}}, Aggregator: tc.agg} - got := AggregateType{Input: Int64Type{}, Aggregator: MaxAggregator{}}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) + assertType(t, at, want) + }) } } -func TestHllAggregateProto(t *testing.T) { - want := &btapb.Type{ - Kind: &btapb.Type_AggregateType{ - AggregateType: &btapb.Type_Aggregate{ - InputType: &btapb.Type{ - Kind: &btapb.Type_Int64Type{ - Int64Type: &btapb.Type_Int64{ - Encoding: &btapb.Type_Int64_Encoding{ - Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ - BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, - }, - }, - }, - }, - }, - Aggregator: &btapb.Type_Aggregate_HllppUniqueCount{ - HllppUniqueCount: &btapb.Type_Aggregate_HyperLogLogPlusPlusUniqueCount{}, - }, - }, - }, - } +func assertType(t *testing.T, ty Type, want *btapb.Type) { + t.Helper() - got := AggregateType{Input: Int64Type{}, Aggregator: HllppUniqueCountAggregator{}}.proto() + got := ty.proto() if !proto.Equal(got, want) { t.Errorf("got type %v, want: %v", got, want) } + + gotJSON, err := MarshalJSON(ty) + if err != nil { + t.Fatalf("Error calling MarshalJSON: %v", err) + } + result, err := UnmarshalJSON(gotJSON) + if err != nil { + t.Fatalf("Error calling UnmarshalJSON: %v", err) + } + if diff := cmp.Diff(result, ty, cmpopts.IgnoreUnexported(unknown[btapb.Type]{})); diff != "" { + t.Errorf("Unexpected diff: \n%s", diff) + } + if !Equal(result, ty) { + t.Errorf("Unexpected result. Got %#v, want %#v", result, ty) + } } func TestNilChecks(t *testing.T) { @@ -208,10 +205,7 @@ func TestNilChecks(t *testing.T) { } // aggregateProtoToType - aggType1, ok := aggregateProtoToType(nil).(AggregateType) - if !ok { - t.Fatalf("got: %T, wanted AggregateType", aggType1) - } + aggType1 := aggregateProtoToType(nil) if val, ok := aggType1.Aggregator.(unknownAggregator); !ok { t.Errorf("got: %T, wanted unknownAggregator", val) } @@ -219,10 +213,7 @@ func TestNilChecks(t *testing.T) { t.Errorf("got: %v, wanted nil", aggType1.Input) } - aggType2, ok := aggregateProtoToType(&btapb.Type_Aggregate{}).(AggregateType) - if !ok { - t.Fatalf("got: %T, wanted AggregateType", aggType2) - } + aggType2 := aggregateProtoToType(&btapb.Type_Aggregate{}) if val, ok := aggType2.Aggregator.(unknownAggregator); !ok { t.Errorf("got: %T, wanted unknownAggregator", val) }