Skip to content

Commit

Permalink
Add static KV cache and test on Gemma-2B (#4)
Browse files Browse the repository at this point in the history
* chore(style): run make style

* chore(style): update pyproject to avoid ruff warning

* fix(tgi): sequence length should be based on sequence_length config

It was previously using n_positions sometimes, but that would not be
available on some model configs.

* feat(modeling): model is immediately loaded on device

* debug: added env var to debug on CPU

if DBG_DEVICE env var is set, it will used to set the device for the
model.

* feat(test): reduce overhad when retrieving model

This will avoid loading the model twice.

* feat(modeling): make compilation optional

Make compilation optional, it can be enabled with the environment
variable DBG_COMPILE. This is because:

1. There are some models that produce bugs when the model is compiled.
   (notably gemma).
2. Models inference input params shapes change, triggering
   recompilation, leading to slow performance.
3. With the added xm.mark_step, performance is actually better when the
   model is not compiled. XLA builds a graph anyway, so performance is
   going to be good.

* feat: add @torch.no_grad decorators to decode and prefill

This is to reduce useless gradient calculations.

* chore(generator): create buffers in device to avoid moving them

* refactor(generator): some model params are passed as dict

This will allow to handle passing different params in different model
configurations later.

* feat: use static KV cache when available

Some models, like Gemma and Llama, support static KV cache in
transformers. For these, it is possible to use this feature, leading to
much higher performance.

* fix(CI): added HF_TOKEN to use models that require it

Also manually install accelerate to avoid memory issues when loading
gemma.

* fix(CI): adapt expected result in do_sample test

The test produces different results after some operations are being done
in a slightly different order.
  • Loading branch information
tengomucho authored Mar 15, 2024
1 parent 4c75f88 commit fdcd7ea
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 73 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test-pytorch-xla-tpu-tgi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ jobs:
run: python -c "import torch_xla.core.xla_model as xm; assert xm.xla_device().type == 'xla', 'XLA device not available'"

- name: Build and test TGI server
run: make tgi_test
run: |
pip install accelerate==0.27.2
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ line-length = 119
target-version = ['py38']
extend-exclude = '.ipynb'

[tool.ruff]
[lint.ruff]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
line-length = 119

# Ignore import violations in all `__init__.py` files.
[tool.ruff.per-file-ignores]
[lint.ruff.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]

[tool.ruff.isort]
[lint.ruff.isort]
lines-after-imports = 2
known-first-party = ["optimum.tpu"]

Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def docker_launcher(
trust_remote_code: bool = False,
):
# TODO: consider finding out how to forward a port in the container instead of leaving it to 80.
#For now this is necessary because TPU dockers require to run with net=host and privileged mode.
# For now this is necessary because TPU dockers require to run with net=host and privileged mode.
port = 80

args = ["--model-id", model_id, "--env"]
Expand Down
9 changes: 7 additions & 2 deletions text-generation-inference/integration-tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ async def test_model_single_request(tgi_client):
decoder_input_details=True,
)
assert response.details.generated_tokens == 17
assert response.generated_text == "\n\nDeep learning is a technique that allows you to learn something from a set of"
assert (
response.generated_text == "\n\nDeep learning is a technique that allows you to learn something from a set of"
)

# Greedy bounded with input
response = await tgi_client.generate(
Expand All @@ -64,7 +66,10 @@ async def test_model_single_request(tgi_client):
seed=42,
decoder_input_details=True,
)
assert 'The deep neural networks that we create are essentially "miniature" neural networks that can easily be trained' in response.generated_text
assert (
'The deep neural networks that we create are essentially "miniature" neural networks that can easily be trained'
in response.generated_text
)


@pytest.mark.asyncio
Expand Down
134 changes: 96 additions & 38 deletions text-generation-inference/server/text_generation_server/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import List, Optional, Tuple

import torch
import torch_xla.core.xla_model as xm
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache
from transformers.generation import GenerationConfig

from .modeling import TpuModelForCausalLM
Expand Down Expand Up @@ -94,10 +95,11 @@ class State(Enum):
PAUSE = 1
READY = 2

def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase, device: [str, torch.device]):
self._id = id
self._tokenizer = tokenizer
self.clear()
self._device = device

def clear(self):
"""Clear the slot and mark it as available."""
Expand All @@ -106,7 +108,7 @@ def clear(self):
self._inputs = ""
self._generation_config = None
self._tokens = []
self._mask = []
self._mask = None
self._selector = None
self._generated_tokens = 0
self._next_text_token_start = 0
Expand Down Expand Up @@ -139,6 +141,10 @@ def generation_config(self) -> GenerationConfig:
def generated_tokens(self) -> int:
return self._generated_tokens

@property
def cur_position(self) -> int:
return self._next_text_token_start

def assign(self, request: Request, generation_config: GenerationConfig):
"""Assign a request to a slot.
Expand Down Expand Up @@ -179,7 +185,10 @@ def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, s
self._next_text_token_start = 0
self._next_text_token_end = torch.numel(self._tokens)
self._next_text = ""
self._mask = attention_mask.clone()
if attention_mask is not None:
self._mask = attention_mask.clone()
else:
self._mask = None
self._selector = selector

def pause(self):
Expand Down Expand Up @@ -238,8 +247,12 @@ def append(self, next_token: int) -> str:
Return:
The corresponding decoded text (if any).
"""
self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])])
self._mask = torch.cat([self._mask, torch.LongTensor([1])])
self._tokens = torch.cat(
[self._tokens, torch.tensor([next_token], device=self._device, dtype=self._tokens.dtype)]
)
# Update mask only if it was set previously
if self._mask is not None:
self._mask = torch.cat([self._mask, torch.tensor([1], device=self._device, dtype=self._mask.dtype)])
self._generated_tokens += 1
next_text = self._decode_next_tokens()
# Now that a new token has been generated, we can append the previous one to the generated text
Expand Down Expand Up @@ -296,8 +309,16 @@ def __init__(
tokenizer.padding_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)]
self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)]
self.past_key_values = None
# _setup_cache is specific to some models (e.g.: Gemma and Llama). In those cases it is possible to setup
# a static cache, otherwise it is not.
self.use_static_cache = True
if getattr(self.model, "_setup_cache", False) is False:
logger.warning(
f"Static cache not available for {self.model.__class__.__name__}. Performance will be affected"
)
self.use_static_cache = False

@property
def info(self) -> InfoResponse:
Expand Down Expand Up @@ -326,8 +347,9 @@ def warmup(self, batch: Batch) -> int:
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
)
self.prefill(batch)
return self.model.config.batch_size * self.model.config.n_positions
return self.model.config.batch_size * self.model.config.sequence_length

@torch.no_grad
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.
Expand Down Expand Up @@ -361,9 +383,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# for unfinished requests.
inputs = [slot.cached_text for slot in self.slots]
# Tokenize with padding
padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)
padded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.model.device)
# If needed truncate sequences to fit into the static dimensions
seq_length = min(padded_inputs.input_ids.shape[-1], self.model.config.n_positions)
seq_length = min(padded_inputs.input_ids.shape[-1], self.model.config.sequence_length)
input_ids = padded_inputs.input_ids[:, :seq_length]
attention_mask = padded_inputs.attention_mask[:, :seq_length]
# Pause previously active slots during generation and store their last token.
Expand All @@ -377,17 +399,36 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
selector = TokenSelector.create(
slot_input_ids, slot.generation_config, self.model, self.model.config.n_positions, seed=slot.seed
slot_input_ids,
slot.generation_config,
self.model,
self.model.config.sequence_length,
seed=slot.seed,
)
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i]
if self.use_static_cache:
# Attention mask does not need to be tracked when using static cache
slot_attention_mask = None
else:
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
# Clear KV cache
self.past_key_values = None
# Pause previously active slots during generation.
# The KV cache of paused slots will be prefilled during generation but new tokens
# will be ignored, as they have already been generated and sent back in the last decode.
generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask)
# Obtain position ids using attention mask.
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
position_ids = position_ids[:, -input_ids.shape[-1] :]

extra_args = {}
if self.use_static_cache:
self.model._setup_cache(StaticCache, len(self.slots), self.model.config.sequence_length)
extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device)
else:
# Reset/clear KV cache
self.past_key_values = None
generation, next_batch = self._generate_token(
batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids, **extra_args
)

# Reactivate previously active slots for the next decode, and append
# back their next token.
for slot, next_token in zip(active_slots, next_tokens):
Expand All @@ -396,6 +437,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
logger.debug("Model ready for decoding")
return generation, next_batch

@torch.no_grad
def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
"""Decode the specified prefilled requests.
Expand All @@ -416,46 +458,62 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# Reconstruct input_ids and attention_mask from slots
input_ids = None
attention_mask = None
position_ids = torch.zeros(
[self.model.config.batch_size, 1],
dtype=torch.int64,
device=self.model.device,
)
for i, slot in enumerate(self.slots):
if slot.state != Slot.State.EMPTY:
if input_ids is None:
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[self.model.config.batch_size, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64
[self.model.config.batch_size, 1],
fill_value=self.tokenizer.eos_token_id,
dtype=torch.int64,
device=self.model.device,
)
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[self.model.config.batch_size, slot.attention_mask.size(-1)], dtype=torch.int64
)
attention_mask[i, :] = slot.attention_mask
if not self.use_static_cache:
# When using dynamic cache, the whole attention mask needs to be passed over to the model at each iteration.
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[self.model.config.batch_size, slot.attention_mask.size(-1)],
dtype=torch.int64,
device=self.model.device,
)
attention_mask[i, :] = slot.attention_mask
position_ids[i, 0] = slot.cur_position
if input_ids is None:
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
return self._generate_token(next_batch_id, input_ids, attention_mask)
extra_args = {}
if self.use_static_cache:
extra_args["cache_position"] = position_ids.max().unsqueeze(0)
else:
extra_args["attention_mask"] = attention_mask
extra_args["past_key_values"] = self.past_key_values
return self._generate_token(next_batch_id, input_ids, position_ids=position_ids, **extra_args)

def _generate_token(
self, next_batch_id: int, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None
self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params
) -> Tuple[List[Generation], CachedBatch]:
# Obtain position ids using attention mask.
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
position_ids = position_ids[:, -input_ids.shape[-1] :]
# Move input params to device
input_ids = input_ids.to(self.model.device)
attention_mask = attention_mask.to(self.model.device)
position_ids = position_ids.to(self.model.device)
# Add barrier to allow next graph step to always be the same
xm.mark_step()
# Forward
outputs = self.model(
input_ids,
past_key_values=self.past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
use_cache=True,
**forward_extra_params,
)
# Save KV cache
self.past_key_values = outputs.past_key_values
if not self.use_static_cache:
# Save KV cache
self.past_key_values = outputs.past_key_values
# Barrier for XLA model
xm.mark_step(wait=False)

generations = []
active_slots = False
for i, slot in enumerate(self.slots):
Expand Down Expand Up @@ -507,7 +565,7 @@ def _generate_token(

def _cached_batch(self, batch_id: int, request_ids: List):
size = len(request_ids)
max_tokens = size * self.model.config.n_positions
max_tokens = size * self.model.config.sequence_length
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)

def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import time
from pathlib import Path
from typing import Optional

from huggingface_hub import snapshot_download
from loguru import logger
from transformers import AutoConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from huggingface_hub import snapshot_download


def get_export_kwargs_from_env():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from typing import Any

import torch
import torch_xla.core.xla_model as xm
from loguru import logger
from transformers import AutoModelForCausalLM
from transformers.utils import is_accelerate_available


# TODO: For now TpuModelForCausalLM is just a shallow wrapper of
Expand All @@ -38,7 +38,23 @@ def from_pretrained(
*model_args: Any,
**kwargs: Any,
):
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
if "PJRT_DEVICE" not in environ:
logger.info("PJRT_DEVICE environment variable not found. Setting it to 'TPU'.")
environ["PJRT_DEVICE"] = "TPU"
if "DBG_DEVICE" in environ:
device = environ["DBG_DEVICE"]
logger.debug(f"Device set to: {device}")
else:
device = "xla"
if is_accelerate_available():
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, device_map=device, *model_args, **kwargs
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
model.to(device)
# Update config with specific data)
if task is not None or getattr(model.config, "task", None) is None:
model.config.task = task
Expand All @@ -47,13 +63,10 @@ def from_pretrained(
if sequence_length is not None or getattr(model.config, "sequence_length", None) is None:
model.config.sequence_length = sequence_length

if "PJRT_DEVICE" not in environ:
logger.warning("PJRT_DEVICE environment variable not found. Setting it to 'TPU'.")
environ["PJRT_DEVICE"] = "TPU"
dev = xm.xla_device()
# Do eval, move model to device and compile
model.to(dev)
# Do eval, and compile
model.eval()
model = torch.compile(model, backend="openxla_eval")
if device == "xla" and "DBG_COMPILE" in environ:
model = torch.compile(model, backend="openxla_eval")
logger.debug("Model compiled.")

return model
Loading

0 comments on commit fdcd7ea

Please sign in to comment.