Skip to content

Commit

Permalink
Feat: reimplement vllm backend beam search using logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
vicoooo26 committed Oct 23, 2024
1 parent 561deca commit 86d41ee
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 21 deletions.
36 changes: 16 additions & 20 deletions optimum_benchmark/backends/vllm/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,35 +117,31 @@ def batch_offline_engine_generate(self, inputs: Dict[str, Any], kwargs: Dict[str
self.pretrained_model.add_request(
inputs=prompt,
request_id=str(i),
params=SamplingParams(
ignore_eos=True,
detokenize=True,
seed=self.config.seed,
n=kwargs.get("num_return_sequences"),
max_tokens=kwargs.get("max_new_tokens"),
min_tokens=kwargs.get("min_new_tokens"),
use_beam_search=kwargs.get("num_beams") > 1,
logits_processors=kwargs.get("logits_processors", None),
),
params=self.get_sampling_params(kwargs),
)

while self.pretrained_model.has_unfinished_requests():
self.pretrained_model.step()

def get_sampling_params(self, kwargs: Dict[str, Any]) -> SamplingParams:
params = SamplingParams(
ignore_eos=True,
detokenize=True,
seed=self.config.seed,
n=kwargs.get("num_return_sequences"),
max_tokens=kwargs.get("max_new_tokens"),
min_tokens=kwargs.get("min_new_tokens"),
logits_processors=kwargs.get("logits_processors", None),
)
if kwargs.get("num_beams") > 1:
params.logprobs = 2 * kwargs.get("num_beams")
return params

async def single_online_engine_generate(self, prompt: str, request_id: str, kwargs: Dict[str, Any]) -> Any:
stream = await self.pretrained_model.add_request(
inputs=prompt,
request_id=request_id,
params=SamplingParams(
ignore_eos=True,
detokenize=True,
seed=self.config.seed,
n=kwargs.get("num_return_sequences"),
max_tokens=kwargs.get("max_new_tokens"),
min_tokens=kwargs.get("min_new_tokens"),
use_beam_search=kwargs.get("num_beams") > 1,
logits_processors=kwargs.get("logits_processors", None),
),
params=self.get_sampling_params(),
)

async for _ in stream:
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
"Please install amdsmi from https://github.com/ROCm/amdsmi to enable this feature."
)


EXTRAS_REQUIRE = {
"quality": ["ruff"],
"testing": ["pytest", "hydra-joblib-launcher"],
Expand Down

0 comments on commit 86d41ee

Please sign in to comment.