Skip to content

Commit

Permalink
Merge pull request #1 from AsynqLab/feature/propagate-context
Browse files Browse the repository at this point in the history
  • Loading branch information
Joker666 authored Oct 10, 2024
2 parents 675de81 + 666ccf2 commit 6856cc7
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 92 deletions.
33 changes: 17 additions & 16 deletions inspector.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package asynq

import (
"context"
"fmt"
"strconv"
"strings"
Expand All @@ -18,7 +19,7 @@ type Inspector struct {
rdb *rdb.RDB
}

// New returns a new instance of Inspector.
// NewInspector returns a new instance of Inspector.
func NewInspector(r RedisConnOpt) *Inspector {
c, ok := r.MakeRedisClient().(redis.UniversalClient)
if !ok {
Expand Down Expand Up @@ -295,13 +296,13 @@ func Page(n int) ListOption {
// ListPendingTasks retrieves pending tasks from the specified queue.
//
// By default, it retrieves the first 30 tasks.
func (i *Inspector) ListPendingTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) {
func (i *Inspector) ListPendingTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) {
if err := base.ValidateQueueName(queue); err != nil {
return nil, fmt.Errorf("asynq: %v", err)
}
opt := composeListOptions(opts...)
pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1}
infos, err := i.rdb.ListPending(queue, pgn)
infos, err := i.rdb.ListPending(ctx, queue, pgn)
switch {
case errors.IsQueueNotFound(err):
return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound)
Expand All @@ -323,20 +324,20 @@ func (i *Inspector) ListPendingTasks(queue string, opts ...ListOption) ([]*TaskI
// ListActiveTasks retrieves active tasks from the specified queue.
//
// By default, it retrieves the first 30 tasks.
func (i *Inspector) ListActiveTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) {
func (i *Inspector) ListActiveTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) {
if err := base.ValidateQueueName(queue); err != nil {
return nil, fmt.Errorf("asynq: %v", err)
}
opt := composeListOptions(opts...)
pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1}
infos, err := i.rdb.ListActive(queue, pgn)
infos, err := i.rdb.ListActive(ctx, queue, pgn)
switch {
case errors.IsQueueNotFound(err):
return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound)
case err != nil:
return nil, fmt.Errorf("asynq: %v", err)
}
expired, err := i.rdb.ListLeaseExpired(time.Now(), queue)
expired, err := i.rdb.ListLeaseExpired(ctx, time.Now(), queue)
if err != nil {
return nil, fmt.Errorf("asynq: %v", err)
}
Expand All @@ -363,13 +364,13 @@ func (i *Inspector) ListActiveTasks(queue string, opts ...ListOption) ([]*TaskIn
// ListAggregatingTasks retrieves scheduled tasks from the specified group.
//
// By default, it retrieves the first 30 tasks.
func (i *Inspector) ListAggregatingTasks(queue, group string, opts ...ListOption) ([]*TaskInfo, error) {
func (i *Inspector) ListAggregatingTasks(ctx context.Context, queue, group string, opts ...ListOption) ([]*TaskInfo, error) {
if err := base.ValidateQueueName(queue); err != nil {
return nil, fmt.Errorf("asynq: %v", err)
}
opt := composeListOptions(opts...)
pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1}
infos, err := i.rdb.ListAggregating(queue, group, pgn)
infos, err := i.rdb.ListAggregating(ctx, queue, group, pgn)
switch {
case errors.IsQueueNotFound(err):
return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound)
Expand All @@ -392,13 +393,13 @@ func (i *Inspector) ListAggregatingTasks(queue, group string, opts ...ListOption
// Tasks are sorted by NextProcessAt in ascending order.
//
// By default, it retrieves the first 30 tasks.
func (i *Inspector) ListScheduledTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) {
func (i *Inspector) ListScheduledTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) {
if err := base.ValidateQueueName(queue); err != nil {
return nil, fmt.Errorf("asynq: %v", err)
}
opt := composeListOptions(opts...)
pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1}
infos, err := i.rdb.ListScheduled(queue, pgn)
infos, err := i.rdb.ListScheduled(ctx, queue, pgn)
switch {
case errors.IsQueueNotFound(err):
return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound)
Expand All @@ -421,13 +422,13 @@ func (i *Inspector) ListScheduledTasks(queue string, opts ...ListOption) ([]*Tas
// Tasks are sorted by NextProcessAt in ascending order.
//
// By default, it retrieves the first 30 tasks.
func (i *Inspector) ListRetryTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) {
func (i *Inspector) ListRetryTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) {
if err := base.ValidateQueueName(queue); err != nil {
return nil, fmt.Errorf("asynq: %v", err)
}
opt := composeListOptions(opts...)
pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1}
infos, err := i.rdb.ListRetry(queue, pgn)
infos, err := i.rdb.ListRetry(ctx, queue, pgn)
switch {
case errors.IsQueueNotFound(err):
return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound)
Expand All @@ -450,13 +451,13 @@ func (i *Inspector) ListRetryTasks(queue string, opts ...ListOption) ([]*TaskInf
// Tasks are sorted by LastFailedAt in descending order.
//
// By default, it retrieves the first 30 tasks.
func (i *Inspector) ListArchivedTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) {
func (i *Inspector) ListArchivedTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) {
if err := base.ValidateQueueName(queue); err != nil {
return nil, fmt.Errorf("asynq: %v", err)
}
opt := composeListOptions(opts...)
pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1}
infos, err := i.rdb.ListArchived(queue, pgn)
infos, err := i.rdb.ListArchived(ctx, queue, pgn)
switch {
case errors.IsQueueNotFound(err):
return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound)
Expand All @@ -479,13 +480,13 @@ func (i *Inspector) ListArchivedTasks(queue string, opts ...ListOption) ([]*Task
// Tasks are sorted by expiration time (i.e. CompletedAt + Retention) in descending order.
//
// By default, it retrieves the first 30 tasks.
func (i *Inspector) ListCompletedTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) {
func (i *Inspector) ListCompletedTasks(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) {
if err := base.ValidateQueueName(queue); err != nil {
return nil, fmt.Errorf("asynq: %v", err)
}
opt := composeListOptions(opts...)
pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1}
infos, err := i.rdb.ListCompleted(queue, pgn)
infos, err := i.rdb.ListCompleted(ctx, queue, pgn)
switch {
case errors.IsQueueNotFound(err):
return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound)
Expand Down
49 changes: 34 additions & 15 deletions inspector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ func TestInspectorGetTaskInfoError(t *testing.T) {
func TestInspectorListPendingTasks(t *testing.T) {
r := setup(t)
defer r.Close()

m1 := h.NewTaskMessage("task1", nil)
m2 := h.NewTaskMessage("task2", nil)
m3 := h.NewTaskMessageWithQueue("task3", nil, "critical")
Expand Down Expand Up @@ -718,12 +719,14 @@ func TestInspectorListPendingTasks(t *testing.T) {
}

for _, tc := range tests {
ctx := context.Background()

h.FlushDB(t, r)
for q, msgs := range tc.pending {
h.SeedPendingQueue(t, r, msgs, q)
}

got, err := inspector.ListPendingTasks(tc.qname)
got, err := inspector.ListPendingTasks(ctx, tc.qname)
if err != nil {
t.Errorf("%s; ListPendingTasks(%q) returned error: %v",
tc.desc, tc.qname, err)
Expand Down Expand Up @@ -811,11 +814,13 @@ func TestInspectorListActiveTasks(t *testing.T) {
}

for _, tc := range tests {
ctx := context.Background()

h.FlushDB(t, r)
h.SeedAllActiveQueues(t, r, tc.active)
h.SeedAllLease(t, r, tc.lease)

got, err := inspector.ListActiveTasks(tc.qname)
got, err := inspector.ListActiveTasks(ctx, tc.qname)
if err != nil {
t.Errorf("%s; ListActiveTasks(%q) returned error: %v", tc.qname, tc.desc, err)
continue
Expand Down Expand Up @@ -882,10 +887,12 @@ func TestInspectorListScheduledTasks(t *testing.T) {
}

for _, tc := range tests {
ctx := context.Background()

h.FlushDB(t, r)
h.SeedAllScheduledQueues(t, r, tc.scheduled)

got, err := inspector.ListScheduledTasks(tc.qname)
got, err := inspector.ListScheduledTasks(ctx, tc.qname)
if err != nil {
t.Errorf("%s; ListScheduledTasks(%q) returned error: %v", tc.desc, tc.qname, err)
continue
Expand Down Expand Up @@ -953,10 +960,12 @@ func TestInspectorListRetryTasks(t *testing.T) {
}

for _, tc := range tests {
ctx := context.Background()

h.FlushDB(t, r)
h.SeedAllRetryQueues(t, r, tc.retry)

got, err := inspector.ListRetryTasks(tc.qname)
got, err := inspector.ListRetryTasks(ctx, tc.qname)
if err != nil {
t.Errorf("%s; ListRetryTasks(%q) returned error: %v", tc.desc, tc.qname, err)
continue
Expand Down Expand Up @@ -1023,10 +1032,12 @@ func TestInspectorListArchivedTasks(t *testing.T) {
}

for _, tc := range tests {
ctx := context.Background()

h.FlushDB(t, r)
h.SeedAllArchivedQueues(t, r, tc.archived)

got, err := inspector.ListArchivedTasks(tc.qname)
got, err := inspector.ListArchivedTasks(ctx, tc.qname)
if err != nil {
t.Errorf("%s; ListArchivedTasks(%q) returned error: %v", tc.desc, tc.qname, err)
continue
Expand Down Expand Up @@ -1100,10 +1111,12 @@ func TestInspectorListCompletedTasks(t *testing.T) {
}

for _, tc := range tests {
ctx := context.Background()

h.FlushDB(t, r)
h.SeedAllCompletedQueues(t, r, tc.completed)

got, err := inspector.ListCompletedTasks(tc.qname)
got, err := inspector.ListCompletedTasks(ctx, tc.qname)
if err != nil {
t.Errorf("%s; ListCompletedTasks(%q) returned error: %v", tc.desc, tc.qname, err)
continue
Expand Down Expand Up @@ -1188,14 +1201,16 @@ func TestInspectorListAggregatingTasks(t *testing.T) {
}

for _, tc := range tests {
ctx := context.Background()

h.FlushDB(t, r)
h.SeedTasks(t, r, fxt.tasks)
h.SeedRedisSet(t, r, base.AllQueues, fxt.allQueues)
h.SeedRedisSets(t, r, fxt.allGroups)
h.SeedRedisZSets(t, r, fxt.groups)

t.Run(tc.desc, func(t *testing.T) {
got, err := inspector.ListAggregatingTasks(tc.qname, tc.gname)
got, err := inspector.ListAggregatingTasks(ctx, tc.qname, tc.gname)
if err != nil {
t.Fatalf("ListAggregatingTasks returned error: %v", err)
}
Expand Down Expand Up @@ -1263,7 +1278,9 @@ func TestInspectorListPagination(t *testing.T) {
}

for _, tc := range tests {
got, err := inspector.ListPendingTasks("default", Page(tc.page), PageSize(tc.pageSize))
ctx := context.Background()

got, err := inspector.ListPendingTasks(ctx, "default", Page(tc.page), PageSize(tc.pageSize))
if err != nil {
t.Errorf("ListPendingTask('default') returned error: %v", err)
continue
Expand Down Expand Up @@ -1296,27 +1313,29 @@ func TestInspectorListTasksQueueNotFoundError(t *testing.T) {
}

for _, tc := range tests {
ctx := context.Background()

h.FlushDB(t, r)

if _, err := inspector.ListActiveTasks(tc.qname); !errors.Is(err, tc.wantErr) {
if _, err := inspector.ListActiveTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) {
t.Errorf("ListActiveTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr)
}
if _, err := inspector.ListPendingTasks(tc.qname); !errors.Is(err, tc.wantErr) {
if _, err := inspector.ListPendingTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) {
t.Errorf("ListPendingTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr)
}
if _, err := inspector.ListScheduledTasks(tc.qname); !errors.Is(err, tc.wantErr) {
if _, err := inspector.ListScheduledTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) {
t.Errorf("ListScheduledTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr)
}
if _, err := inspector.ListRetryTasks(tc.qname); !errors.Is(err, tc.wantErr) {
if _, err := inspector.ListRetryTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) {
t.Errorf("ListRetryTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr)
}
if _, err := inspector.ListArchivedTasks(tc.qname); !errors.Is(err, tc.wantErr) {
if _, err := inspector.ListArchivedTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) {
t.Errorf("ListArchivedTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr)
}
if _, err := inspector.ListCompletedTasks(tc.qname); !errors.Is(err, tc.wantErr) {
if _, err := inspector.ListCompletedTasks(ctx, tc.qname); !errors.Is(err, tc.wantErr) {
t.Errorf("ListCompletedTasks(%q) returned error %v, want %v", tc.qname, err, tc.wantErr)
}
if _, err := inspector.ListAggregatingTasks(tc.qname, "mygroup"); !errors.Is(err, tc.wantErr) {
if _, err := inspector.ListAggregatingTasks(ctx, tc.qname, "mygroup"); !errors.Is(err, tc.wantErr) {
t.Errorf("ListAggregatingTasks(%q, \"mygroup\") returned error %v, want %v", tc.qname, err, tc.wantErr)
}
}
Expand Down
Loading

0 comments on commit 6856cc7

Please sign in to comment.