From 2a95b951c5405c06799f5700bac7a63b7aaa16a8 Mon Sep 17 00:00:00 2001 From: Rafi Date: Sat, 5 Oct 2024 12:03:18 -0300 Subject: [PATCH 1/6] Passing context to listMessages --- inspector.go | 11 ++++++----- internal/rdb/inspect.go | 15 ++++++++------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/inspector.go b/inspector.go index a98a2211..35bd162a 100644 --- a/inspector.go +++ b/inspector.go @@ -1,6 +1,7 @@ package asynq import ( + "context" "fmt" "strconv" "strings" @@ -18,7 +19,7 @@ type Inspector struct { rdb *rdb.RDB } -// New returns a new instance of Inspector. +// NewInspector returns a new instance of Inspector. func NewInspector(r RedisConnOpt) *Inspector { c, ok := r.MakeRedisClient().(redis.UniversalClient) if !ok { @@ -295,13 +296,13 @@ func Page(n int) ListOption { // ListPendingTasks retrieves pending tasks from the specified queue. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListPendingTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListPendingTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListPending(queue, pgn) + infos, err := i.rdb.ListPending(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -323,13 +324,13 @@ func (i *Inspector) ListPendingTasks(queue string, opts ...ListOption) ([]*TaskI // ListActiveTasks retrieves active tasks from the specified queue. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListActiveTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListActiveTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListActive(queue, pgn) + infos, err := i.rdb.ListActive(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index a232ceb9..17bc3a82 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -612,7 +612,7 @@ func (p Pagination) stop() int64 { } // ListPending returns pending tasks that are ready to be processed. -func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListPending(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListPending" exists, err := r.queueExists(qname) if err != nil { @@ -621,7 +621,7 @@ func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listMessages(qname, base.TaskStatePending, pgn) + res, err := r.listMessages(ctx, qname, base.TaskStatePending, pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -629,7 +629,7 @@ func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error } // ListActive returns all tasks that are currently being processed for the given queue. -func (r *RDB) ListActive(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListActive(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListActive" exists, err := r.queueExists(qname) if err != nil { @@ -638,7 +638,7 @@ func (r *RDB) ListActive(qname string, pgn Pagination) ([]*base.TaskInfo, error) if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listMessages(qname, base.TaskStateActive, pgn) + res, err := r.listMessages(ctx, qname, base.TaskStateActive, pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -662,7 +662,9 @@ return data `) // listMessages returns a list of TaskInfo in Redis list with the given key. -func (r *RDB) listMessages(qname string, state base.TaskState, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) listMessages(ctx context.Context, qname string, state base.TaskState, pgn Pagination) ( + []*base.TaskInfo, error, +) { var key string switch state { case base.TaskStateActive: @@ -676,8 +678,7 @@ func (r *RDB) listMessages(qname string, state base.TaskState, pgn Pagination) ( // correct range and reverse the list to get the tasks with pagination. stop := -pgn.start() - 1 start := -pgn.stop() - 1 - res, err := listMessagesCmd.Run(context.Background(), r.client, - []string{key}, start, stop, base.TaskKeyPrefix(qname)).Result() + res, err := listMessagesCmd.Run(ctx, r.client, []string{key}, start, stop, base.TaskKeyPrefix(qname)).Result() if err != nil { return nil, errors.E(errors.Unknown, err) } From ad5fa3dd7ba15713ca337645060dfaebc321c4c6 Mon Sep 17 00:00:00 2001 From: Rafi Date: Sat, 5 Oct 2024 13:24:16 -0300 Subject: [PATCH 2/6] Update inspect_test.go --- internal/rdb/inspect_test.go | 80 ++++++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 22 deletions(-) diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 863c1328..2f270c45 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -32,9 +32,11 @@ func TestAllQueues(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) for _, qname := range tc.queues { - if err := r.client.SAdd(context.Background(), base.AllQueues, qname).Err(); err != nil { + if err := r.client.SAdd(ctx, base.AllQueues, qname).Err(); err != nil { t.Fatalf("could not initialize all queue set: %v", err) } } @@ -274,6 +276,8 @@ func TestCurrentStats(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case for _, qname := range tc.paused { if err := r.Pause(qname); err != nil { @@ -290,7 +294,7 @@ func TestCurrentStats(t *testing.T) { h.SeedRedisZSets(t, r.client, tc.archived) h.SeedRedisZSets(t, r.client, tc.completed) h.SeedRedisZSets(t, r.client, tc.groups) - ctx := context.Background() + for qname, n := range tc.processed { r.client.Set(ctx, base.ProcessedKey(qname, now), n, 0) } @@ -351,16 +355,18 @@ func TestHistoricalStats(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) - r.client.SAdd(context.Background(), base.AllQueues, tc.qname) + r.client.SAdd(ctx, base.AllQueues, tc.qname) // populate last n days data for i := 0; i < tc.n; i++ { ts := now.Add(-time.Duration(i) * 24 * time.Hour) processedKey := base.ProcessedKey(tc.qname, ts) failedKey := base.FailedKey(tc.qname, ts) - r.client.Set(context.Background(), processedKey, (i+1)*1000, 0) - r.client.Set(context.Background(), failedKey, (i+1)*10, 0) + r.client.Set(ctx, processedKey, (i+1)*1000, 0) + r.client.Set(ctx, failedKey, (i+1)*10, 0) } got, err := r.HistoricalStats(tc.qname, tc.n) @@ -562,14 +568,17 @@ func TestGetTaskInfo(t *testing.T) { }, } + ctx := context.Background() + h.SeedAllActiveQueues(t, r.client, fixtures.active) h.SeedAllPendingQueues(t, r.client, fixtures.pending) h.SeedAllScheduledQueues(t, r.client, fixtures.scheduled) h.SeedAllRetryQueues(t, r.client, fixtures.retry) h.SeedAllArchivedQueues(t, r.client, fixtures.archived) h.SeedAllCompletedQueues(t, r.client, fixtures.completed) + // Write result data for the completed task. - if err := r.client.HSet(context.Background(), base.TaskKey(m6.Queue, m6.ID), "result", "foobar").Err(); err != nil { + if err := r.client.HSet(ctx, base.TaskKey(m6.Queue, m6.ID), "result", "foobar").Err(); err != nil { t.Fatalf("Failed to write result data under task key: %v", err) } @@ -788,10 +797,12 @@ func TestListPending(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllPendingQueues(t, r.client, tc.pending) - got, err := r.ListPending(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListPending(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListPending(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -807,6 +818,7 @@ func TestListPending(t *testing.T) { func TestListPendingPagination(t *testing.T) { r := setup(t) defer r.Close() + var msgs []*base.TaskMessage for i := 0; i < 100; i++ { msg := h.NewTaskMessage(fmt.Sprintf("task %d", i), nil) @@ -841,7 +853,9 @@ func TestListPendingPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListPending(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListPending(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListPending(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { t.Errorf("%s; %s returned error %v", tc.desc, op, err) @@ -907,10 +921,12 @@ func TestListActive(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllActiveQueues(t, r.client, tc.inProgress) - got, err := r.ListActive(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListActive(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListActive(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.inProgress) @@ -926,6 +942,7 @@ func TestListActive(t *testing.T) { func TestListActivePagination(t *testing.T) { r := setup(t) defer r.Close() + var msgs []*base.TaskMessage for i := 0; i < 100; i++ { msg := h.NewTaskMessage(fmt.Sprintf("task %d", i), nil) @@ -950,7 +967,9 @@ func TestListActivePagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListActive(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListActive(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListActive(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { t.Errorf("%s; %s returned error %v", tc.desc, op, err) @@ -1063,8 +1082,10 @@ func TestListScheduledPagination(t *testing.T) { defer r.Close() // create 100 tasks with an increasing number of wait time. for i := 0; i < 100; i++ { + ctx := context.Background() + msg := h.NewTaskMessage(fmt.Sprintf("task %d", i), nil) - if err := r.Schedule(context.Background(), msg, time.Now().Add(time.Duration(i)*time.Second)); err != nil { + if err := r.Schedule(ctx, msg, time.Now().Add(time.Duration(i)*time.Second)); err != nil { t.Fatal(err) } } @@ -1796,11 +1817,13 @@ func TestListTasksError(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + pgn := Pagination{Page: 0, Size: 20} - if _, got := r.ListActive(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListActive(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListActive returned %v", tc.desc, got) } - if _, got := r.ListPending(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListPending(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListPending returned %v", tc.desc, got) } if _, got := r.ListScheduled(tc.qname, pgn); !tc.match(got) { @@ -4291,6 +4314,8 @@ func TestDeleteTaskWithUniqueLock(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) @@ -4306,7 +4331,7 @@ func TestDeleteTaskWithUniqueLock(t *testing.T) { } } - if r.client.Exists(context.Background(), tc.uniqueKey).Val() != 0 { + if r.client.Exists(ctx, tc.uniqueKey).Val() != 0 { t.Errorf("Uniqueness lock %q still exists", tc.uniqueKey) } } @@ -4591,6 +4616,8 @@ func TestDeleteAllArchivedTasksWithUniqueKey(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) @@ -4609,7 +4636,7 @@ func TestDeleteAllArchivedTasksWithUniqueKey(t *testing.T) { } for _, uniqueKey := range tc.uniqueKeys { - if r.client.Exists(context.Background(), uniqueKey).Val() != 0 { + if r.client.Exists(ctx, uniqueKey).Val() != 0 { t.Errorf("Uniqueness lock %q still exists", uniqueKey) } } @@ -4996,6 +5023,8 @@ func TestRemoveQueue(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedAllPendingQueues(t, r.client, tc.pending) h.SeedAllActiveQueues(t, r.client, tc.inProgress) @@ -5009,7 +5038,7 @@ func TestRemoveQueue(t *testing.T) { tc.qname, tc.force, err) continue } - if r.client.SIsMember(context.Background(), base.AllQueues, tc.qname).Val() { + if r.client.SIsMember(ctx, base.AllQueues, tc.qname).Val() { t.Errorf("%q is a member of %q", tc.qname, base.AllQueues) } @@ -5022,12 +5051,12 @@ func TestRemoveQueue(t *testing.T) { base.ArchivedKey(tc.qname), } for _, key := range keys { - if r.client.Exists(context.Background(), key).Val() != 0 { + if r.client.Exists(ctx, key).Val() != 0 { t.Errorf("key %q still exists", key) } } - if n := len(r.client.Keys(context.Background(), base.TaskKeyPrefix(tc.qname)+"*").Val()); n != 0 { + if n := len(r.client.Keys(ctx, base.TaskKeyPrefix(tc.qname)+"*").Val()); n != 0 { t.Errorf("%d keys still exists for tasks", n) } } @@ -5036,6 +5065,7 @@ func TestRemoveQueue(t *testing.T) { func TestRemoveQueueError(t *testing.T) { r := setup(t) defer r.Close() + m1 := h.NewTaskMessage("task1", nil) m2 := h.NewTaskMessage("task2", nil) m3 := h.NewTaskMessageWithQueue("task3", nil, "custom") @@ -5432,6 +5462,8 @@ func TestRecordSchedulerEnqueueEventTrimsDataSet(t *testing.T) { key = base.SchedulerHistoryKey(entryID) ) + ctx := context.Background() + // Record maximum number of events. for i := 1; i <= maxEvents; i++ { event := base.SchedulerEnqueueEvent{ @@ -5444,7 +5476,7 @@ func TestRecordSchedulerEnqueueEventTrimsDataSet(t *testing.T) { } // Make sure the set is full. - if n := r.client.ZCard(context.Background(), key).Val(); n != maxEvents { + if n := r.client.ZCard(ctx, key).Val(); n != maxEvents { t.Fatalf("unexpected number of events; got %d, want %d", n, maxEvents) } @@ -5456,7 +5488,7 @@ func TestRecordSchedulerEnqueueEventTrimsDataSet(t *testing.T) { if err := r.RecordSchedulerEnqueueEvent(entryID, &event); err != nil { t.Fatalf("RecordSchedulerEnqueueEvent failed: %v", err) } - if n := r.client.ZCard(context.Background(), key).Val(); n != maxEvents { + if n := r.client.ZCard(ctx, key).Val(); n != maxEvents { t.Fatalf("unexpected number of events; got %d, want %d", n, maxEvents) } events, err := r.ListSchedulerEnqueueEvents(entryID, Pagination{Size: maxEvents}) @@ -5482,6 +5514,8 @@ func TestPause(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) err := r.Pause(tc.qname) @@ -5489,7 +5523,7 @@ func TestPause(t *testing.T) { t.Errorf("Pause(%q) returned error: %v", tc.qname, err) } key := base.PausedKey(tc.qname) - if r.client.Exists(context.Background(), key).Val() == 0 { + if r.client.Exists(ctx, key).Val() == 0 { t.Errorf("key %q does not exist", key) } } @@ -5532,6 +5566,8 @@ func TestUnpause(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) for _, qname := range tc.paused { if err := r.Pause(qname); err != nil { @@ -5544,7 +5580,7 @@ func TestUnpause(t *testing.T) { t.Errorf("Unpause(%q) returned error: %v", tc.qname, err) } key := base.PausedKey(tc.qname) - if r.client.Exists(context.Background(), key).Val() == 1 { + if r.client.Exists(ctx, key).Val() == 1 { t.Errorf("key %q exists", key) } } From 5f627669e37e8cb07801c3223b92b2b68d75a683 Mon Sep 17 00:00:00 2001 From: Rafi Date: Sat, 5 Oct 2024 13:28:19 -0300 Subject: [PATCH 3/6] Update inspector_test.go --- inspector_test.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/inspector_test.go b/inspector_test.go index 428058f3..8c11c0fc 100644 --- a/inspector_test.go +++ b/inspector_test.go @@ -671,6 +671,7 @@ func TestInspectorGetTaskInfoError(t *testing.T) { func TestInspectorListPendingTasks(t *testing.T) { r := setup(t) defer r.Close() + m1 := h.NewTaskMessage("task1", nil) m2 := h.NewTaskMessage("task2", nil) m3 := h.NewTaskMessageWithQueue("task3", nil, "critical") @@ -718,12 +719,14 @@ func TestInspectorListPendingTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) for q, msgs := range tc.pending { h.SeedPendingQueue(t, r, msgs, q) } - got, err := inspector.ListPendingTasks(tc.qname) + got, err := inspector.ListPendingTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListPendingTasks(%q) returned error: %v", tc.desc, tc.qname, err) @@ -811,11 +814,13 @@ func TestInspectorListActiveTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllActiveQueues(t, r, tc.active) h.SeedAllLease(t, r, tc.lease) - got, err := inspector.ListActiveTasks(tc.qname) + got, err := inspector.ListActiveTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListActiveTasks(%q) returned error: %v", tc.qname, tc.desc, err) continue @@ -1263,7 +1268,9 @@ func TestInspectorListPagination(t *testing.T) { } for _, tc := range tests { - got, err := inspector.ListPendingTasks("default", Page(tc.page), PageSize(tc.pageSize)) + ctx := context.Background() + + got, err := inspector.ListPendingTasks(ctx, "default", Page(tc.page), PageSize(tc.pageSize)) if err != nil { t.Errorf("ListPendingTask('default') returned error: %v", err) continue @@ -1296,12 +1303,14 @@ func TestInspectorListTasksQueueNotFoundError(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) - if _, err := inspector.ListActiveTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListActiveTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListActiveTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListPendingTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListPendingTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListPendingTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } if _, err := inspector.ListScheduledTasks(tc.qname); !errors.Is(err, tc.wantErr) { From 8a0e86c443ed8e096aa30af4035e9c467eb2d3e2 Mon Sep 17 00:00:00 2001 From: Rafi Date: Sat, 5 Oct 2024 13:47:51 -0300 Subject: [PATCH 4/6] Passing context --- inspector.go | 22 +++++++-------- inspector_test.go | 30 +++++++++++++------- internal/base/base.go | 24 +++++++++++----- internal/rdb/inspect.go | 24 ++++++++-------- internal/rdb/inspect_test.go | 47 ++++++++++++++++++++++--------- internal/rdb/rdb.go | 8 +++--- internal/rdb/rdb_test.go | 6 ++-- internal/testbroker/testbroker.go | 8 +++--- subscriber.go | 2 +- subscriber_test.go | 4 +-- 10 files changed, 108 insertions(+), 67 deletions(-) diff --git a/inspector.go b/inspector.go index 35bd162a..279e6631 100644 --- a/inspector.go +++ b/inspector.go @@ -364,13 +364,13 @@ func (i *Inspector) ListActiveTasks(ctx context.Context, queue string, opts ...L // ListAggregatingTasks retrieves scheduled tasks from the specified group. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListAggregatingTasks(queue, group string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListAggregatingTasks(ctx context.Context, queue, group string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListAggregating(queue, group, pgn) + infos, err := i.rdb.ListAggregating(ctx, queue, group, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -393,13 +393,13 @@ func (i *Inspector) ListAggregatingTasks(queue, group string, opts ...ListOption // Tasks are sorted by NextProcessAt in ascending order. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListScheduledTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListScheduledTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListScheduled(queue, pgn) + infos, err := i.rdb.ListScheduled(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -422,13 +422,13 @@ func (i *Inspector) ListScheduledTasks(queue string, opts ...ListOption) ([]*Tas // Tasks are sorted by NextProcessAt in ascending order. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListRetryTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListRetryTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListRetry(queue, pgn) + infos, err := i.rdb.ListRetry(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -451,13 +451,13 @@ func (i *Inspector) ListRetryTasks(queue string, opts ...ListOption) ([]*TaskInf // Tasks are sorted by LastFailedAt in descending order. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListArchivedTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListArchivedTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListArchived(queue, pgn) + infos, err := i.rdb.ListArchived(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -480,13 +480,13 @@ func (i *Inspector) ListArchivedTasks(queue string, opts ...ListOption) ([]*Task // Tasks are sorted by expiration time (i.e. CompletedAt + Retention) in descending order. // // By default, it retrieves the first 30 tasks. -func (i *Inspector) ListCompletedTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { +func (i *Inspector) ListCompletedTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListCompleted(queue, pgn) + infos, err := i.rdb.ListCompleted(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -719,7 +719,7 @@ func (i *Inspector) ArchiveTask(queue, id string) error { // guarantee that the task with the given id will be canceled. The return // value only indicates whether the cancelation signal has been sent. func (i *Inspector) CancelProcessing(id string) error { - return i.rdb.PublishCancelation(id) + return i.rdb.PublishCancellation(id) } // PauseQueue pauses task processing on the specified queue. diff --git a/inspector_test.go b/inspector_test.go index 8c11c0fc..4537dc56 100644 --- a/inspector_test.go +++ b/inspector_test.go @@ -887,10 +887,12 @@ func TestInspectorListScheduledTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllScheduledQueues(t, r, tc.scheduled) - got, err := inspector.ListScheduledTasks(tc.qname) + got, err := inspector.ListScheduledTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListScheduledTasks(%q) returned error: %v", tc.desc, tc.qname, err) continue @@ -958,10 +960,12 @@ func TestInspectorListRetryTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllRetryQueues(t, r, tc.retry) - got, err := inspector.ListRetryTasks(tc.qname) + got, err := inspector.ListRetryTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListRetryTasks(%q) returned error: %v", tc.desc, tc.qname, err) continue @@ -1028,10 +1032,12 @@ func TestInspectorListArchivedTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllArchivedQueues(t, r, tc.archived) - got, err := inspector.ListArchivedTasks(tc.qname) + got, err := inspector.ListArchivedTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListArchivedTasks(%q) returned error: %v", tc.desc, tc.qname, err) continue @@ -1105,10 +1111,12 @@ func TestInspectorListCompletedTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedAllCompletedQueues(t, r, tc.completed) - got, err := inspector.ListCompletedTasks(tc.qname) + got, err := inspector.ListCompletedTasks(ctx, tc.qname) if err != nil { t.Errorf("%s; ListCompletedTasks(%q) returned error: %v", tc.desc, tc.qname, err) continue @@ -1193,6 +1201,8 @@ func TestInspectorListAggregatingTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r) h.SeedTasks(t, r, fxt.tasks) h.SeedRedisSet(t, r, base.AllQueues, fxt.allQueues) @@ -1200,7 +1210,7 @@ func TestInspectorListAggregatingTasks(t *testing.T) { h.SeedRedisZSets(t, r, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := inspector.ListAggregatingTasks(tc.qname, tc.gname) + got, err := inspector.ListAggregatingTasks(ctx, tc.qname, tc.gname) if err != nil { t.Fatalf("ListAggregatingTasks returned error: %v", err) } @@ -1313,19 +1323,19 @@ func TestInspectorListTasksQueueNotFoundError(t *testing.T) { if _, err := inspector.ListPendingTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListPendingTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListScheduledTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListScheduledTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListScheduledTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListRetryTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListRetryTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListRetryTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListArchivedTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListArchivedTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListArchivedTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListCompletedTasks(tc.qname); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListCompletedTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) { t.Errorf("ListCompletedTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr) } - if _, err := inspector.ListAggregatingTasks(tc.qname, "mygroup"); !errors.Is(err, tc.wantErr) { + if _, err := inspector.ListAggregatingTasks(ctx, tc.qname, "mygroup"); !errors.Is(err, tc.wantErr) { t.Errorf("ListAggregatingTasks(%q, \"mygroup\") returned error %v, want %v", tc.qname, err, tc.wantErr) } } diff --git a/internal/base/base.go b/internal/base/base.go index 2d658b47..b8af2952 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -687,7 +687,9 @@ type Broker interface { Archive(ctx context.Context, msg *TaskMessage, errMsg string) error ForwardIfReady(qnames ...string) error - // Group aggregation related methods + /* + Group aggregation related methods + */ AddToGroup(ctx context.Context, msg *TaskMessage, gname string) error AddToGroupUnique(ctx context.Context, msg *TaskMessage, groupKey string, ttl time.Duration) error ListGroups(qname string) ([]string, error) @@ -696,20 +698,28 @@ type Broker interface { DeleteAggregationSet(ctx context.Context, qname, gname, aggregationSetID string) error ReclaimStaleAggregationSets(qname string) error - // Task retention related method + /* + Task retention related method + */ DeleteExpiredCompletedTasks(qname string, batchSize int) error - // Lease related methods + /* + Lease related methods + */ ListLeaseExpired(cutoff time.Time, qnames ...string) ([]*TaskMessage, error) ExtendLease(qname string, ids ...string) (time.Time, error) - // State snapshot related methods + /* + State snapshot related methods + */ WriteServerState(info *ServerInfo, workers []*WorkerInfo, ttl time.Duration) error ClearServerState(host string, pid int, serverID string) error - // Cancelation related methods - CancelationPubSub() (*redis.PubSub, error) // TODO: Need to decouple from redis to support other brokers - PublishCancelation(id string) error + /* + Cancellation related methods + */ + CancellationPubSub() (*redis.PubSub, error) // TODO: Need to decouple from redis to support other brokers + PublishCancellation(id string) error WriteResult(qname, id string, data []byte) (n int, err error) } diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index 17bc3a82..21dbcbcc 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -713,7 +713,7 @@ func (r *RDB) listMessages(ctx context.Context, qname string, state base.TaskSta // ListScheduled returns all tasks from the given queue that are scheduled // to be processed in the future. -func (r *RDB) ListScheduled(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListScheduled(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListScheduled" exists, err := r.queueExists(qname) if err != nil { @@ -722,7 +722,7 @@ func (r *RDB) ListScheduled(qname string, pgn Pagination) ([]*base.TaskInfo, err if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listZSetEntries(qname, base.TaskStateScheduled, base.ScheduledKey(qname), pgn) + res, err := r.listZSetEntries(ctx, qname, base.TaskStateScheduled, base.ScheduledKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -731,7 +731,7 @@ func (r *RDB) ListScheduled(qname string, pgn Pagination) ([]*base.TaskInfo, err // ListRetry returns all tasks from the given queue that have failed before // and willl be retried in the future. -func (r *RDB) ListRetry(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListRetry(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListRetry" exists, err := r.queueExists(qname) if err != nil { @@ -740,7 +740,7 @@ func (r *RDB) ListRetry(qname string, pgn Pagination) ([]*base.TaskInfo, error) if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listZSetEntries(qname, base.TaskStateRetry, base.RetryKey(qname), pgn) + res, err := r.listZSetEntries(ctx, qname, base.TaskStateRetry, base.RetryKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -748,7 +748,7 @@ func (r *RDB) ListRetry(qname string, pgn Pagination) ([]*base.TaskInfo, error) } // ListArchived returns all tasks from the given queue that have exhausted its retry limit. -func (r *RDB) ListArchived(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListArchived(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListArchived" exists, err := r.queueExists(qname) if err != nil { @@ -757,7 +757,7 @@ func (r *RDB) ListArchived(qname string, pgn Pagination) ([]*base.TaskInfo, erro if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateArchived, base.ArchivedKey(qname), pgn) + zs, err := r.listZSetEntries(ctx, qname, base.TaskStateArchived, base.ArchivedKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -765,7 +765,7 @@ func (r *RDB) ListArchived(qname string, pgn Pagination) ([]*base.TaskInfo, erro } // ListCompleted returns all tasks from the given queue that have completed successfully. -func (r *RDB) ListCompleted(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListCompleted(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListCompleted" exists, err := r.queueExists(qname) if err != nil { @@ -774,7 +774,7 @@ func (r *RDB) ListCompleted(qname string, pgn Pagination) ([]*base.TaskInfo, err if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateCompleted, base.CompletedKey(qname), pgn) + zs, err := r.listZSetEntries(ctx, qname, base.TaskStateCompleted, base.CompletedKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -782,7 +782,7 @@ func (r *RDB) ListCompleted(qname string, pgn Pagination) ([]*base.TaskInfo, err } // ListAggregating returns all tasks from the given group. -func (r *RDB) ListAggregating(qname, gname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListAggregating(ctx context.Context, qname, gname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListAggregating" exists, err := r.queueExists(qname) if err != nil { @@ -791,7 +791,7 @@ func (r *RDB) ListAggregating(qname, gname string, pgn Pagination) ([]*base.Task if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateAggregating, base.GroupKey(qname, gname), pgn) + zs, err := r.listZSetEntries(ctx, qname, base.TaskStateAggregating, base.GroupKey(qname, gname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -827,8 +827,8 @@ return data // listZSetEntries returns a list of message and score pairs in Redis sorted-set // with the given key. -func (r *RDB) listZSetEntries(qname string, state base.TaskState, key string, pgn Pagination) ([]*base.TaskInfo, error) { - res, err := listZSetEntriesCmd.Run(context.Background(), r.client, []string{key}, +func (r *RDB) listZSetEntries(ctx context.Context, qname string, state base.TaskState, key string, pgn Pagination) ([]*base.TaskInfo, error) { + res, err := listZSetEntriesCmd.Run(ctx, r.client, []string{key}, pgn.start(), pgn.stop(), base.TaskKeyPrefix(qname)).Result() if err != nil { return nil, errors.E(errors.Unknown, err) diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 2f270c45..bf509f22 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -1061,10 +1061,12 @@ func TestListScheduled(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - got, err := r.ListScheduled(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListScheduled(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListScheduled(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1107,7 +1109,9 @@ func TestListScheduledPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListScheduled(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListScheduled(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListScheduled(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { t.Errorf("%s; %s returned error %v", tc.desc, op, err) @@ -1217,10 +1221,12 @@ func TestListRetry(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllRetryQueues(t, r.client, tc.retry) - got, err := r.ListRetry(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListRetry(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListRetry(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1237,6 +1243,7 @@ func TestListRetry(t *testing.T) { func TestListRetryPagination(t *testing.T) { r := setup(t) defer r.Close() + // create 100 tasks with an increasing number of wait time. now := time.Now() var seed []base.Z @@ -1264,7 +1271,9 @@ func TestListRetryPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListRetry(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListRetry(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListRetry(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { @@ -1370,10 +1379,12 @@ func TestListArchived(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) - got, err := r.ListArchived(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListArchived(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListArchived(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1414,7 +1425,9 @@ func TestListArchivedPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListArchived(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListArchived(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListArchived(Pagination{Size: %d, Page: %d})", tc.size, tc.page) if err != nil { @@ -1510,10 +1523,12 @@ func TestListCompleted(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllCompletedQueues(t, r.client, tc.completed) - got, err := r.ListCompleted(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListCompleted(ctx, tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListCompleted(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1554,7 +1569,9 @@ func TestListCompletedPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListCompleted(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + ctx := context.Background() + + got, err := r.ListCompleted(ctx, tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListCompleted(Pagination{Size: %d, Page: %d})", tc.size, tc.page) if err != nil { @@ -1653,6 +1670,8 @@ func TestListAggregating(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedRedisSet(t, r.client, base.AllQueues, fxt.allQueues) h.SeedRedisSets(t, r.client, fxt.allGroups) @@ -1660,7 +1679,7 @@ func TestListAggregating(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.ListAggregating(tc.qname, tc.gname, Pagination{}) + got, err := r.ListAggregating(ctx, tc.qname, tc.gname, Pagination{}) if err != nil { t.Fatalf("ListAggregating returned error: %v", err) } @@ -1774,7 +1793,9 @@ func TestListAggregatingPagination(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.ListAggregating(tc.qname, tc.gname, Pagination{Page: tc.page, Size: tc.size}) + ctx := context.Background() + + got, err := r.ListAggregating(ctx, tc.qname, tc.gname, Pagination{Page: tc.page, Size: tc.size}) if err != nil { t.Fatalf("ListAggregating returned error: %v", err) } @@ -1826,13 +1847,13 @@ func TestListTasksError(t *testing.T) { if _, got := r.ListPending(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListPending returned %v", tc.desc, got) } - if _, got := r.ListScheduled(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListScheduled(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListScheduled returned %v", tc.desc, got) } - if _, got := r.ListRetry(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListRetry(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListRetry returned %v", tc.desc, got) } - if _, got := r.ListArchived(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListArchived(ctx, tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListArchived returned %v", tc.desc, got) } } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 74454beb..e50ffa08 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -1442,8 +1442,8 @@ func (r *RDB) ClearSchedulerEntries(scheduelrID string) error { } // CancelationPubSub returns a pubsub for cancelation messages. -func (r *RDB) CancelationPubSub() (*redis.PubSub, error) { - var op errors.Op = "rdb.CancelationPubSub" +func (r *RDB) CancellationPubSub() (*redis.PubSub, error) { + var op errors.Op = "rdb.CancellationPubSub" ctx := context.Background() pubsub := r.client.Subscribe(ctx, base.CancelChannel) _, err := pubsub.Receive(ctx) @@ -1455,8 +1455,8 @@ func (r *RDB) CancelationPubSub() (*redis.PubSub, error) { // PublishCancelation publish cancelation message to all subscribers. // The message is the ID for the task to be canceled. -func (r *RDB) PublishCancelation(id string) error { - var op errors.Op = "rdb.PublishCancelation" +func (r *RDB) PublishCancellation(id string) error { + var op errors.Op = "rdb.PublishCancellation" ctx := context.Background() if err := r.client.Publish(ctx, base.CancelChannel, id).Err(); err != nil { return errors.E(op, errors.Unknown, fmt.Sprintf("redis pubsub publish error: %v", err)) diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index c8d9ce86..822d2319 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -3023,9 +3023,9 @@ func TestCancelationPubSub(t *testing.T) { r := setup(t) defer r.Close() - pubsub, err := r.CancelationPubSub() + pubsub, err := r.CancellationPubSub() if err != nil { - t.Fatalf("(*RDB).CancelationPubSub() returned an error: %v", err) + t.Fatalf("(*RDB).CancellationPubSub() returned an error: %v", err) } cancelCh := pubsub.Channel() @@ -3046,7 +3046,7 @@ func TestCancelationPubSub(t *testing.T) { publish := []string{"one", "two", "three"} for _, msg := range publish { - r.PublishCancelation(msg) + r.PublishCancellation(msg) } // allow for message to reach subscribers. diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index 8e562148..16844fc9 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -186,22 +186,22 @@ func (tb *TestBroker) ClearServerState(host string, pid int, serverID string) er return tb.real.ClearServerState(host, pid, serverID) } -func (tb *TestBroker) CancelationPubSub() (*redis.PubSub, error) { +func (tb *TestBroker) CancellationPubSub() (*redis.PubSub, error) { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { return nil, errRedisDown } - return tb.real.CancelationPubSub() + return tb.real.CancellationPubSub() } -func (tb *TestBroker) PublishCancelation(id string) error { +func (tb *TestBroker) PublishCancellation(id string) error { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { return errRedisDown } - return tb.real.PublishCancelation(id) + return tb.real.PublishCancellation(id) } func (tb *TestBroker) WriteResult(qname, id string, data []byte) (int, error) { diff --git a/subscriber.go b/subscriber.go index dbcf029e..560a168f 100644 --- a/subscriber.go +++ b/subscriber.go @@ -55,7 +55,7 @@ func (s *subscriber) start(wg *sync.WaitGroup) { ) // Try until successfully connect to Redis. for { - pubsub, err = s.broker.CancelationPubSub() + pubsub, err = s.broker.CancellationPubSub() if err != nil { s.logger.Errorf("cannot subscribe to cancelation channel: %v", err) select { diff --git a/subscriber_test.go b/subscriber_test.go index ac51dfbd..3e95f142 100644 --- a/subscriber_test.go +++ b/subscriber_test.go @@ -47,7 +47,7 @@ func TestSubscriber(t *testing.T) { // wait for subscriber to establish connection to pubsub channel time.Sleep(time.Second) - if err := rdbClient.PublishCancelation(tc.publishID); err != nil { + if err := rdbClient.PublishCancellation(tc.publishID); err != nil { t.Fatalf("could not publish cancelation message: %v", err) } @@ -106,7 +106,7 @@ func TestSubscriberWithRedisDown(t *testing.T) { called = true }) - if err := r.PublishCancelation(id); err != nil { + if err := r.PublishCancellation(id); err != nil { t.Fatalf("could not publish cancelation message: %v", err) } From 6115926900fa5bc22bc72f58300b844f992be9ad Mon Sep 17 00:00:00 2001 From: Rafi Date: Sun, 6 Oct 2024 16:21:18 -0300 Subject: [PATCH 5/6] Fix context more --- inspector.go | 2 +- internal/rdb/rdb_test.go | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/inspector.go b/inspector.go index 279e6631..ece8b174 100644 --- a/inspector.go +++ b/inspector.go @@ -337,7 +337,7 @@ func (i *Inspector) ListActiveTasks(ctx context.Context, queue string, opts ...L case err != nil: return nil, fmt.Errorf("asynq: %v", err) } - expired, err := i.rdb.ListLeaseExpired(time.Now(), queue) + expired, err := i.rdb.ListLeaseExpired(ctx, time.Now(), queue) if err != nil { return nil, fmt.Errorf("asynq: %v", err) } diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 822d2319..21aa658f 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -2535,10 +2535,12 @@ func TestDeleteExpiredCompletedTasks(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedAllCompletedQueues(t, r.client, tc.completed) - if err := r.DeleteExpiredCompletedTasks(tc.qname, 100); err != nil { + if err := r.DeleteExpiredCompletedTasks(ctx, tc.qname, 100); err != nil { t.Errorf("DeleteExpiredCompletedTasks(%q, 100) failed: %v", tc.qname, err) continue } @@ -2620,10 +2622,12 @@ func TestListLeaseExpired(t *testing.T) { r := setup(t) defer r.Close() for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedAllLease(t, r.client, tc.lease) - got, err := r.ListLeaseExpired(tc.cutoff, tc.qnames...) + got, err := r.ListLeaseExpired(ctx, tc.cutoff, tc.qnames...) if err != nil { t.Errorf("%s; ListLeaseExpired(%v) returned error: %v", tc.desc, tc.cutoff, err) continue @@ -2724,10 +2728,12 @@ func TestExtendLease(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedAllLease(t, r.client, tc.lease) - gotExpirationTime, err := r.ExtendLease(tc.qname, tc.ids...) + gotExpirationTime, err := r.ExtendLease(ctx, tc.qname, tc.ids...) if err != nil { t.Fatalf("%s: ExtendLease(%q, %v) returned error: %v", tc.desc, tc.qname, tc.ids, err) } @@ -3078,9 +3084,11 @@ func TestWriteResult(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) - n, err := r.WriteResult(tc.qname, tc.taskID, tc.data) + n, err := r.WriteResult(ctx, tc.qname, tc.taskID, tc.data) if err != nil { t.Errorf("WriteResult failed: %v", err) continue @@ -3737,12 +3745,14 @@ func TestReclaimStaleAggregationSets(t *testing.T) { } for _, tc := range tests { + ctx := context.Background() + h.FlushDB(t, r.client) h.SeedRedisZSets(t, r.client, tc.groups) h.SeedRedisZSets(t, r.client, tc.aggregationSets) h.SeedRedisZSets(t, r.client, tc.allAggregationSets) - if err := r.ReclaimStaleAggregationSets(tc.qname); err != nil { + if err := r.ReclaimStaleAggregationSets(ctx, tc.qname); err != nil { t.Errorf("ReclaimStaleAggregationSets returned error: %v", err) continue } From 666ccf24aeb9b73440fb9087be55e2d3c60449e9 Mon Sep 17 00:00:00 2001 From: Rafi Date: Sun, 6 Oct 2024 20:23:16 -0300 Subject: [PATCH 6/6] Update rdb.go --- internal/rdb/rdb.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 1c47d1ca..17ae4a7b 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -994,7 +994,7 @@ func (r *RDB) ListGroups(qname string) ([]string, error) { // 1) group has reached or exceeded its max size // 2) group's oldest task has reached or exceeded its max delay // 3) group's latest task has reached or exceeded its grace period -// if aggreation criteria is met, the command moves those tasks from the group +// if aggregation criteria is met, the command moves those tasks from the group // and put them in an aggregation set. Additionally, if the creation of aggregation set // empties the group, it will clear the group name from the all groups set. // @@ -1073,7 +1073,7 @@ return 0 `) // Task aggregation should finish within this timeout. -// Otherwise an aggregation set should be reclaimed by the recoverer. +// Otherwise, an aggregation set should be reclaimed by the recoverer. const aggregationTimeout = 2 * time.Minute // AggregationCheck checks the group identified by the given queue and group name to see if the tasks in the