Skip to content

Commit

Permalink
[Gen] Clone logits before returning when cg=True
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Aug 28, 2023
1 parent f8aea6e commit 9f42cb6
Show file tree
Hide file tree
Showing 2 changed files with 431 additions and 4 deletions.
363 changes: 361 additions & 2 deletions flash_attn/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import time
from collections import namedtuple
from dataclasses import dataclass, field
from functools import partial
from typing import Callable, Optional, Sequence, Union

import torch
from einops import rearrange
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
Expand Down Expand Up @@ -205,6 +207,363 @@ def logits_forward_fn(input_ids, position_ids, inference_params):
)


def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):
"""Algorithm 1 from [1]
[1] Fast Inference from Transformers via Speculative Decoding
Yaniv Leviathan, Matan Kalman, Yossi Matias
https://arxiv.org/abs/2211.17192
Arguments:
logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
tokens_draft: Tensor of shape (batch_size, seqlen)
Return:
tokens: Tensor of shape (batch_size, seqlen + 1)
num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
For each sequence in the batch, the number of valid tokens that were sampled by
speculative sampling.
"""
batch, seqlen_p_1, vocab_size = logits.shape
seqlen = seqlen_p_1 - 1
assert logits_draft.shape == (batch, seqlen, vocab_size)
assert tokens_draft.shape == (batch, seqlen)
assert tokens_draft.dtype in [torch.int64, torch.int32]
# TODO: if top_k = 1 we can simplify things and only work with indices
if top_p > 0.0:
assert top_p <= 1.0, "top-p should be in (0, 1]."
# Clone so that when we modify for top_p we don't change the original logits
logits = logits / temperature if temperature != 1.0 else logits.clone()
logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone()
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
modify_logits_for_top_k_filtering(logits, top_k)
modify_logits_for_top_k_filtering(logits_draft, top_k)
modify_logits_for_top_p_filtering(logits, top_p)
modify_logits_for_top_p_filtering(logits_draft, top_p)
probs = torch.softmax(logits, dim=-1)
probs_draft = torch.softmax(logits_draft, dim=-1)
gather = lambda probs, tokens: rearrange(
probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..."
)
# (batch, seqlen)
accepted = torch.rand(batch, seqlen, device=probs.device) * gather(
probs_draft, tokens_draft
) <= gather(probs[:, :-1], tokens_draft)
accepted_all = accepted.all(dim=-1)
# (batch,)
first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1))
probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0)
# torch.multinomial can deal with unnormalized probabilities
# probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1)
resample_probs = rearrange(
resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)),
"b 1 d -> b d",
)
resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1) # (batch,)
tokens = F.pad(tokens_draft, (0, 1))
tokens[:, first_rejected_idx] = resample
return tokens, first_rejected_idx + 1


def decode_speculative(
input_ids,
model,
model_draft,
max_length,
speculative_lookahead=3,
top_k=1,
top_p=0.0,
temperature=1.0,
eos_token_id=None,
vocab_size=None,
tensor_parallel=1,
fused_ft_kernel=False,
cg=False,
timing=False,
debug=False,
):
"""
TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.
Speculative decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1"
assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id"
if cg:
assert fused_ft_kernel
if not hasattr(model_draft, "_decoding_cache"):
model_draft._decoding_cache = None
model_draft._decoding_cache = update_graph_cache(
model_draft,
model_draft._decoding_cache,
batch_size,
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
)
inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.max_sequence_len = max_length
inference_params_draft.max_batch_size = batch_size
inference_params_draft.sequence_len_offset = 0
# fused_ft_kernel doesn't support passing in multiple tokens at once
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
)
else:
inference_params_draft = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
)
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
)

def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
if not cg:
return model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
else:
return model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset
).clone()

logits_postprocess_fn = (
lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
)

def sample_tokens(
input_ids, model, inference_params, sample_fn, num_tokens=1, cg=False, decoding=True,
last_token_logits=False
):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
Arguments:
input_ids: (batch, seqlen)
decoding: whether we're in the decoding phase or the prefilling phase. Prefill doesn't
need special position_ids.
last_token_logits: whether to return the logits of the last token. Normally we don't need this.
However, for speculative sampling, if the main model accepts all the draft tokens, plus it
samples one new token, then by right at the next iteration the draft model need to evaluate
the logits of the last draft token and the logits of the newly sampled token.
This makes implementation more complicated. So here we just evaluate the logits of the last
token in the draft model to simplify the implementation.
Return:
tokens: (batch, num_tokens)
scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
(num_tokens - 1) tokens. The logits of the last token isn't computed unless last_token_logits=True.
In which case we have scores of shape (batch, num_tokens + 1)
"""
batch_size, seqlen = input_ids.shape
assert num_tokens >= 1
sequences = []
if decoding:
assert seqlen == 1
position_ids = torch.full(
(batch_size, 1),
inference_params.sequence_len_offset,
dtype=torch.long,
device=input_ids.device,
)
else:
position_ids = None
logits = logits_postprocess_fn(
logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
)
inference_params.sequence_len_offset += input_ids.shape[1]
scores = [logits]
next_token = sample_fn(logits)
sequences.append(next_token)
for i in range(num_tokens):
if i < num_tokens - 1 or last_token_logits:
position_ids = torch.full(
(batch_size, 1),
inference_params_draft.sequence_len_offset,
dtype=torch.long,
device=input_ids.device,
)
logits = logits_postprocess_fn(
logits_forward_fn(
model, rearrange(next_token, "b -> b 1"), position_ids, inference_params, cg=cg
)
)
inference_params.sequence_len_offset += 1
scores.append(logits)
if i < num_tokens - 1:
next_token = sample_fn(logits)
sequences.append(next_token)
return torch.stack(sequences, dim=1), torch.stack(scores, dim=1)

sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature)
sample_fn = partial(sample, **sampling_kwargs)
sample_tokens_main = partial(
sample_tokens, model=model, sample_fn=sample_fn, inference_params=inference_params, cg=False
) # main model doesn't use CUDA graph
sample_tokens_draft = partial(
sample_tokens,
model=model_draft,
sample_fn=sample_fn,
last_token_logits=True,
inference_params=inference_params_draft,
cg=cg
)

if debug:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
sequences = [input_ids]
scores = []
with torch.inference_mode():
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()

if seqlen_og >= max_length - 1:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1, decoding=False)
sequences.append(tokens)
scores.append(scores_new)
else:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)
tokens_draft, scores_draft = sample_tokens_draft(
input_ids,
num_tokens=n_spec_tokens,
decoding=False,
)
if debug:
scores_draft_ref = model_draft(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max())

# Evaluate the draft tokens with the model
logits = model(
torch.cat([input_ids, tokens_draft], dim=1),
inference_params=inference_params,
num_last_tokens=n_spec_tokens + 1,
).logits
logits = logits_postprocess_fn(logits)
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs
)
if debug:
print(tokens)
print(num_generated_tokens)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
inference_params.sequence_len_offset = seqlen_og + num_generated_tokens[0].item() - 1
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
if debug:
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
scores_ref = model(
cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1
).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())

while True:
# sequence_len_offset is total length generated - 1
if inference_params.sequence_len_offset >= max_length - 1:
break
if inference_params.sequence_len_offset >= max_length - 2:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
sequences.append(tokens)
scores.append(scores_new)
break
# Sample from draft model
n_spec_tokens = min(
speculative_lookahead, max_length - inference_params_draft.sequence_len_offset - 2
)
tokens_draft, scores_draft = sample_tokens_draft(
sequences[-1][:, -1:], num_tokens=n_spec_tokens
)
if debug:
scores_draft_ref = model_draft(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max())
# Evaluate the draft tokens with the model
position_ids = repeat(
torch.arange(
inference_params.sequence_len_offset,
# 1 extra token from last time that hasn't been passed through model
inference_params.sequence_len_offset + n_spec_tokens + 1,
dtype=torch.long,
device=input_ids.device,
),
"s -> b s",
b=batch_size,
)
logits = model(
torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
position_ids=position_ids,
inference_params=inference_params,
).logits # (batch, n_spec_tokens, vocab_size)
logits = logits_postprocess_fn(logits)
inference_params.sequence_len_offset += 1
if debug:
logits_ref = model(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((logits - logits_ref).abs().max())
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs
)
if debug:
print(tokens)
print(num_generated_tokens)
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
inference_params.sequence_len_offset += num_generated_tokens[0].item() - 1
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
# breakpoint()
if debug:
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
scores_ref = model(
cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1
).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())

if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
sequences = torch.cat(sequences, dim=1)
scores = torch.cat(scores, dim=1)
if debug:
scores_ref = model(sequences).logits
print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max())
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(sequences=sequences, scores=scores)


class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError
Expand Down Expand Up @@ -394,7 +753,7 @@ def run(new_input_ids, new_position_ids, seqlen):
input_ids.copy_(new_input_ids)
position_ids.copy_(new_position_ids)
graph.replay()
return logits
return logits.clone()

inference_params.sequence_len_offset = sequence_len_offset_og
return run
Loading

0 comments on commit 9f42cb6

Please sign in to comment.