Skip to content

Commit

Permalink
💎 Gemma on TGI Jetstream Pytorch (#99)
Browse files Browse the repository at this point in the history
* feat(Jetstream Pt): add gemma support

* test(TGI): add gemma 7b slow test that uses Pytorch Jetstream

* doc: update Jetstream Pytorch install command

* refactor(Jetstream Pt): simplify model_class.from_config call

* refactor(engine loader): DRY code

* fix(test): clarify warmup test comment
  • Loading branch information
tengomucho authored Oct 4, 2024
1 parent a0464df commit 1194f61
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 29 deletions.
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ Please see the [TGI specific documentation](text-generation-inference) on how to
`optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated command:

```shell
pip install "optimum-tpu[jetstream-pt]" \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
make jetstream_requirements
```

To enable the support, export the environment variable `JETSTREAM_PT=1`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def model_can_use_jetstream_pt(model_path: str) -> bool:
the engine are installed.
"""
config = AutoConfig.from_pretrained(model_path)
# For now only Llama is supported
if config.model_type != "llama":
# For now few models are supported
supported_models = ["llama", "gemma"]
if config.model_type not in supported_models:
return False
if jetstream_pt_available():
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,30 @@
from transformers import PretrainedConfig
from transformers import AutoConfig

from .llama_model_exportable_hf import TransformerHf
from .compatibility import model_can_use_jetstream_pt
from .gemma_model_hf import GemmaModelHf as GemmaModel
from .llama_model_exportable_hf import TransformerHf as LlamaModel


def load_llama_model_info(config: "PretrainedConfig") -> Any:
def _get_head_dim(config: "PretrainedConfig") -> int:
if hasattr(config, "head_dim"):
return config.head_dim
return config.hidden_size // config.num_attention_heads

def load_model_info(config: "PretrainedConfig") -> Any:
num_layers = config.num_hidden_layers
num_heads = config.num_attention_heads
head_dim = config.hidden_size // num_heads
head_dim = _get_head_dim(config)
num_kv_heads = config.num_key_value_heads
n_reps = num_heads // num_kv_heads
if config.model_type == "llama":
model_class = LlamaModel
elif config.model_type == "gemma":
model_class = GemmaModel
else:
raise ValueError(f"Unsupported model type {config.model_type}")
model_info = fetch_models.ModelInfo(
TransformerHf,
model_class=model_class,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
Expand All @@ -37,21 +50,15 @@ def load_llama_model_info(config: "PretrainedConfig") -> Any:
return model_info


def load_model_info(config: "PretrainedConfig") -> Any:
# For now only Llama is supported
if config.model_type == "llama":
return load_llama_model_info(config)
# Other models supports can be added here later
return None


def create_engine_env_data(
model_path: str,
batch_size: int,
sequence_length: int,
max_input_tokens: int,
max_output_tokens: int,
) -> Any:
if not model_can_use_jetstream_pt(model_path):
return None
# First get config
config = AutoConfig.from_pretrained(model_path)
model_info = load_model_info(config)
Expand Down Expand Up @@ -86,12 +93,6 @@ def create_engine_env_data(
return env_data


def create_model(model_path: str, env: Any) -> Any:
config = AutoConfig.from_pretrained(model_path)
if config.model_type == "llama":
return TransformerHf.from_config(config, env)


def instantiate_model_from_repo_id(
model_dir: str,
env: Any,
Expand All @@ -103,7 +104,7 @@ def instantiate_model_from_repo_id(
assert model_info is not None

env.device = "meta"
model = create_model(model_dir, env)
model = model_info.model_class.from_config(config, env)
weights = fetch_models._load_weights(model_dir)
weights = model.convert_hf_weights(weights)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

from jetstream_pt.third_party.gemma import config as gemma_config
from jetstream_pt.third_party.gemma.model import GemmaModel

#.model_exportable import Transformer, model_args
from transformers import GemmaConfig, GenerationConfig, GenerationMixin


class GemmaModelHf(GemmaModel, GenerationMixin):
"""Transformer module that uses HF GemmaConfig instead of Jetstream Pytorch GemmaConfig + device.
Note that this class also derives from GenerationMixin, so that we can use its methods.
"""

def __init__(
self,
config: GemmaConfig,
device,
env,
):
self.config = config
self.generation_config = GenerationConfig.from_model_config(config)

args = gemma_config.GemmaConfig(
vocab_size=config.vocab_size,
max_position_embeddings=config.max_position_embeddings,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
dtype="bfloat16",
quant=False, # No quantization support for now
tokenizer=None,
)

args.device = device
super().__init__(args, env)


@classmethod
def from_config(cls, config, env):
device = "meta"
model = cls(config, device, env)
return model
14 changes: 12 additions & 2 deletions text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,13 @@ def _test_decode_single(params):
expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,",
top_k=100,
),
DecodeTestParams(
model_id="google/gemma-7b",
sequence_length=128,
expected_text="\n\nThe time is 1984. The place is Airstrip One, the British",
),
],
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B"],
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b"],
)
def test_decode_single_jetstream_pytorch_slow(params, do_sample):
if not jetstream_pt_available():
Expand All @@ -136,8 +141,13 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
expected_text=" The sun was shining and the sky was shining.\nSuddenly, a big wind came and blew the wind away.",
max_new_tokens=25,
),
DecodeTestParams(
model_id="google/gemma-2b",
sequence_length=1024,
expected_text="\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never",
),
],
ids=["TinyLLama-v0"],
ids=["TinyLLama-v0", "gemma-2b"],
)
def test_decode_single_jetstream_pytorch(params, do_sample):
if not jetstream_pt_available():
Expand Down
4 changes: 2 additions & 2 deletions text-generation-inference/tests/test_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_warmup_jetstream_pytorch():
model_path, revision="", max_batch_size=2, max_sequence_length=sequence_length
)
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
# The maximum sequence length of the model is set to 1000, but warmup will round that up to the next power of two
# in prefill (256).
# The maximum tokens length of the model is intentionally not a power of two, to verify that prefill bucketization
# works as expected (250 -> 256).
max_tokens = 250
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_tokens)
generator.warmup(batch)
Expand Down

0 comments on commit 1194f61

Please sign in to comment.