From fd295912922515f3110fffda3cdb9c5ad600a905 Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Sat, 6 Jul 2024 13:08:25 +0200 Subject: [PATCH] More Inference Endpoints features and fixes (#68) * feat(generator): better handle exceptions on multiprocessing This will raise an error, signaling there was a problem. Before the root thread was getting stuck waiting for the agent that was dead. This way it should exit. * feat(tgi): add more debug on server * chore(docker): entrypoint json output is set by default It is possible to disable it by setting JSON_OUTPUT_DISABLE. It is now possible also to play with more batch sizes. * feat(generator): add bucketing functions to use in prefill * feat(generator): store position_id in current slot This will further simplify the implementation of prefill bucketing. * fix(generator): correct input_ids and attention_mask padding * fix(TGI): fix input truncation Truncation was sub-optimal, and it was done on the wrong side. * feat(generator): enable logs on children processes * feat(tgi): warmup runs prefill/decode on all supported combinations This will prevent XLA compilation at inference time. Note that I had to disable dynamo compilation though, otherwise the model was not generating correct results. This leads to slower generation, but at least generation seems stable now. * ci(tgi): create images when pushing on current branch This will allow for testing IE before release. * feat(tgi): reversed loop order in warmup to test memory limits earlier * chore(ci): remove image generation for this branch --- optimum/tpu/xla_mp_comm.py | 5 + .../docker/entrypoint.sh | 15 +- .../text_generation_server/generator.py | 237 ++++++++++++++---- .../server/text_generation_server/server.py | 13 +- .../tests/test_prefill_truncate.py | 14 +- 5 files changed, 224 insertions(+), 60 deletions(-) mode change 100644 => 100755 text-generation-inference/docker/entrypoint.sh diff --git a/optimum/tpu/xla_mp_comm.py b/optimum/tpu/xla_mp_comm.py index 069de5de..be2062dc 100644 --- a/optimum/tpu/xla_mp_comm.py +++ b/optimum/tpu/xla_mp_comm.py @@ -11,6 +11,8 @@ def __init__(self, manager: mp.Manager): self.root_command = manager.list() self.agent_ready = manager.Event() self.output_data = manager.list() + self.agent_error = manager.Event() + self.agent_error.clear() def send(self, command: int, *args) -> ListProxy: """Send a command and arguments to the agents and wait for the response. @@ -30,6 +32,8 @@ def send(self, command: int, *args) -> ListProxy: self.root_bell.set() # wait again until agent is ready, meaning command has been processed self.agent_ready.wait() + if self.agent_error.is_set(): + raise RuntimeError("Error on one of threads, stopping.") ret = self.output_data return ret @@ -41,6 +45,7 @@ def __init__(self, root_mailbox: RootMailbox): self.root_command = root_mailbox.root_command self.agent_ready = root_mailbox.agent_ready self.output_data = root_mailbox.output_data + self.agent_error = root_mailbox.agent_error def receive(self) -> ListProxy: """Wait for a command from the root process and return it. diff --git a/text-generation-inference/docker/entrypoint.sh b/text-generation-inference/docker/entrypoint.sh old mode 100644 new mode 100755 index 7dee246f..b56eb564 --- a/text-generation-inference/docker/entrypoint.sh +++ b/text-generation-inference/docker/entrypoint.sh @@ -5,6 +5,18 @@ ulimit -l 68719476736 # Hugging Face Hub related +if [[ -z "${BATCH_SIZE}" ]]; then + BATCH_SIZE=4 +fi +export BATCH_SIZE="${BATCH_SIZE}" + +if [[ -z "${JSON_OUTPUT_DISABLE}" ]]; then + JSON_OUTPUT_DISABLE=--json-output +else + JSON_OUTPUT_DISABLE="" +fi +export JSON_OUTPUT_DISABLE="${JSON_OUTPUT_DISABLE}" + if [[ -z "${MODEL_ID}" ]]; then echo "MODEL_ID must be set" exit 1 @@ -12,6 +24,7 @@ fi export MODEL_ID="${MODEL_ID}" text-generation-launcher --port 8080 \ - --max-batch-size 4 \ + --max-batch-size ${BATCH_SIZE} \ + ${JSON_OUTPUT_DISABLE} \ --model-id ${MODEL_ID} diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index ee65411c..4254e58f 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -1,7 +1,10 @@ import copy import logging import os +import sys import time +import traceback +from bisect import bisect_left from enum import Enum from typing import Dict, List, Optional, Tuple @@ -25,7 +28,9 @@ GeneratedText, Generation, InfoResponse, + NextTokenChooserParameters, Request, + StoppingCriteriaParameters, Tokens, ) @@ -34,6 +39,28 @@ optimum_logger = logging.getLogger("optimum.tpu") optimum_logger.setLevel("CRITICAL") +# These will do some bucketing on prefill lengths to avoid too many different sizes +PREFILL_LENGTHS = [ + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, +] + +def take_nearest_length(length: int) -> int: + """Gets the nearest length to the right in a set of lengths.""" + pos = bisect_left(PREFILL_LENGTHS, length) + if pos == len(PREFILL_LENGTHS): + return PREFILL_LENGTHS[-1] + return PREFILL_LENGTHS[pos] class Slot: """Represents a slot in a static batch""" @@ -66,6 +93,7 @@ def clear(self): self._next_text = "" self._kv_cache = None self._truncate = 0 + self._position_id = 0 @property def id(self) -> int: @@ -95,14 +123,24 @@ def generation_config(self) -> GenerationConfig: def generated_tokens(self) -> int: return self._generated_tokens - @property - def cur_position(self) -> int: - return self._next_text_token_start - @property def truncate(self) -> int: return self._truncate + @property + def position_id(self) -> int: + return self._position_id + + @position_id.setter + def position_id(self, cur_pos: int): + self._position_id = cur_pos + + @property + def cache_position(self) -> int: + # This corresponds to the cache position for this slot + return self._next_text_token_start + + def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig): """Assign a request to a slot. @@ -272,6 +310,7 @@ def __init__( # Specify padding options for decoder-only architecture tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids # Slots are empty to begin with, they will be populated as new batches arrive @@ -288,7 +327,7 @@ def __init__( ) self._supports_static_cache = False # compile model when possible to accelerate decoding - if model.device.type == "xla" and ("DBG_COMPILE" in os.environ or self._supports_static_cache): + if model.device.type == "xla" and ("DBG_COMPILE" in os.environ): self.model_one_token = torch.compile(model, backend="openxla") logger.debug("Model compiled for decoding") else: @@ -304,6 +343,34 @@ def info(self) -> InfoResponse: device_type="xla", ) + def _create_dummy_request(self, max_tokens: int) -> Batch: + """Create a dummy request for warmup.""" + # Generate a random input with slightly more tokens than requested, because special tokens are going to be + # skipped. + MARGIN = 10 + input_tokens = torch.randint(self.model.config.vocab_size, (1, max_tokens + MARGIN), dtype=torch.int64) + text = self.tokenizer.decode(input_tokens[0], skip_special_tokens=True) + # These are just dummy params to allow Request creation + parameters = NextTokenChooserParameters( + temperature=1.0, + top_k=None, + top_p=None, + do_sample=False, + seed=None, + repetition_penalty=1.0, + typical_p=1.0, + ) + stopping_parameters = StoppingCriteriaParameters(max_new_tokens=20, ignore_eos_token=True) + dummy_request = Request( + id=0, + inputs=text, + truncate=max_tokens, + parameters=parameters, + stopping_parameters=stopping_parameters, + ) + return dummy_request + + def warmup(self, batch: Batch) -> int: """Verify if the hardware can support the target load. @@ -315,6 +382,7 @@ def warmup(self, batch: Batch) -> int: The maximum number of tokens the model supports. """ logger.debug("Warming up the model") + start = time.time() # Just check that the warmup request parameters match the model capacity # NOTE: later self.model.config.batch_size might become self.model.config.max_batch_size. if self.model.config.batch_size is not None: @@ -326,9 +394,40 @@ def warmup(self, batch: Batch) -> int: raise ValueError( f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length." ) - self.prefill(batch) - self.clear() - return batch_size * self.model.config.sequence_length + + # Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible + # batch sizes and sequence lengths. + seq_len = self.model.config.sequence_length + bucket_seq_len = take_nearest_length(seq_len) + requests = [self._create_dummy_request(seq_len) for _ in range(batch_size)] + for _ in reversed(range(batch_size)): + # Prefill with different truncate sizes to test all prefill lengths. List is reversed so first longest + # sequences are tested and, if there is a memory failure, that will appear sooner. + for l in reversed(PREFILL_LENGTHS): + # Skip all the unsupported lengths + if l > bucket_seq_len: + continue + # Set all truncate values for all requests + for r in requests: + r.truncate = l + r.stopping_parameters.max_new_tokens = 10 + warmup_batch = Batch(id=0, + requests=requests, + size=len(requests), + max_tokens=batch.max_tokens) + logger.debug(f"Warmup for {len(requests)} requests, truncate value {l} seq_len {seq_len}") + _generations, next_batch = self.prefill(warmup_batch) + if next_batch is not None: + self.decode([next_batch]) + else: + logger.debug(f"No decode on warmup for {len(requests)}x{l}") + self.clear() + # remove the last requests to decrease the batch size + requests.pop() + + elapsed = time.time() - start + logger.debug(f"Warmup done, took {elapsed:.2f}s") + return batch_size * seq_len @torch.no_grad def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: @@ -365,18 +464,20 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # for unfinished requests. # Prepare inputs. They need to be tokenized and truncated afterwards. - tokenized_inputs = [] - seq_length = 0 - for _, slot in enumerate(self.slots): - cur_input = slot.cached_text - tokenized_input = self.tokenizer(cur_input, return_tensors="pt") - # Truncate the input to truncation in the slot (coming from request) and the maximum sequence length - tokenized_input.input_ids = tokenized_input.input_ids[:, -slot.truncate:][:self.model.config.sequence_length] - tokenized_input.attention_mask = tokenized_input.attention_mask[:, -slot.truncate:][:self.model.config.sequence_length] - tokenized_inputs.append(tokenized_input) - # Update the maximum sequence length - seq_length = max(seq_length, tokenized_input.input_ids.size(-1)) - + max_len = 0 + batch_inputs = [] + for slot in self.slots: + batch_inputs.append(slot.cached_text) + max_len = max(max_len, slot.truncate) + if max_len == 0: + max_len = self.model.config.sequence_length + tokenized_inputs = self.tokenizer(batch_inputs, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_len) + seq_length = tokenized_inputs.input_ids.size(-1) + seq_length = min(seq_length, self.model.config.sequence_length) batch_size = len(self.slots) # Initialize input_ids and attention_mask with padding (to make them all the same size) input_ids = torch.full((batch_size, seq_length), self.tokenizer.pad_token_id, dtype=torch.int64) @@ -390,7 +491,11 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # Each slot must be reset with the padded inputs and masks for i, slot in enumerate(self.slots): assert slot.state != slot.state.EMPTY - input_ids[i, : tokenized_inputs[i].input_ids.size(-1)] = tokenized_inputs[i].input_ids + + truncation = min(tokenized_inputs.input_ids.size(-1), input_ids.size(-1)) + if slot.truncate > 0: + truncation = min(truncation, slot.truncate) + input_ids[i, -truncation:] = tokenized_inputs.input_ids[i, -truncation:] slot_input_ids = input_ids[i : i + 1, :] # Padded input ids are also required to set logits processors and stopping criterias selector = TokenSelector.create( @@ -401,7 +506,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: seed=slot.seed, ) slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) - attention_mask[i, : tokenized_inputs[i].attention_mask.size(-1)] = tokenized_inputs[i].attention_mask + attention_mask[i, -truncation:] = tokenized_inputs.attention_mask[i, -truncation:] if self._supports_static_cache: # Attention mask does not need to be tracked when using static cache slot_attention_mask = None @@ -412,6 +517,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: self.past_key_values = None # Obtain position ids using attention mask. position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) + # Save position id for every slot + for slot, position_id in zip(self.slots, position_ids): + slot.position_id = position_id.max().item() + 1 extra_args = {} if self._supports_static_cache: @@ -484,6 +592,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa fill_value=self.tokenizer.eos_token_id, dtype=torch.int64, ) + cache_position = torch.zeros([1], dtype=torch.int64) for i, slot in enumerate(self.slots): if slot.state != Slot.State.EMPTY: # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) @@ -497,7 +606,8 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa dtype=torch.int64, ) attention_mask.index_put_([torch.tensor([i])], slot.attention_mask) - position_ids.index_put_([torch.tensor([i])], torch.tensor(slot.cur_position)) + position_ids.index_put_([torch.tensor([i])], torch.tensor(slot.position_id)) + cache_position = torch.maximum(cache_position, torch.tensor([slot.cache_position])) if input_ids is None: raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") extra_args = {} @@ -506,13 +616,17 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa else: extra_args["attention_mask"] = attention_mask.to(self.model.device) extra_args["past_key_values"] = self.past_key_values - return self._generate_token( + generations, next_batch = self._generate_token( next_batch_id, input_ids.to(self.model.device), self.model_one_token, position_ids=position_ids.to(self.model.device), **extra_args, ) + for slot, gen in zip(self.slots, generations): + slot.position_id += len(gen.tokens.ids) + + return generations, next_batch def _generate_token( self, next_batch_id: int, input_ids: torch.LongTensor, model: torch.nn.Module, **forward_extra_params @@ -667,6 +781,16 @@ def _mp_fn( # create agent mailbox out of root's one mailbox = AgentMailbox(root_mailbox) + # re-init logger for each child process + logger.logger.add( + sys.stdout, + format="{message}", + filter="text_generation_server", + level="DEBUG", + backtrace=True, + diagnose=False, + ) + logger.debug( f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} " + f"world size {world_size}" @@ -690,33 +814,44 @@ def return_to_caller(*data): xm.rendezvous("wait_command") command, data = mailbox.command_data logger.debug(f"Generator@{rank} {command.name}") - if command == GeneratorCommand.INFO: - info = generator.info - return_to_caller(info.SerializeToString()) - if command == GeneratorCommand.WARMUP: - batch = Batch.FromString(data[0]) - return_to_caller(generator.warmup(batch=batch)) - if command == GeneratorCommand.PREFILL: - batch = Batch.FromString(data[0]) - generations, cached_batch = generator.prefill(batch=batch) - s_cached_batch = cached_batch.SerializeToString() if cached_batch is not None else None - return_to_caller([g.SerializeToString() for g in generations], s_cached_batch) - if command == GeneratorCommand.DECODE: - batches = [CachedBatch.FromString(b) for b in data[0]] - generations, cached_batch = generator.decode(batches=batches) - s_cached_batch = cached_batch.SerializeToString() if cached_batch is not None else None - return_to_caller([g.SerializeToString() for g in generations], s_cached_batch) - if command == GeneratorCommand.FILTER: - batch_id, request_ids = data - cached_batch = generator.filter(batch_id, request_ids) - return_to_caller(cached_batch.SerializeToString()) - if command == GeneratorCommand.CLEAR: - generator.clear() - if command == GeneratorCommand.DELETE: - if rank == 0: - # Set agent to ready - mailbox.agent_ready.set() - break + try: + if command == GeneratorCommand.INFO: + info = generator.info + return_to_caller(info.SerializeToString()) + if command == GeneratorCommand.WARMUP: + batch = Batch.FromString(data[0]) + return_to_caller(generator.warmup(batch=batch)) + if command == GeneratorCommand.PREFILL: + batch = Batch.FromString(data[0]) + generations, cached_batch = generator.prefill(batch=batch) + s_cached_batch = cached_batch.SerializeToString() if cached_batch is not None else None + return_to_caller([g.SerializeToString() for g in generations], s_cached_batch) + if command == GeneratorCommand.DECODE: + batches = [CachedBatch.FromString(b) for b in data[0]] + generations, cached_batch = generator.decode(batches=batches) + s_cached_batch = cached_batch.SerializeToString() if cached_batch is not None else None + return_to_caller([g.SerializeToString() for g in generations], s_cached_batch) + if command == GeneratorCommand.FILTER: + batch_id, request_ids = data + cached_batch = generator.filter(batch_id, request_ids) + return_to_caller(cached_batch.SerializeToString()) + if command == GeneratorCommand.CLEAR: + generator.clear() + if command == GeneratorCommand.DELETE: + if rank == 0: + # Set agent to ready + mailbox.agent_ready.set() + break + except Exception as e: + logger.error(f"Error in command {command.name}") + mailbox.agent_error.set() + mailbox.agent_ready.set() + exc_info = sys.exc_info() + logger.error(''.join(traceback.format_exception(*exc_info))) + raise e + # If error was only happening on one of the threads, all of them should exit + if mailbox.agent_error.is_set(): + return def model_loop_fn(*args): diff --git a/text-generation-inference/server/text_generation_server/server.py b/text-generation-inference/server/text_generation_server/server.py index f4771cf0..e53e8e14 100644 --- a/text-generation-inference/server/text_generation_server/server.py +++ b/text-generation-inference/server/text_generation_server/server.py @@ -17,15 +17,19 @@ def __init__(self, generator: Generator, server_urls: List[str]): self.server_urls = server_urls async def Info(self, request, context): + logger.debug("Info") return self.generator.info async def Health(self, request, context): + logger.debug("Health") return generate_pb2.HealthResponse() async def ServiceDiscovery(self, request, context): + logger.debug("ServiceDiscovery") return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): + logger.debug("ClearCache") if request.HasField("id"): self.generator.clear(request.id) else: @@ -33,18 +37,25 @@ async def ClearCache(self, request, context): return generate_pb2.ClearCacheResponse() async def FilterBatch(self, request, context): + logger.debug("FilterBatch") filtered_batch = self.generator.filter(request.batch_id, request.request_ids) return generate_pb2.FilterBatchResponse(batch=filtered_batch) async def Warmup(self, request, context): + logger.info("Warmup (this can take several minutes)") max_tokens = self.generator.warmup(request.batch) - return generate_pb2.WarmupResponse(max_supported_total_tokens=max_tokens) + ret = generate_pb2.WarmupResponse(max_supported_total_tokens=max_tokens) + logger.info("Warmup done") + return ret async def Prefill(self, request, context): + logger.debug("Prefill") + batch = request.batch generations, batch = self.generator.prefill(request.batch) return generate_pb2.PrefillResponse(generations=generations, batch=batch) async def Decode(self, request, context): + logger.debug("Decode") generations, batch = self.generator.decode(request.batches) return generate_pb2.DecodeResponse(generations=generations, batch=batch) diff --git a/text-generation-inference/tests/test_prefill_truncate.py b/text-generation-inference/tests/test_prefill_truncate.py index cac8db1f..e3666c66 100644 --- a/text-generation-inference/tests/test_prefill_truncate.py +++ b/text-generation-inference/tests/test_prefill_truncate.py @@ -13,23 +13,23 @@ def test_prefill_truncate(): generator = TpuGenerator.from_pretrained( model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length ) - input_text = "This is a secret part. Once upon a time," + input_text = "This is something I will tell by the end of the story" request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) generations, _ = generator.prefill(batch) assert len(generations) == 1 - assert generations[0].tokens.ids == [635] - assert generations[0].tokens.texts == [" there"] + assert generations[0].tokens.ids == [31843] + assert generations[0].tokens.texts == ["."] # Now re-test but with truncate generator.clear() request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) - # This will only leave 5 tokens - request.truncate = 5 + # This will only leave last tokens + request.truncate = 6 batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) generations, _ = generator.prefill(batch) assert len(generations) == 1 - assert generations[0].tokens.ids == [260] - assert generations[0].tokens.texts == [" a"] + assert generations[0].tokens.ids == [291] + assert generations[0].tokens.texts == [" and"]