diff --git a/inspector.go b/inspector.go index 6cc2c28b..ece8b174 100644 --- a/inspector.go +++ b/inspector.go @@ -1,6 +1,7 @@ package asynq import ( + "context" "fmt" "strconv" "strings" @@ -18,7 +19,7 @@ type Inspector struct { rdb *rdb.RDB } -// New returns a new instance of Inspector. +// NewInspector returns a new instance of Inspector. func NewInspector(r RedisConnOpt) *Inspector { c, ok := r.MakeRedisClient().(redis.UniversalClient) if !ok { @@ -295,13 +296,13 @@ func Page(n int) ListOption { // ListPendingTasks retrieves pending tasks from the specified queue. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListPendingTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListPendingTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListPending(queue, pgn) + infos, err := i.rdb.ListPending(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -323,20 +324,20 @@ func (i *Inspector) ListPendingTasks(queue string, opts ...ListOption) ([]*TaskI // ListActiveTasks retrieves active tasks from the specified queue. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListActiveTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListActiveTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListActive(queue, pgn) + infos, err := i.rdb.ListActive(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) case err != nil: return nil, fmt.Errorf("asynq: %v", err) } - expired, err := i.rdb.ListLeaseExpired(time.Now(), queue) + expired, err := i.rdb.ListLeaseExpired(ctx, time.Now(), queue) if err != nil { return nil, fmt.Errorf("asynq: %v", err) } @@ -363,13 +364,13 @@ func (i *Inspector) ListActiveTasks(queue string, opts ...ListOption) ([]*TaskIn // ListAggregatingTasks retrieves scheduled tasks from the specified group. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListAggregatingTasks(queue, group string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListAggregatingTasks(ctx context.Context, queue, group string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListAggregating(queue, group, pgn) + infos, err := i.rdb.ListAggregating(ctx, queue, group, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -392,13 +393,13 @@ func (i *Inspector) ListAggregatingTasks(queue, group string, opts ...ListOption // Tasks are sorted by NextProcessAt in ascending order. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListScheduledTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListScheduledTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListScheduled(queue, pgn) + infos, err := i.rdb.ListScheduled(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -421,13 +422,13 @@ func (i *Inspector) ListScheduledTasks(queue string, opts ...ListOption) ([]*Tas // Tasks are sorted by NextProcessAt in ascending order. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListRetryTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListRetryTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListRetry(queue, pgn) + infos, err := i.rdb.ListRetry(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -450,13 +451,13 @@ func (i *Inspector) ListRetryTasks(queue string, opts ...ListOption) ([]*TaskInf // Tasks are sorted by LastFailedAt in descending order. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListArchivedTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListArchivedTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListArchived(queue, pgn) + infos, err := i.rdb.ListArchived(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -479,13 +480,13 @@ func (i *Inspector) ListArchivedTasks(queue string, opts ...ListOption) ([]*Task // Tasks are sorted by expiration time (i.e. CompletedAt + Retention) in descending order. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListCompletedTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListCompletedTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListCompleted(queue, pgn) + infos, err := i.rdb.ListCompleted(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) diff --git a/inspector_test.go b/inspector_test.go index 428058f3..4537dc56 100644 --- a/inspector_test.go +++ b/inspector_test.go @@ -671,6 +671,7 @@ func TestInspectorGetTaskInfoError(t *testing.T) { func TestInspectorListPendingTasks(t *testing.T) { r := setup(t) defer r.Close() + m1 := h.NewTaskMessage("task1", nil) m2 := h.NewTaskMessage("task2", nil) m3 := h.NewTaskMessageWithQueue("task3", nil, "critical") @@ -718,12 +719,14 @@ func TestInspectorListPendingTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) for q, msgs := range tc.pending { h.SeedPendingQueue(t, r, msgs, q) } - got, err := inspector.ListPendingTasks(tc.qname) + got, err := inspector.ListPendingTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListPendingTasks(%q) returned error: %v", tc.desc, tc.qname, err) @@ -811,11 +814,13 @@ func TestInspectorListActiveTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllActiveQueues(t, r, tc.active) h.SeedAllLease(t, r, tc.lease) - got, err := inspector.ListActiveTasks(tc.qname) + got, err := inspector.ListActiveTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListActiveTasks(%q) returned error: %v", tc.qname, tc.desc, err) continue @@ -882,10 +887,12 @@ func TestInspectorListScheduledTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllScheduledQueues(t, r, tc.scheduled) - got, err := inspector.ListScheduledTasks(tc.qname) + got, err := inspector.ListScheduledTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListScheduledTasks(%q) returned error: %v", tc.desc, tc.qname, err) continue @@ -953,10 +960,12 @@ func TestInspectorListRetryTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllRetryQueues(t, r, tc.retry) - got, err := inspector.ListRetryTasks(tc.qname) + got, err := inspector.ListRetryTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListRetryTasks(%q) returned error: %v", tc.desc, tc.qname, err) continue @@ -1023,10 +1032,12 @@ func TestInspectorListArchivedTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllArchivedQueues(t, r, tc.archived) - got, err := inspector.ListArchivedTasks(tc.qname) + got, err := inspector.ListArchivedTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListArchivedTasks(%q) returned error: %v", tc.desc, tc.qname, err) continue @@ -1100,10 +1111,12 @@ func TestInspectorListCompletedTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllCompletedQueues(t, r, tc.completed) - got, err := inspector.ListCompletedTasks(tc.qname) + got, err := inspector.ListCompletedTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListCompletedTasks(%q) returned error: %v", tc.desc, tc.qname, err) continue @@ -1188,6 +1201,8 @@ func TestInspectorListAggregatingTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedTasks(t, r, fxt.tasks) h.SeedRedisSet(t, r, base.AllQueues, fxt.allQueues) @@ -1195,7 +1210,7 @@ func TestInspectorListAggregatingTasks(t *testing.T) { h.SeedRedisZSets(t, r, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := inspector.ListAggregatingTasks(tc.qname, tc.gname) + got, err := inspector.ListAggregatingTasks(ctx, tc.qname, tc.gname) if err != nil { t.Fatalf("ListAggregatingTasks returned error: %v", err) } @@ -1263,7 +1278,9 @@ func TestInspectorListPagination(t *testing.T) { } for _, tc := range tests { - got, err := inspector.ListPendingTasks("default", Page(tc.page), PageSize(tc.pageSize)) + ctx := context.Background() + + got, err := inspector.ListPendingTasks(ctx, "default", Page(tc.page), PageSize(tc.pageSize)) if err != nil { t.Errorf("ListPendingTask('default') returned error: %v", err) continue @@ -1296,27 +1313,29 @@ func TestInspectorListTasksQueueNotFoundError(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) - if _, err := inspector.ListActiveTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListActiveTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListActiveTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListPendingTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListPendingTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListPendingTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListScheduledTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListScheduledTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListScheduledTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListRetryTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListRetryTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListRetryTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListArchivedTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListArchivedTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListArchivedTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListCompletedTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListCompletedTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListCompletedTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListAggregatingTasks(tc.qname, "mygroup"); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListAggregatingTasks(ctx, tc.qname, "mygroup"); !errors.Is(err, tc.wantErr) { t.Errorf("ListAggregatingTasks(%q, \"mygroup\") returned error %v, want %v", tc.qname, err, tc.wantErr) } } diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index a232ceb9..21dbcbcc 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -612,7 +612,7 @@ func (p Pagination) stop() int64 { } // ListPending returns pending tasks that are ready to be processed. -func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListPending(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListPending" exists, err := r.queueExists(qname) if err != nil { @@ -621,7 +621,7 @@ func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listMessages(qname, base.TaskStatePending, pgn) + res, err := r.listMessages(ctx, qname, base.TaskStatePending, pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -629,7 +629,7 @@ func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error } // ListActive returns all tasks that are currently being processed for the given queue. -func (r *RDB) ListActive(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListActive(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListActive" exists, err := r.queueExists(qname) if err != nil { @@ -638,7 +638,7 @@ func (r *RDB) ListActive(qname string, pgn Pagination) ([]*base.TaskInfo, error) if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listMessages(qname, base.TaskStateActive, pgn) + res, err := r.listMessages(ctx, qname, base.TaskStateActive, pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -662,7 +662,9 @@ return data `) // listMessages returns a list of TaskInfo in Redis list with the given key. -func (r *RDB) listMessages(qname string, state base.TaskState, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) listMessages(ctx context.Context, qname string, state base.TaskState, pgn Pagination) ( + []*base.TaskInfo, error, +) { var key string switch state { case base.TaskStateActive: @@ -676,8 +678,7 @@ func (r *RDB) listMessages(qname string, state base.TaskState, pgn Pagination) ( // correct range and reverse the list to get the tasks with pagination. stop := -pgn.start() - 1 start := -pgn.stop() - 1 - res, err := listMessagesCmd.Run(context.Background(), r.client, - []string{key}, start, stop, base.TaskKeyPrefix(qname)).Result() + res, err := listMessagesCmd.Run(ctx, r.client, []string{key}, start, stop, base.TaskKeyPrefix(qname)).Result() if err != nil { return nil, errors.E(errors.Unknown, err) } @@ -712,7 +713,7 @@ func (r *RDB) listMessages(qname string, state base.TaskState, pgn Pagination) ( // ListScheduled returns all tasks from the given queue that are scheduled // to be processed in the future. -func (r *RDB) ListScheduled(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListScheduled(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListScheduled" exists, err := r.queueExists(qname) if err != nil { @@ -721,7 +722,7 @@ func (r *RDB) ListScheduled(qname string, pgn Pagination) ([]*base.TaskInfo, err if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listZSetEntries(qname, base.TaskStateScheduled, base.ScheduledKey(qname), pgn) + res, err := r.listZSetEntries(ctx, qname, base.TaskStateScheduled, base.ScheduledKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -730,7 +731,7 @@ func (r *RDB) ListScheduled(qname string, pgn Pagination) ([]*base.TaskInfo, err // ListRetry returns all tasks from the given queue that have failed before // and willl be retried in the future. -func (r *RDB) ListRetry(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListRetry(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListRetry" exists, err := r.queueExists(qname) if err != nil { @@ -739,7 +740,7 @@ func (r *RDB) ListRetry(qname string, pgn Pagination) ([]*base.TaskInfo, error) if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listZSetEntries(qname, base.TaskStateRetry, base.RetryKey(qname), pgn) + res, err := r.listZSetEntries(ctx, qname, base.TaskStateRetry, base.RetryKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -747,7 +748,7 @@ func (r *RDB) ListRetry(qname string, pgn Pagination) ([]*base.TaskInfo, error) } // ListArchived returns all tasks from the given queue that have exhausted its retry limit. -func (r *RDB) ListArchived(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListArchived(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListArchived" exists, err := r.queueExists(qname) if err != nil { @@ -756,7 +757,7 @@ func (r *RDB) ListArchived(qname string, pgn Pagination) ([]*base.TaskInfo, erro if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateArchived, base.ArchivedKey(qname), pgn) + zs, err := r.listZSetEntries(ctx, qname, base.TaskStateArchived, base.ArchivedKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -764,7 +765,7 @@ func (r *RDB) ListArchived(qname string, pgn Pagination) ([]*base.TaskInfo, erro } // ListCompleted returns all tasks from the given queue that have completed successfully. -func (r *RDB) ListCompleted(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListCompleted(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListCompleted" exists, err := r.queueExists(qname) if err != nil { @@ -773,7 +774,7 @@ func (r *RDB) ListCompleted(qname string, pgn Pagination) ([]*base.TaskInfo, err if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateCompleted, base.CompletedKey(qname), pgn) + zs, err := r.listZSetEntries(ctx, qname, base.TaskStateCompleted, base.CompletedKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -781,7 +782,7 @@ func (r *RDB) ListCompleted(qname string, pgn Pagination) ([]*base.TaskInfo, err } // ListAggregating returns all tasks from the given group. -func (r *RDB) ListAggregating(qname, gname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListAggregating(ctx context.Context, qname, gname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListAggregating" exists, err := r.queueExists(qname) if err != nil { @@ -790,7 +791,7 @@ func (r *RDB) ListAggregating(qname, gname string, pgn Pagination) ([]*base.Task if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateAggregating, base.GroupKey(qname, gname), pgn) + zs, err := r.listZSetEntries(ctx, qname, base.TaskStateAggregating, base.GroupKey(qname, gname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -826,8 +827,8 @@ return data // listZSetEntries returns a list of message and score pairs in Redis sorted-set // with the given key. -func (r *RDB) listZSetEntries(qname string, state base.TaskState, key string, pgn Pagination) ([]*base.TaskInfo, error) { - res, err := listZSetEntriesCmd.Run(context.Background(), r.client, []string{key}, +func (r *RDB) listZSetEntries(ctx context.Context, qname string, state base.TaskState, key string, pgn Pagination) ([]*base.TaskInfo, error) { + res, err := listZSetEntriesCmd.Run(ctx, r.client, []string{key}, pgn.start(), pgn.stop(), base.TaskKeyPrefix(qname)).Result() if err != nil { return nil, errors.E(errors.Unknown, err) diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 863c1328..bf509f22 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -32,9 +32,11 @@ func TestAllQueues(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) for _, qname := range tc.queues { - if err := r.client.SAdd(context.Background(), base.AllQueues, qname).Err(); err != nil { + if err := r.client.SAdd(ctx, base.AllQueues, qname).Err(); err != nil { t.Fatalf("could not initialize all queue set: %v", err) } } @@ -274,6 +276,8 @@ func TestCurrentStats(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case for _, qname := range tc.paused { if err := r.Pause(qname); err != nil { @@ -290,7 +294,7 @@ func TestCurrentStats(t *testing.T) { h.SeedRedisZSets(t, r.client, tc.archived) h.SeedRedisZSets(t, r.client, tc.completed) h.SeedRedisZSets(t, r.client, tc.groups) - ctx := context.Background() + for qname, n := range tc.processed { r.client.Set(ctx, base.ProcessedKey(qname, now), n, 0) } @@ -351,16 +355,18 @@ func TestHistoricalStats(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) - r.client.SAdd(context.Background(), base.AllQueues, tc.qname) + r.client.SAdd(ctx, base.AllQueues, tc.qname) // populate last n days data for i := 0; i < tc.n; i++ { ts := now.Add(-time.Duration(i) * 24 * time.Hour) processedKey := base.ProcessedKey(tc.qname, ts) failedKey := base.FailedKey(tc.qname, ts) - r.client.Set(context.Background(), processedKey, (i+1)*1000, 0) - r.client.Set(context.Background(), failedKey, (i+1)*10, 0) + r.client.Set(ctx, processedKey, (i+1)*1000, 0) + r.client.Set(ctx, failedKey, (i+1)*10, 0) } got, err := r.HistoricalStats(tc.qname, tc.n) @@ -562,14 +568,17 @@ func TestGetTaskInfo(t *testing.T) { }, } + ctx := context.Background() + h.SeedAllActiveQueues(t, r.client, fixtures.active) h.SeedAllPendingQueues(t, r.client, fixtures.pending) h.SeedAllScheduledQueues(t, r.client, fixtures.scheduled) h.SeedAllRetryQueues(t, r.client, fixtures.retry) h.SeedAllArchivedQueues(t, r.client, fixtures.archived) h.SeedAllCompletedQueues(t, r.client, fixtures.completed) + // Write result data for the completed task. - if err := r.client.HSet(context.Background(), base.TaskKey(m6.Queue, m6.ID), "result", "foobar").Err(); err != nil { + if err := r.client.HSet(ctx, base.TaskKey(m6.Queue, m6.ID), "result", "foobar").Err(); err != nil { t.Fatalf("Failed to write result data under task key: %v", err) } @@ -788,10 +797,12 @@ func TestListPending(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllPendingQueues(t, r.client, tc.pending) - got, err := r.ListPending(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListPending(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListPending(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -807,6 +818,7 @@ func TestListPending(t *testing.T) { func TestListPendingPagination(t *testing.T) { r := setup(t) defer r.Close() + var msgs []*base.TaskMessage for i := 0; i < 100; i++ { msg := h.NewTaskMessage(fmt.Sprintf("task %d", i), nil) @@ -841,7 +853,9 @@ func TestListPendingPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListPending(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListPending(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListPending(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { t.Errorf("%s; %s returned error %v", tc.desc, op, err) @@ -907,10 +921,12 @@ func TestListActive(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllActiveQueues(t, r.client, tc.inProgress) - got, err := r.ListActive(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListActive(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListActive(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.inProgress) @@ -926,6 +942,7 @@ func TestListActive(t *testing.T) { func TestListActivePagination(t *testing.T) { r := setup(t) defer r.Close() + var msgs []*base.TaskMessage for i := 0; i < 100; i++ { msg := h.NewTaskMessage(fmt.Sprintf("task %d", i), nil) @@ -950,7 +967,9 @@ func TestListActivePagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListActive(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListActive(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListActive(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { t.Errorf("%s; %s returned error %v", tc.desc, op, err) @@ -1042,10 +1061,12 @@ func TestListScheduled(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - got, err := r.ListScheduled(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListScheduled(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListScheduled(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1063,8 +1084,10 @@ func TestListScheduledPagination(t *testing.T) { defer r.Close() // create 100 tasks with an increasing number of wait time. for i := 0; i < 100; i++ { + ctx := context.Background() + msg := h.NewTaskMessage(fmt.Sprintf("task %d", i), nil) - if err := r.Schedule(context.Background(), msg, time.Now().Add(time.Duration(i)*time.Second)); err != nil { + if err := r.Schedule(ctx, msg, time.Now().Add(time.Duration(i)*time.Second)); err != nil { t.Fatal(err) } } @@ -1086,7 +1109,9 @@ func TestListScheduledPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListScheduled(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListScheduled(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListScheduled(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { t.Errorf("%s; %s returned error %v", tc.desc, op, err) @@ -1196,10 +1221,12 @@ func TestListRetry(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllRetryQueues(t, r.client, tc.retry) - got, err := r.ListRetry(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListRetry(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListRetry(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1216,6 +1243,7 @@ func TestListRetry(t *testing.T) { func TestListRetryPagination(t *testing.T) { r := setup(t) defer r.Close() + // create 100 tasks with an increasing number of wait time. now := time.Now() var seed []base.Z @@ -1243,7 +1271,9 @@ func TestListRetryPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListRetry(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListRetry(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListRetry(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { @@ -1349,10 +1379,12 @@ func TestListArchived(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) - got, err := r.ListArchived(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListArchived(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListArchived(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1393,7 +1425,9 @@ func TestListArchivedPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListArchived(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListArchived(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListArchived(Pagination{Size: %d, Page: %d})", tc.size, tc.page) if err != nil { @@ -1489,10 +1523,12 @@ func TestListCompleted(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllCompletedQueues(t, r.client, tc.completed) - got, err := r.ListCompleted(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListCompleted(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListCompleted(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1533,7 +1569,9 @@ func TestListCompletedPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListCompleted(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListCompleted(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListCompleted(Pagination{Size: %d, Page: %d})", tc.size, tc.page) if err != nil { @@ -1632,6 +1670,8 @@ func TestListAggregating(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedRedisSet(t, r.client, base.AllQueues, fxt.allQueues) h.SeedRedisSets(t, r.client, fxt.allGroups) @@ -1639,7 +1679,7 @@ func TestListAggregating(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.ListAggregating(tc.qname, tc.gname, Pagination{}) + got, err := r.ListAggregating(ctx, tc.qname, tc.gname, Pagination{}) if err != nil { t.Fatalf("ListAggregating returned error: %v", err) } @@ -1753,7 +1793,9 @@ func TestListAggregatingPagination(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.ListAggregating(tc.qname, tc.gname, Pagination{Page: tc.page, Size: tc.size}) + ctx := context.Background() + + got, err := r.ListAggregating(ctx, tc.qname, tc.gname, Pagination{Page: tc.page, Size: tc.size}) if err != nil { t.Fatalf("ListAggregating returned error: %v", err) } @@ -1796,20 +1838,22 @@ func TestListTasksError(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + pgn := Pagination{Page: 0, Size: 20} - if _, got := r.ListActive(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListActive(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListActive returned %v", tc.desc, got) } - if _, got := r.ListPending(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListPending(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListPending returned %v", tc.desc, got) } - if _, got := r.ListScheduled(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListScheduled(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListScheduled returned %v", tc.desc, got) } - if _, got := r.ListRetry(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListRetry(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListRetry returned %v", tc.desc, got) } - if _, got := r.ListArchived(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListArchived(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListArchived returned %v", tc.desc, got) } } @@ -4291,6 +4335,8 @@ func TestDeleteTaskWithUniqueLock(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) @@ -4306,7 +4352,7 @@ func TestDeleteTaskWithUniqueLock(t *testing.T) { } } - if r.client.Exists(context.Background(), tc.uniqueKey).Val() != 0 { + if r.client.Exists(ctx, tc.uniqueKey).Val() != 0 { t.Errorf("Uniqueness lock %q still exists", tc.uniqueKey) } } @@ -4591,6 +4637,8 @@ func TestDeleteAllArchivedTasksWithUniqueKey(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) @@ -4609,7 +4657,7 @@ func TestDeleteAllArchivedTasksWithUniqueKey(t *testing.T) { } for _, uniqueKey := range tc.uniqueKeys { - if r.client.Exists(context.Background(), uniqueKey).Val() != 0 { + if r.client.Exists(ctx, uniqueKey).Val() != 0 { t.Errorf("Uniqueness lock %q still exists", uniqueKey) } } @@ -4996,6 +5044,8 @@ func TestRemoveQueue(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedAllPendingQueues(t, r.client, tc.pending) h.SeedAllActiveQueues(t, r.client, tc.inProgress) @@ -5009,7 +5059,7 @@ func TestRemoveQueue(t *testing.T) { tc.qname, tc.force, err) continue } - if r.client.SIsMember(context.Background(), base.AllQueues, tc.qname).Val() { + if r.client.SIsMember(ctx, base.AllQueues, tc.qname).Val() { t.Errorf("%q is a member of %q", tc.qname, base.AllQueues) } @@ -5022,12 +5072,12 @@ func TestRemoveQueue(t *testing.T) { base.ArchivedKey(tc.qname), } for _, key := range keys { - if r.client.Exists(context.Background(), key).Val() != 0 { + if r.client.Exists(ctx, key).Val() != 0 { t.Errorf("key %q still exists", key) } } - if n := len(r.client.Keys(context.Background(), base.TaskKeyPrefix(tc.qname)+"*").Val()); n != 0 { + if n := len(r.client.Keys(ctx, base.TaskKeyPrefix(tc.qname)+"*").Val()); n != 0 { t.Errorf("%d keys still exists for tasks", n) } } @@ -5036,6 +5086,7 @@ func TestRemoveQueue(t *testing.T) { func TestRemoveQueueError(t *testing.T) { r := setup(t) defer r.Close() + m1 := h.NewTaskMessage("task1", nil) m2 := h.NewTaskMessage("task2", nil) m3 := h.NewTaskMessageWithQueue("task3", nil, "custom") @@ -5432,6 +5483,8 @@ func TestRecordSchedulerEnqueueEventTrimsDataSet(t *testing.T) { key = base.SchedulerHistoryKey(entryID) ) + ctx := context.Background() + // Record maximum number of events. for i := 1; i <= maxEvents; i++ { event := base.SchedulerEnqueueEvent{ @@ -5444,7 +5497,7 @@ func TestRecordSchedulerEnqueueEventTrimsDataSet(t *testing.T) { } // Make sure the set is full. - if n := r.client.ZCard(context.Background(), key).Val(); n != maxEvents { + if n := r.client.ZCard(ctx, key).Val(); n != maxEvents { t.Fatalf("unexpected number of events; got %d, want %d", n, maxEvents) } @@ -5456,7 +5509,7 @@ func TestRecordSchedulerEnqueueEventTrimsDataSet(t *testing.T) { if err := r.RecordSchedulerEnqueueEvent(entryID, &event); err != nil { t.Fatalf("RecordSchedulerEnqueueEvent failed: %v", err) } - if n := r.client.ZCard(context.Background(), key).Val(); n != maxEvents { + if n := r.client.ZCard(ctx, key).Val(); n != maxEvents { t.Fatalf("unexpected number of events; got %d, want %d", n, maxEvents) } events, err := r.ListSchedulerEnqueueEvents(entryID, Pagination{Size: maxEvents}) @@ -5482,6 +5535,8 @@ func TestPause(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) err := r.Pause(tc.qname) @@ -5489,7 +5544,7 @@ func TestPause(t *testing.T) { t.Errorf("Pause(%q) returned error: %v", tc.qname, err) } key := base.PausedKey(tc.qname) - if r.client.Exists(context.Background(), key).Val() == 0 { + if r.client.Exists(ctx, key).Val() == 0 { t.Errorf("key %q does not exist", key) } } @@ -5532,6 +5587,8 @@ func TestUnpause(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) for _, qname := range tc.paused { if err := r.Pause(qname); err != nil { @@ -5544,7 +5601,7 @@ func TestUnpause(t *testing.T) { t.Errorf("Unpause(%q) returned error: %v", tc.qname, err) } key := base.PausedKey(tc.qname) - if r.client.Exists(context.Background(), key).Val() == 1 { + if r.client.Exists(ctx, key).Val() == 1 { t.Errorf("key %q exists", key) } } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 1c47d1ca..17ae4a7b 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -994,7 +994,7 @@ func (r *RDB) ListGroups(qname string) ([]string, error) { // 1) group has reached or exceeded its max size // 2) group's oldest task has reached or exceeded its max delay // 3) group's latest task has reached or exceeded its grace period -// if aggreation criteria is met, the command moves those tasks from the group +// if aggregation criteria is met, the command moves those tasks from the group // and put them in an aggregation set. Additionally, if the creation of aggregation set // empties the group, it will clear the group name from the all groups set. // @@ -1073,7 +1073,7 @@ return 0 `) // Task aggregation should finish within this timeout. -// Otherwise an aggregation set should be reclaimed by the recoverer. +// Otherwise, an aggregation set should be reclaimed by the recoverer. const aggregationTimeout = 2 * time.Minute // AggregationCheck checks the group identified by the given queue and group name to see if the tasks in the diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 822d2319..21aa658f 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -2535,10 +2535,12 @@ func TestDeleteExpiredCompletedTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedAllCompletedQueues(t, r.client, tc.completed) - if err := r.DeleteExpiredCompletedTasks(tc.qname, 100); err != nil { + if err := r.DeleteExpiredCompletedTasks(ctx, tc.qname, 100); err != nil { t.Errorf("DeleteExpiredCompletedTasks(%q, 100) failed: %v", tc.qname, err) continue } @@ -2620,10 +2622,12 @@ func TestListLeaseExpired(t *testing.T) { r := setup(t) defer r.Close() for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedAllLease(t, r.client, tc.lease) - got, err := r.ListLeaseExpired(tc.cutoff, tc.qnames...) + got, err := r.ListLeaseExpired(ctx, tc.cutoff, tc.qnames...) if err != nil { t.Errorf("%s; ListLeaseExpired(%v) returned error: %v", tc.desc, tc.cutoff, err) continue @@ -2724,10 +2728,12 @@ func TestExtendLease(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedAllLease(t, r.client, tc.lease) - gotExpirationTime, err := r.ExtendLease(tc.qname, tc.ids...) + gotExpirationTime, err := r.ExtendLease(ctx, tc.qname, tc.ids...) if err != nil { t.Fatalf("%s: ExtendLease(%q, %v) returned error: %v", tc.desc, tc.qname, tc.ids, err) } @@ -3078,9 +3084,11 @@ func TestWriteResult(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) - n, err := r.WriteResult(tc.qname, tc.taskID, tc.data) + n, err := r.WriteResult(ctx, tc.qname, tc.taskID, tc.data) if err != nil { t.Errorf("WriteResult failed: %v", err) continue @@ -3737,12 +3745,14 @@ func TestReclaimStaleAggregationSets(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedRedisZSets(t, r.client, tc.groups) h.SeedRedisZSets(t, r.client, tc.aggregationSets) h.SeedRedisZSets(t, r.client, tc.allAggregationSets) - if err := r.ReclaimStaleAggregationSets(tc.qname); err != nil { + if err := r.ReclaimStaleAggregationSets(ctx, tc.qname); err != nil { t.Errorf("ReclaimStaleAggregationSets returned error: %v", err) continue }