Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner): implement generation and propagation of "x-goog-spanner-request-id" Header #11048

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions spanner/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,13 @@ func (t *BatchReadOnlyTransaction) partitionQuery(ctx context.Context, statement
ParamTypes: paramTypes,
}
sh.updateLastUseTime()

// PartitionQuery does not retry automatically so we don't need to retrieve
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean with this? PartitionQuery retries if it receives an UNAVAILABLE error (same as most unary RPCs). See https://github.com/googleapis/googleapis/blob/master/google/spanner/v1/spanner_grpc_service_config.json for the default RPC configuration.

// the injected requestID to increment the RPC number on retries.
resp, err := client.PartitionQuery(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md)))
if gcl, ok := client.(*grpcSpannerClient); ok {
gcl.setOrResetRPCID()
}

if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "partitionQuery"); err != nil {
Expand Down
13 changes: 12 additions & 1 deletion spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1318,10 +1318,21 @@ func (c *Client) BatchWriteWithOptions(ctx context.Context, mgs []*MutationGroup
return &BatchWriteResponseIterator{meterTracerFactory: c.metricsTracerFactory, err: err}
}

nRPCs := uint64(0)
rpc := func(ct context.Context) (sppb.Spanner_BatchWriteClient, error) {
var md metadata.MD
sh.updateLastUseTime()
stream, rpcErr := sh.getClient().BatchWrite(contextWithOutgoingMetadata(ct, sh.getMetadata(), c.disableRouteToLeader), &sppb.BatchWriteRequest{
nRPCs++

// Firstly set the number of retries as the RPCID.
client := sh.getClient()
gcl, ok := client.(*grpcSpannerClient)
if ok {
gcl.setRPCID(nRPCs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to assume that there will be only one active request for a gRPC client at the same time. That does not seem correct for two reasons:

  1. For multiplexed sessions, we keep a pool of 4 (or more correct: numChannels) grpcSpannerClients. These clients are shared across all goroutines that use multiplexed sessions.
  2. Regular sessions can also execute requests in parallel. Those requests would also use the same grpcSpannerClient, meaning that keeping track of for example the number of (retry) attempts at the grpcSpannerClient level won't work.

defer gcl.setOrResetRPCID()
}

stream, rpcErr := client.BatchWrite(contextWithOutgoingMetadata(ct, sh.getMetadata(), c.disableRouteToLeader), &sppb.BatchWriteRequest{
Session: sh.getID(),
MutationGroups: mgsPb,
RequestOptions: createRequestOptions(opts.Priority, "", opts.TransactionTag),
Expand Down
4 changes: 4 additions & 0 deletions spanner/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ func (e *Error) decorate(info string) {
// APIError error having given error code as its status.
func spannerErrorf(code codes.Code, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return spannerError(code, msg)
}

func spannerError(code codes.Code, msg string) error {
wrapped, _ := apierror.FromError(status.Error(code, msg))
return &Error{
Code: code,
Expand Down
94 changes: 78 additions & 16 deletions spanner/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@

import (
"context"
"fmt"
"math/rand"
"strings"
"sync/atomic"
"time"

vkit "cloud.google.com/go/spanner/apiv1"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"cloud.google.com/go/spanner/internal"
"github.com/googleapis/gax-go/v2"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -65,10 +70,43 @@
// grpcSpannerClient is the gRPC API implementation of the transport-agnostic
// spannerClient interface.
type grpcSpannerClient struct {
id uint64
raw *vkit.Client
metricsTracerFactory *builtinMetricsTracerFactory

// These fields are used to uniquely track x-goog-spanner-request-id
// grpc.ClientConn is presumed to be the channel, hence channelID
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The raw *vkit.Client is the channel, so in that sense, this could be said to be redundant. But that property does not have a simple number or other simple string representation, which means that it is probably better/easier to just use the channel pool index that was used to fetch the channel as the channel ID here. Which again means that this property is not redundant and should be assigned a value.

// is redundant. However, is it correct to presume that raw.Connection()
// will always be the same throughout the lifetime of a grcpSpannerClient?
channelID uint64
// nthRequest shall always be incremented on every fresh request.
nthRequest *atomic.Uint32
// This id uniquely defines the RPC being issued and in
// the case of retries it should be incremented.
rpcID *atomic.Uint64
}

func (g *grpcSpannerClient) setOrResetRPCID() {
if g.rpcID == nil {
g.rpcID = new(atomic.Uint64)
}
g.rpcID.Store(1)
}

func (g *grpcSpannerClient) setRPCID(rpcID uint64) {
g.rpcID.Store(rpcID)
}

func (g *grpcSpannerClient) prepareRequestIDTrackers() {
g.id = nGRPCClient.Add(1)
g.nthRequest = new(atomic.Uint32)
g.channelID = 1 // Assuming that .raw.Connection() never changes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be fixed at 1. For regular sessions, we are setting the channel that should be used here: https://github.com/googleapis/google-cloud-go/blob/main/spanner/sessionclient.go#L404

For multiplexed sessions, we do that here:

p.multiplexSessionClientCounter = p.multiplexSessionClientCounter % len(p.clientPool)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above also shows why the current strategy of assuming that a grpcSpannerClient is not used in parallel by multiple goroutines is incorrect, as the the client library just keeps a pool of numChannels (default: 4) grpcSpannerClient instances for multiplexed sessions. These will be handed out in round-robin fashion to application goroutines that want to execute a query or transaction.

g.nthRequest = new(atomic.Uint32)
g.setOrResetRPCID()
}

var nGRPCClient = new(atomic.Uint64)

var (
// Ensure that grpcSpannerClient implements spannerClient.
_ spannerClient = (*grpcSpannerClient)(nil)
Expand All @@ -83,6 +121,8 @@
}

g := &grpcSpannerClient{raw: raw, metricsTracerFactory: sc.metricsTracerFactory}
g.prepareRequestIDTrackers()

clientInfo := []string{"gccl", internal.Version}
if sc.userAgent != "" {
agentWithVersion := strings.SplitN(sc.userAgent, "/", 2)
Expand Down Expand Up @@ -118,7 +158,7 @@
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.CreateSession(ctx, req, opts...)
resp, err := g.raw.CreateSession(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -128,7 +168,7 @@
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.BatchCreateSessions(ctx, req, opts...)
resp, err := g.raw.BatchCreateSessions(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -138,45 +178,67 @@
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.GetSession(ctx, req, opts...)
resp, err := g.raw.GetSession(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest, opts ...gax.CallOption) *vkit.SessionIterator {
return g.raw.ListSessions(ctx, req, opts...)
return g.raw.ListSessions(ctx, req, g.optsWithNextRequestID(opts)...)
}

func (g *grpcSpannerClient) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest, opts ...gax.CallOption) error {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
err := g.raw.DeleteSession(ctx, req, opts...)
err := g.raw.DeleteSession(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return err
}

var randIdForProcess uint32

Check failure on line 201 in spanner/grpc_client.go

View workflow job for this annotation

GitHub Actions / vet

var randIdForProcess should be randIDForProcess

func init() {
randIdForProcess = rand.New(rand.NewSource(time.Now().UnixNano())).Uint32()
}

const xSpannerRequestIDHeader = "x-goog-spanner-request-id"

// optsWithNextRequestID bundles priors with a new header "x-goog-spanner-request-id"
func (g *grpcSpannerClient) optsWithNextRequestID(priors []gax.CallOption) []gax.CallOption {
// TODO: Decide if each field should be padded and to what width or
// should we just let fields fill up so as to reduce bandwidth?
// Go creates grpc.ClientConn which is presumed to be a channel, so channelID is going to be redundant.
requestID := fmt.Sprintf("%d.%d.%d.%d.%d", randIdForProcess, g.id, g.nextNthRequest(), g.channelID, g.rpcID.Load())
md := metadata.MD{xSpannerRequestIDHeader: []string{requestID}}
return append(priors, gax.WithGRPCOptions(grpc.Header(&md)))
}

func (g *grpcSpannerClient) nextNthRequest() uint32 {
return g.nthRequest.Add(1)
}

func (g *grpcSpannerClient) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (*spannerpb.ResultSet, error) {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.ExecuteSql(ctx, req, opts...)
resp, err := g.raw.ExecuteSql(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) ExecuteStreamingSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (spannerpb.Spanner_ExecuteStreamingSqlClient, error) {
return g.raw.ExecuteStreamingSql(peer.NewContext(ctx, &peer.Peer{}), req, opts...)
return g.raw.ExecuteStreamingSql(peer.NewContext(ctx, &peer.Peer{}), req, g.optsWithNextRequestID(opts)...)
}

func (g *grpcSpannerClient) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest, opts ...gax.CallOption) (*spannerpb.ExecuteBatchDmlResponse, error) {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.ExecuteBatchDml(ctx, req, opts...)
resp, err := g.raw.ExecuteBatchDml(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -186,21 +248,21 @@
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.Read(ctx, req, opts...)
resp, err := g.raw.Read(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) StreamingRead(ctx context.Context, req *spannerpb.ReadRequest, opts ...gax.CallOption) (spannerpb.Spanner_StreamingReadClient, error) {
return g.raw.StreamingRead(peer.NewContext(ctx, &peer.Peer{}), req, opts...)
return g.raw.StreamingRead(peer.NewContext(ctx, &peer.Peer{}), req, g.optsWithNextRequestID(opts)...)
}

func (g *grpcSpannerClient) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest, opts ...gax.CallOption) (*spannerpb.Transaction, error) {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.BeginTransaction(ctx, req, opts...)
resp, err := g.raw.BeginTransaction(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -210,7 +272,7 @@
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.Commit(ctx, req, opts...)
resp, err := g.raw.Commit(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -220,7 +282,7 @@
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
err := g.raw.Rollback(ctx, req, opts...)
err := g.raw.Rollback(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return err
Expand All @@ -230,7 +292,7 @@
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.PartitionQuery(ctx, req, opts...)
resp, err := g.raw.PartitionQuery(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -240,12 +302,12 @@
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.PartitionRead(ctx, req, opts...)
resp, err := g.raw.PartitionRead(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) BatchWrite(ctx context.Context, req *spannerpb.BatchWriteRequest, opts ...gax.CallOption) (spannerpb.Spanner_BatchWriteClient, error) {
return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, opts...)
return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, g.optsWithNextRequestID(opts)...)
}
4 changes: 4 additions & 0 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ const (
MethodExecuteBatchDml string = "EXECUTE_BATCH_DML"
MethodStreamingRead string = "EXECUTE_STREAMING_READ"
MethodBatchWrite string = "BATCH_WRITE"
MethodPartitionQuery string = "PARTITION_QUERY"
)

// StatementResult represents a mocked result on the test server. The result is
Expand Down Expand Up @@ -1107,6 +1108,9 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba
}

func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil {
return nil, err
}
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
Expand Down
11 changes: 9 additions & 2 deletions spanner/pdml.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,19 @@ func executePdml(ctx context.Context, sh *sessionHandle, req *sppb.ExecuteSqlReq
var md metadata.MD
sh.updateLastUseTime()
// Begin transaction.
res, err := sh.getClient().BeginTransaction(ctx, &sppb.BeginTransactionRequest{
client := sh.getClient()
res, err := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{
Session: sh.getID(),
Options: &sppb.TransactionOptions{
Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}},
ExcludeTxnFromChangeStreams: options.ExcludeTxnFromChangeStreams,
},
})
// This function is invoked afresh on every retry and it retrieves a fresh client
// each time hence does not need an extraction and increment of the injected spanner requestId.
if gcl, ok := client.(*grpcSpannerClient); ok {
defer gcl.setOrResetRPCID()
}
if err != nil {
return 0, ToSpannerError(err)
}
Expand All @@ -126,7 +132,8 @@ func executePdml(ctx context.Context, sh *sessionHandle, req *sppb.ExecuteSqlReq
}

sh.updateLastUseTime()
resultSet, err := sh.getClient().ExecuteSql(ctx, req, gax.WithGRPCOptions(grpc.Header(&md)))

resultSet, err := client.ExecuteSql(ctx, req, gax.WithGRPCOptions(grpc.Header(&md)))
if getGFELatencyMetricsFlag() && md != nil && sh.session.pool != nil {
err := captureGFELatencyStats(tag.NewContext(ctx, sh.session.pool.tagMap), md, "executePdml_ExecuteSql")
if err != nil {
Expand Down
Loading
Loading