Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bigtable): Add MarshalJSON to allow clients to get a stringified version of the protobuf #10679

Merged
merged 22 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions bigtable/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,40 @@ limitations under the License.

package bigtable

import btapb "cloud.google.com/go/bigtable/admin/apiv2/adminpb"
import (
btapb "cloud.google.com/go/bigtable/admin/apiv2/adminpb"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)

// Type wraps the protobuf representation of a type. See the protobuf definition
// for more details on types.
type Type interface {
proto() *btapb.Type
}

var marshalOptions = protojson.MarshalOptions{AllowPartial: true, UseEnumNumbers: true}
var unmarshalOptions = protojson.UnmarshalOptions{AllowPartial: true}

// MarshalJSON returns the string representation of the Type protobuf.
func MarshalJSON(t Type) ([]byte, error) {
return marshalOptions.Marshal(t.proto())
}

// UnmarshalJSON returns a Type object from json bytes.
func UnmarshalJSON(data []byte) (Type, error) {
result := &btapb.Type{}
if err := unmarshalOptions.Unmarshal(data, result); err != nil {
return nil, err
}
return ProtoToType(result), nil
}

// Equal compares Type objects.
func Equal(a, b Type) bool {
return proto.Equal(a.proto(), b.proto())
}

type unknown[T interface{}] struct {
wrapped *T
}
Expand Down Expand Up @@ -205,6 +231,8 @@ func ProtoToType(pb *btapb.Type) Type {
return int64ProtoToType(t.Int64Type)
case *btapb.Type_BytesType:
return bytesProtoToType(t.BytesType)
case *btapb.Type_StringType:
return stringProtoToType(t.StringType)
case *btapb.Type_AggregateType:
return aggregateProtoToType(t.AggregateType)
default:
Expand All @@ -229,6 +257,23 @@ func bytesProtoToType(b *btapb.Type_Bytes) BytesType {
return BytesType{Encoding: bytesEncodingProtoToType(b.Encoding)}
}

func stringEncodingProtoToType(se *btapb.Type_String_Encoding) StringEncoding {
if se == nil {
return unknown[btapb.Type_String_Encoding]{wrapped: se}
}

switch se.Encoding.(type) {
case *btapb.Type_String_Encoding_Utf8Raw_:
return StringUtf8Encoding{}
default:
return unknown[btapb.Type_String_Encoding]{wrapped: se}
}
}

func stringProtoToType(s *btapb.Type_String) Type {
return StringType{Encoding: stringEncodingProtoToType(s.Encoding)}
}

func int64EncodingProtoToEncoding(ie *btapb.Type_Int64_Encoding) Int64Encoding {
if ie == nil {
return unknown[btapb.Type_Int64_Encoding]{wrapped: ie}
Expand All @@ -246,7 +291,7 @@ func int64ProtoToType(i *btapb.Type_Int64) Type {
return Int64Type{Encoding: int64EncodingProtoToEncoding(i.Encoding)}
}

func aggregateProtoToType(agg *btapb.Type_Aggregate) Type {
func aggregateProtoToType(agg *btapb.Type_Aggregate) AggregateType {
if agg == nil {
return AggregateType{Input: nil, Aggregator: unknownAggregator{wrapped: agg}}
}
Expand Down
203 changes: 97 additions & 106 deletions bigtable/type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"testing"

btapb "cloud.google.com/go/bigtable/admin/apiv2/adminpb"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/proto"
)

Expand All @@ -37,12 +39,25 @@ func aggregateProto() *btapb.Type {
}
}

func TestUnknown(t *testing.T) {
unsupportedType := &btapb.Type{
Kind: &btapb.Type_Float64Type{
Float64Type: &btapb.Type_Float64{},
},
}
got, ok := ProtoToType(unsupportedType).(unknown[btapb.Type])
if !ok {
t.Errorf("got: %T, wanted unknown[btapb.Type]", got)
}

assertType(t, got, unsupportedType)
}

func TestInt64Proto(t *testing.T) {
want := aggregateProto()
got := Int64Type{}.proto()
if !proto.Equal(got, want) {
t.Errorf("got type %v, want: %v", got, want)
}
it := Int64Type{Encoding: BigEndianBytesEncoding{}}

assertType(t, it, want)
}

func TestStringProto(t *testing.T) {
Expand All @@ -55,39 +70,9 @@ func TestStringProto(t *testing.T) {
},
},
}
st := StringType{Encoding: StringUtf8Encoding{}}

got := StringType{}.proto()
if !proto.Equal(got, want) {
t.Errorf("got type %v, want: %v", got, want)
}
}

func TestSumAggregateProto(t *testing.T) {
want := &btapb.Type{
Kind: &btapb.Type_AggregateType{
AggregateType: &btapb.Type_Aggregate{
InputType: &btapb.Type{
Kind: &btapb.Type_Int64Type{
Int64Type: &btapb.Type_Int64{
Encoding: &btapb.Type_Int64_Encoding{
Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{
BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{},
},
},
},
},
},
Aggregator: &btapb.Type_Aggregate_Sum_{
Sum: &btapb.Type_Aggregate_Sum{},
},
},
},
}

got := AggregateType{Input: Int64Type{}, Aggregator: SumAggregator{}}.proto()
if !proto.Equal(got, want) {
t.Errorf("got type %v, want: %v", got, want)
}
assertType(t, st, want)
}

func TestProtoBijection(t *testing.T) {
Expand All @@ -98,88 +83,100 @@ func TestProtoBijection(t *testing.T) {
}
}

func TestMinAggregateProto(t *testing.T) {
want := &btapb.Type{
Kind: &btapb.Type_AggregateType{
AggregateType: &btapb.Type_Aggregate{
InputType: &btapb.Type{
Kind: &btapb.Type_Int64Type{
Int64Type: &btapb.Type_Int64{
Encoding: &btapb.Type_Int64_Encoding{
Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{
BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{},
},
},
},
func TestAggregateProto(t *testing.T) {
intType := &btapb.Type{
Kind: &btapb.Type_Int64Type{
Int64Type: &btapb.Type_Int64{
Encoding: &btapb.Type_Int64_Encoding{
Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{
BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{},
},
},
Aggregator: &btapb.Type_Aggregate_Min_{
Min: &btapb.Type_Aggregate_Min{},
},
},
},
}

got := AggregateType{Input: Int64Type{}, Aggregator: MinAggregator{}}.proto()
if !proto.Equal(got, want) {
t.Errorf("got type %v, want: %v", got, want)
}
}

func TestMaxAggregateProto(t *testing.T) {
want := &btapb.Type{
Kind: &btapb.Type_AggregateType{
AggregateType: &btapb.Type_Aggregate{
InputType: &btapb.Type{
Kind: &btapb.Type_Int64Type{
Int64Type: &btapb.Type_Int64{
Encoding: &btapb.Type_Int64_Encoding{
Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{
BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{},
},
},
},
},
testCases := []struct {
name string
agg Aggregator
protoAgg btapb.Type_Aggregate
}{
{
name: "hll",
agg: HllppUniqueCountAggregator{},
protoAgg: btapb.Type_Aggregate{
InputType: intType,
Aggregator: &btapb.Type_Aggregate_HllppUniqueCount{
HllppUniqueCount: &btapb.Type_Aggregate_HyperLogLogPlusPlusUniqueCount{},
},
},
},
{
name: "min",
agg: MinAggregator{},
protoAgg: btapb.Type_Aggregate{
InputType: intType,
Aggregator: &btapb.Type_Aggregate_Min_{
Min: &btapb.Type_Aggregate_Min{},
},
},
},
{
name: "max",
agg: MaxAggregator{},
protoAgg: btapb.Type_Aggregate{
InputType: intType,
Aggregator: &btapb.Type_Aggregate_Max_{
Max: &btapb.Type_Aggregate_Max{},
},
},
},
}
{
name: "sum",
agg: SumAggregator{},
protoAgg: btapb.Type_Aggregate{
InputType: intType,
Aggregator: &btapb.Type_Aggregate_Sum_{
Sum: &btapb.Type_Aggregate_Sum{},
},
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
want := &btapb.Type{
Kind: &btapb.Type_AggregateType{
AggregateType: &tc.protoAgg,
},
}
at := AggregateType{Input: Int64Type{Encoding: BigEndianBytesEncoding{}}, Aggregator: tc.agg}

got := AggregateType{Input: Int64Type{}, Aggregator: MaxAggregator{}}.proto()
if !proto.Equal(got, want) {
t.Errorf("got type %v, want: %v", got, want)
assertType(t, at, want)
})
}
}

func TestHllAggregateProto(t *testing.T) {
want := &btapb.Type{
Kind: &btapb.Type_AggregateType{
AggregateType: &btapb.Type_Aggregate{
InputType: &btapb.Type{
Kind: &btapb.Type_Int64Type{
Int64Type: &btapb.Type_Int64{
Encoding: &btapb.Type_Int64_Encoding{
Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{
BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{},
},
},
},
},
},
Aggregator: &btapb.Type_Aggregate_HllppUniqueCount{
HllppUniqueCount: &btapb.Type_Aggregate_HyperLogLogPlusPlusUniqueCount{},
},
},
},
}
func assertType(t *testing.T, ty Type, want *btapb.Type) {
t.Helper()

got := AggregateType{Input: Int64Type{}, Aggregator: HllppUniqueCountAggregator{}}.proto()
got := ty.proto()
if !proto.Equal(got, want) {
t.Errorf("got type %v, want: %v", got, want)
}

gotJSON, err := MarshalJSON(ty)
if err != nil {
t.Fatalf("Error calling MarshalJSON: %v", err)
}
result, err := UnmarshalJSON(gotJSON)
if err != nil {
t.Fatalf("Error calling UnmarshalJSON: %v", err)
}
if diff := cmp.Diff(result, ty, cmpopts.IgnoreUnexported(unknown[btapb.Type]{})); diff != "" {
t.Errorf("Unexpected diff: \n%s", diff)
}
if !Equal(result, ty) {
t.Errorf("Unexpected result. Got %#v, want %#v", result, ty)
}
}

func TestNilChecks(t *testing.T) {
Expand Down Expand Up @@ -208,21 +205,15 @@ func TestNilChecks(t *testing.T) {
}

// aggregateProtoToType
aggType1, ok := aggregateProtoToType(nil).(AggregateType)
if !ok {
t.Fatalf("got: %T, wanted AggregateType", aggType1)
}
aggType1 := aggregateProtoToType(nil)
if val, ok := aggType1.Aggregator.(unknownAggregator); !ok {
t.Errorf("got: %T, wanted unknownAggregator", val)
}
if aggType1.Input != nil {
t.Errorf("got: %v, wanted nil", aggType1.Input)
}

aggType2, ok := aggregateProtoToType(&btapb.Type_Aggregate{}).(AggregateType)
if !ok {
t.Fatalf("got: %T, wanted AggregateType", aggType2)
}
aggType2 := aggregateProtoToType(&btapb.Type_Aggregate{})
if val, ok := aggType2.Aggregator.(unknownAggregator); !ok {
t.Errorf("got: %T, wanted unknownAggregator", val)
}
Expand Down
Loading