diff --git a/optimum_benchmark/backends/vllm/backend.py b/optimum_benchmark/backends/vllm/backend.py index e90f3e7e..9ad36b9a 100644 --- a/optimum_benchmark/backends/vllm/backend.py +++ b/optimum_benchmark/backends/vllm/backend.py @@ -117,35 +117,31 @@ def batch_offline_engine_generate(self, inputs: Dict[str, Any], kwargs: Dict[str self.pretrained_model.add_request( inputs=prompt, request_id=str(i), - params=SamplingParams( - ignore_eos=True, - detokenize=True, - seed=self.config.seed, - n=kwargs.get("num_return_sequences"), - max_tokens=kwargs.get("max_new_tokens"), - min_tokens=kwargs.get("min_new_tokens"), - use_beam_search=kwargs.get("num_beams") > 1, - logits_processors=kwargs.get("logits_processors", None), - ), + params=self.get_sampling_params(kwargs), ) while self.pretrained_model.has_unfinished_requests(): self.pretrained_model.step() + def get_sampling_params(self, kwargs: Dict[str, Any]) -> SamplingParams: + params = SamplingParams( + ignore_eos=True, + detokenize=True, + seed=self.config.seed, + n=kwargs.get("num_return_sequences"), + max_tokens=kwargs.get("max_new_tokens"), + min_tokens=kwargs.get("min_new_tokens"), + logits_processors=kwargs.get("logits_processors", None), + ) + if kwargs.get("num_beams") > 1: + params.logprobs = 2 * kwargs.get("num_beams") + return params + async def single_online_engine_generate(self, prompt: str, request_id: str, kwargs: Dict[str, Any]) -> Any: stream = await self.pretrained_model.add_request( inputs=prompt, request_id=request_id, - params=SamplingParams( - ignore_eos=True, - detokenize=True, - seed=self.config.seed, - n=kwargs.get("num_return_sequences"), - max_tokens=kwargs.get("max_new_tokens"), - min_tokens=kwargs.get("min_new_tokens"), - use_beam_search=kwargs.get("num_beams") > 1, - logits_processors=kwargs.get("logits_processors", None), - ), + params=self.get_sampling_params(), ) async for _ in stream: diff --git a/setup.py b/setup.py index cff0d197..afebbcfd 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,6 @@ "Please install amdsmi from https://github.com/ROCm/amdsmi to enable this feature." ) - EXTRAS_REQUIRE = { "quality": ["ruff"], "testing": ["pytest", "hydra-joblib-launcher"],