Skip to content

Commit

Permalink
WIP: problem with do_sample and RNG
Browse files Browse the repository at this point in the history
JETSTREAM_PT=1 python -m pytest -sv text-generation-inference/tests/test_tinyllama.py -k "prefill and single and sample"
  • Loading branch information
tengomucho committed Oct 25, 2024
1 parent 4a11067 commit 35640f0
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,15 @@ def set(self, slot: Slot):
self._curslot = slot

def select(self, logits: jnp.ndarray) -> int:
return self._curslot.select(logits)
def _inner_select(logits):
breakpoint()
print(f"PrefillSlot.select id {self._curslot.id}")
return self._curslot.select(logits)
token = jax.pure_callback(
_inner_select,
result_shape_dtypes=jax.ShapeDtypeStruct((), jnp.int32),
logits=logits)
return token

class TpuGeneratorJetStream(Generator):
"""A Generator for models running on TPU, single threaded."""
Expand Down Expand Up @@ -457,6 +465,8 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
)
slot.reset(truncated_input_ids, selector)
slot.update_rng_key()
print(slot._selector.key)
breakpoint()
# To allow jit'ing the select function, we need to wrap it in a partial
slot_select = jax.tree_util.Partial(self.prefill_slot.select)
# Ask for prefill and insert
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,12 @@ def update_rng_key(self):
self.key, _ = jax.random.split(self.key)

def _sample(self, scores: jnp.ndarray) -> jnp.ndarray:
jax.debug.print(f"_sample logits shape: {scores.shape}")
do_top_k = self.logits_warper.top_k > 0 and self.logits_warper.top_k < scores.shape[-1]
do_top_p = self.logits_warper.top_p < 1.0 and self.logits_warper.top_p > 0.0

if do_top_k:
print(f"Will do top-k sampling")
return sampling_utils.sample_topk_logits(
scores,
self.logits_warper.top_k,
Expand Down
19 changes: 4 additions & 15 deletions text-generation-inference/tests/test_tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from tqdm import tqdm


MODEL_ID = "google/gemma-2b"
MODEL_ID = "Maykeye/TinyLLama-v0"
SEQUENCE_LENGTH = 1024
SEQUENCE_LENGTH = 256


@pytest.fixture(scope="module")
Expand All @@ -20,7 +19,7 @@ def test_info(model_path):
generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1)
info = generator.info
assert info.requires_padding is True
assert info.device_type == "xla"
assert info.device_type == "meta"
assert info.window_size == 0
assert info.speculate == 0

Expand Down Expand Up @@ -61,12 +60,8 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_
assert len(generations) == batch_size
for g in generations:
tokens = g.tokens
try:
assert tokens.ids == [token_id]
assert tokens.texts == [token_text]
except AssertionError:
print(tokens)
breakpoint()
assert tokens.ids == [token_id]
assert tokens.texts == [token_text]

def test_decode_multiple(model_path):
generator = AutoGenerator.from_pretrained(model_path,
Expand Down Expand Up @@ -94,7 +89,6 @@ def test_decode_multiple(model_path):
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == gen_tokens
assert next_batch.size == 1
print(tokens[0])
# Add a second request
request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)
batch = Batch(id=1, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
Expand All @@ -113,7 +107,6 @@ def test_decode_multiple(model_path):
for g in generations:
tokens[g.request_id].append(g.tokens.ids[0])
batches = [next_batch]
print(tokens)
# Verify we now only have one pending request
assert next_batch.size == 1
assert len(tokens[0]) == max_new_tokens
Expand All @@ -134,9 +127,5 @@ def test_decode_multiple(model_path):
assert next_batch is None
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
print(tokens[0])
print(tokens[1])
print(tokens[0] == tokens[1])
return
assert tokens[0] == tokens[1]
assert output.text == generated_text

0 comments on commit 35640f0

Please sign in to comment.