Skip to content

Commit

Permalink
add optimum-intel ipex backend into benchmark (#250)
Browse files Browse the repository at this point in the history
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
  • Loading branch information
yao-matrix authored Aug 30, 2024
1 parent 0b69851 commit 4d3f68e
Show file tree
Hide file tree
Showing 15 changed files with 302 additions and 4 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/test_cli_cpu_ipex.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: CLI CPU IPEX Tests

on:
workflow_dispatch:
push:
branches:
- main
paths:
- .github/workflows/test_cli_cpu_ipex.yaml
- "optimum_benchmark/**"
- "docker/**"
- "tests/**"
- "setup.py"
pull_request:
branches:
- main
paths:
- .github/workflows/test_cli_cpu_ipex.yaml
- "optimum_benchmark/**"
- "docker/**"
- "tests/**"
- "setup.py"

concurrency:
cancel-in-progress: true
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}

jobs:
run_cli_cpu_ipex_tests:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install requirements
run: |
pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -e .[testing,ipex,diffusers,timm]
- name: Run tests
run: pytest -s -k "cli and cpu and ipex"
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pip install -e .
Depending on the backends you want to use, you can install `optimum-benchmark` with the following extras:

- PyTorch (default): `pip install optimum-benchmark`
- IPEX: `pip install optimum-benchmark[ipex]`
- OpenVINO: `pip install optimum-benchmark[openvino]`
- Torch-ORT: `pip install optimum-benchmark[torch-ort]`
- OnnxRuntime: `pip install optimum-benchmark[onnxruntime]`
Expand Down
6 changes: 4 additions & 2 deletions docker/cpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ ENV PATH="/home/user/.local/bin:${PATH}"
RUN apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-dev libglib2.0-0 \
sudo build-essential git bash-completion \
python3.10 python3-pip python3.10-dev && \
python3.10 python3-pip python3.10-dev google-perftools && \
apt-get clean && rm -rf /var/lib/apt/lists/* && \
update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 && \
pip install --no-cache-dir --upgrade pip setuptools wheel
pip install --no-cache-dir --upgrade pip setuptools wheel

ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4"

# Install PyTorch
ARG TORCH_VERSION=stable
Expand Down
37 changes: 37 additions & 0 deletions examples/ipex_llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
defaults:
- benchmark
- scenario: inference
- launcher: process
- backend: ipex
- _base_
- _self_

name: ipex_llama

launcher:
numactl: true
numactl_kwargs:
cpunodebind: 0
membind: 0

scenario:
latency: true
memory: true

warmup_runs: 10
iterations: 10
duration: 10

input_shapes:
batch_size: 1
sequence_length: 256
generate_kwargs:
max_new_tokens: 32
min_new_tokens: 32

backend:
device: cpu
export: true
no_weights: true
torch_dtype: bfloat16
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
2 changes: 2 additions & 0 deletions optimum_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .backends import (
BackendConfig,
IPEXConfig,
INCConfig,
LlamaCppConfig,
LLMSwarmConfig,
Expand All @@ -24,6 +25,7 @@
"BenchmarkReport",
"EnergyStarConfig",
"InferenceConfig",
"IPEXConfig",
"INCConfig",
"InlineConfig",
"LauncherConfig",
Expand Down
2 changes: 2 additions & 0 deletions optimum_benchmark/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .config import BackendConfig
from .llama_cpp.config import LlamaCppConfig
from .llm_swarm.config import LLMSwarmConfig
from .ipex.config import IPEXConfig
from .neural_compressor.config import INCConfig
from .onnxruntime.config import ORTConfig
from .openvino.config import OVConfig
Expand All @@ -13,6 +14,7 @@
__all__ = [
"PyTorchConfig",
"ORTConfig",
"IPEXConfig",
"OVConfig",
"TorchORTConfig",
"TRTLLMConfig",
Expand Down
Empty file.
131 changes: 131 additions & 0 deletions optimum_benchmark/backends/ipex/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import inspect
from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Any, Dict

import torch
from hydra.utils import get_class

from ...generators.dataset_generator import DatasetGenerator
from ...import_utils import is_accelerate_available, is_torch_distributed_available
from ...task_utils import TEXT_GENERATION_TASKS
from ..base import Backend
from ..transformers_utils import fast_weights_init
from .config import IPEXConfig
from .utils import TASKS_TO_IPEXMODEL

if is_accelerate_available():
from accelerate import Accelerator

if is_torch_distributed_available():
import torch.distributed


class IPEXBackend(Backend[IPEXConfig]):
NAME: str = "ipex"

def __init__(self, config: IPEXConfig) -> None:
super().__init__(config)

if self.config.task in TASKS_TO_IPEXMODEL:
self.ipexmodel_class = get_class(TASKS_TO_IPEXMODEL[self.config.task])
self.logger.info(f"\t+ Using IPEXModel class {self.ipexmodel_class.__name__}")
else:
raise NotImplementedError(f"IPEXBackend does not support task {self.config.task}")


def load(self) -> None:
self.logger.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()

if self.config.no_weights:
self.logger.info("\t+ Creating no weights IPEXModel")
self.create_no_weights_model()
self.logger.info("\t+ Loading no weights IPEXModel")
self._load_ipexmodel_with_no_weights()
else:
self.logger.info("\t+ Loading pretrained IPEXModel")
self._load_ipexmodel_from_pretrained()

self.tmpdir.cleanup()

def _load_automodel_from_pretrained(self) -> None:
self.pretrained_model = self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs)

def _load_automodel_with_no_weights(self) -> None:
original_model, self.config.model = self.config.model, self.no_weights_model

with fast_weights_init():
self._load_automodel_from_pretrained()

self.logger.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()

self.config.model = original_model

def _load_ipexmodel_from_pretrained(self) -> None:
self.pretrained_model = self.ipexmodel_class.from_pretrained(
self.config.model,
export=self.config.export,
device=self.config.device,
**self.config.model_kwargs,
**self.automodel_kwargs,
)

def _load_ipexmodel_with_no_weights(self) -> None:
with fast_weights_init():
original_model, self.config.model = self.config.model, self.no_weights_model
original_export, self.config.export = self.config.export, True
self.logger.info("\t+ Loading no weights IPEXModel")
self._load_ipexmodel_from_pretrained()
self.config.export = original_export
self.config.model = original_model

@property
def automodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if self.config.torch_dtype is not None:
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)

print(kwargs)

return kwargs

@property
def is_dp_distributed(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0:
raise ValueError(
f"Batch size {input_shapes['batch_size']} must be divisible by "
f"data parallel world size {torch.distributed.get_world_size()}"
)
# distributing batch size across processes
input_shapes["batch_size"] //= torch.distributed.get_world_size()

# registering input shapes for usage during model reshaping
self.input_shapes = input_shapes

return input_shapes

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

return inputs

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.forward(**inputs, **kwargs)

def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(**inputs, **kwargs)

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(**inputs, **kwargs)

def call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model(**inputs, **kwargs)
37 changes: 37 additions & 0 deletions optimum_benchmark/backends/ipex/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

from ...import_utils import ipex_version
from ..config import BackendConfig

TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"]

@dataclass
class IPEXConfig(BackendConfig):
name: str = "ipex"
version: Optional[str] = ipex_version()
_target_: str = "optimum_benchmark.backends.ipex.backend.IPEXBackend"

# load options
no_weights: bool = False
torch_dtype: Optional[str] = None

# export options
export: bool = True

def __post_init__(self):
super().__post_init__()

self.device = self.device.lower()
if self.device not in ["cpu", "gpu"]:
raise ValueError(f"IPEXBackend only supports CPU devices, got {self.device}")

if self.model_kwargs.get("torch_dtype", None) is not None:
raise ValueError(
"`torch_dtype` is an explicit argument in the PyTorch backend config. "
"Please remove it from the `model_kwargs` and set it in the backend config directly."
)

if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES:
raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")

10 changes: 10 additions & 0 deletions optimum_benchmark/backends/ipex/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
TASKS_TO_IPEXMODEL = {
"fill-mask": "optimum.intel.IPEXModelForMaskedLM",
"text-generation": "optimum.intel.IPEXModelForCausalLM",
"text-classification": "optimum.intel.IPEXModelForSequenceClassification",
"token-classification": "optimum.intel.IPEXModelForTokenClassification",
"question-answering": "optimum.intel.IPEXModelForQuestionAnswering",
"image-classification": "optimum.intel.IPEXModelForImageClassification",
"audio-classification": "optimum.intel.IPEXModelForAudioClassification",
}

2 changes: 2 additions & 0 deletions optimum_benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Benchmark,
BenchmarkConfig,
EnergyStarConfig,
IPEXConfig,
INCConfig,
InferenceConfig,
InlineConfig,
Expand All @@ -36,6 +37,7 @@
# benchmark configuration
cs.store(name="benchmark", node=BenchmarkConfig)
# backends configurations
cs.store(group="backend", name=IPEXConfig.name, node=IPEXConfig)
cs.store(group="backend", name=OVConfig.name, node=OVConfig)
cs.store(group="backend", name=PyTorchConfig.name, node=PyTorchConfig)
cs.store(group="backend", name=ORTConfig.name, node=ORTConfig)
Expand Down
5 changes: 4 additions & 1 deletion optimum_benchmark/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_pynvml_available = importlib.util.find_spec("pynvml") is not None
_torch_distributed_available = importlib.util.find_spec("torch.distributed") is not None
_onnxruntime_available = importlib.util.find_spec("onnxruntime") is not None
_ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None
_openvino_available = importlib.util.find_spec("openvino") is not None
_neural_compressor_available = importlib.util.find_spec("neural_compressor") is not None
_codecarbon_available = importlib.util.find_spec("codecarbon") is not None
Expand Down Expand Up @@ -157,11 +158,13 @@ def onnxruntime_version():
except importlib.metadata.PackageNotFoundError:
return None


def openvino_version():
if _openvino_available:
return importlib.metadata.version("openvino")

def ipex_version():
if _ipex_available:
return importlib.metadata.version("intel_extension_for_pytorch")

def neural_compressor_version():
if _neural_compressor_available:
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
except Exception as error:
assert False, "Error: Could not open '%s' due %s\n" % (filepath, error)

MIN_OPTIMUM_VERSION = "1.16.0"
MIN_OPTIMUM_VERSION = "1.18.0"
INSTALL_REQUIRES = [
# HF dependencies
"transformers",
Expand Down Expand Up @@ -69,6 +69,7 @@
"quality": ["ruff"],
"testing": ["pytest", "hydra-joblib-launcher"],
# optimum backends
"ipex":[f"optimum[ipex]>={MIN_OPTIMUM_VERSION}"],
"openvino": [f"optimum[openvino,nncf]>={MIN_OPTIMUM_VERSION}"],
"onnxruntime": [f"optimum[onnxruntime]>={MIN_OPTIMUM_VERSION}"],
"onnxruntime-gpu": [f"optimum[onnxruntime-gpu]>={MIN_OPTIMUM_VERSION}"],
Expand Down
11 changes: 11 additions & 0 deletions tests/configs/cpu_inference_ipex_text_decoders.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
# order of inheritance, last one overrides previous ones
- _base_ # inherits from base config
- _cpu_ # inherits from cpu config
- _inference_ # inherits from inference config
- _text_decoders_ # inherits from text decoders config
- _no_weights_ # inherits from no weights config
- _self_ # hydra 1.1 compatibility
- override backend: ipex

name: cpu_inference_ipex_text_decoders
Loading

0 comments on commit 4d3f68e

Please sign in to comment.