Skip to content

Commit

Permalink
Merge branch 'main' into huiyingl/peftmerge
Browse files Browse the repository at this point in the history
Signed-off-by: Huiying Li <willwin.lee@gmail.com>
  • Loading branch information
HuiyingLi committed Nov 17, 2024
2 parents 9992509 + 7d03e1f commit 75ea08e
Show file tree
Hide file tree
Showing 13 changed files with 543 additions and 109 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ on:

jobs:
release:
uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.10.0
uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.12.3
with:
release-ref: ${{ inputs.release-ref }}
image-name: nemo_container
Expand All @@ -39,3 +39,4 @@ jobs:
TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }}
TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
SLACK_RELEASE_ENDPOINT: ${{ secrets.SLACK_RELEASE_ENDPOINT }}
PAT: ${{ secrets.PAT }}
7 changes: 7 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,10 @@
__all__.append("deploy")
except ImportError as error:
logging.warning(f"The deploy module could not be imported: {error}")

try:
from nemo.collections.llm.api import evaluate

__all__.append("evaluate")
except ImportError as error:
logging.warning(f"The evaluate module could not be imported: {error}")
207 changes: 126 additions & 81 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from copy import deepcopy
from pathlib import Path
Expand Down Expand Up @@ -256,84 +255,13 @@ def validate(
return app_state.exp_dir


def get_trtllm_deployable(
nemo_checkpoint,
model_type,
triton_model_repository,
num_gpus,
tensor_parallelism_size,
pipeline_parallelism_size,
max_input_len,
max_output_len,
max_batch_size,
dtype,
):
from nemo.export.tensorrt_llm import TensorRTLLM

if triton_model_repository is None:
trt_llm_path = "/tmp/trt_llm_model_dir/"
Path(trt_llm_path).mkdir(parents=True, exist_ok=True)
else:
trt_llm_path = triton_model_repository

if nemo_checkpoint is None and triton_model_repository is None:
raise ValueError(
"The provided model repository is not a valid TensorRT-LLM model "
"directory. Please provide a --nemo_checkpoint or a TensorRT-LLM engine."
)

if nemo_checkpoint is None and not os.path.isdir(triton_model_repository):
raise ValueError(
"The provided model repository is not a valid TensorRT-LLM model "
"directory. Please provide a --nemo_checkpoint or a valid TensorRT-LLM engine."
)

if nemo_checkpoint is not None and model_type is None:
raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.")

trt_llm_exporter = TensorRTLLM(
model_dir=trt_llm_path,
load_model=(nemo_checkpoint is None),
)

if nemo_checkpoint is not None:
try:
logging.info("Export operation will be started to export the nemo checkpoint to TensorRT-LLM.")
trt_llm_exporter.export(
nemo_checkpoint_path=nemo_checkpoint,
model_type=model_type,
n_gpus=num_gpus,
tensor_parallelism_size=tensor_parallelism_size,
pipeline_parallelism_size=pipeline_parallelism_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
dtype=dtype,
)
except Exception as error:
raise RuntimeError("An error has occurred during the model export. Error message: " + str(error))

return trt_llm_exporter


def store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response):
args_dict = {
"triton_service_ip": triton_http_address,
"triton_service_port": triton_port,
"triton_request_timeout": triton_request_timeout,
"openai_format_response": openai_format_response,
}
with open("nemo/deploy/service/config.json", "w") as f:
json.dump(args_dict, f)


@run.cli.entrypoint(namespace="llm")
def deploy(
nemo_checkpoint: Path = None,
model_type: str = "llama",
triton_model_name: str = "xxx",
triton_model_name: str = 'triton_model',
triton_model_version: Optional[int] = 1,
triton_port: int = 8080,
triton_port: int = 8000,
triton_http_address: str = "0.0.0.0",
triton_request_timeout: int = 60,
triton_model_repository: Path = None,
Expand All @@ -344,21 +272,61 @@ def deploy(
max_input_len: int = 256,
max_output_len: int = 256,
max_batch_size: int = 8,
start_rest_service: bool = False,
start_rest_service: bool = True,
rest_service_http_address: str = "0.0.0.0",
rest_service_port: int = 8000,
openai_format_response: bool = False,
rest_service_port: int = 8080,
openai_format_response: bool = True,
output_generation_logits: bool = True,
):
"""
Deploys nemo model on a PyTriton server by converting the nemo ckpt to trtllm.
Also starts rest service that is used to send OpenAI API compatible input request
to the PyTiton server.
Args:
nemo_checkpoint (Path): Path for nemo checkpoint.
model_type (str): Type of the model. Choices: gpt, llama, falcon, starcoder. Default: llama.
triton_model_name (str): Name for the model that gets deployed on PyTriton. Please ensure that the same model
name is passed to the evalute method for the model to be accessible while sending evalution requests.
Default: 'triton_model'.
triton_model_version (Optional[int]): Version for the triton model. Default: 1.
triton_port (int): Port for the PyTriton server. Default: 8000.
triton_http_address (str): HTTP address for the PyTriton server. Default: "0.0.0.0".
triton_request_timeout (int): Timeout in seconds for Triton server. Default: 60.
triton_model_repository (Path): Folder for the trt-llm conversion, trt-llm engine gets saved in this specified
path. If None, saves it in /tmp dir. Default: None.
num_gpus (int): Number of GPUs for export to trtllm and deploy. Default: 1.
tensor_parallelism_size (int): Tensor parallelism size. Default: 1.
pipeline_parallelism_size (int): Pipeline parallelism size. Default: 1.
dtype (str): dtype of the TensorRT-LLM model. Default: "bfloat16".
max_input_len (int): Max input length of the model. Default: 256.
max_output_len (int): Max output length of the model. Default: 256.
max_batch_size (int): Max batch size of the model. Default: 8.
start_rest_service (bool): Start rest service that is used to send evaluation requests to the PyTriton server.
Needs to be True to be able to run evaluation. Default: True.
rest_service_http_address (str): HTTP address for the rest service. Default: "0.0.0.0".
rest_service_port (int): Port for the rest service. Default: 8080.
openai_format_response (bool): Return the response from PyTriton server in OpenAI compatible format. Needs to
be True while running evaluation. Default: True.
output_generation_logits (bool): If True builds trtllm engine with gather_generation_logits set to True.
generation_logits are used to compute the logProb of the output token. Default: True.
"""
from nemo.collections.llm import deploy
from nemo.deploy import DeployPyTriton

deploy.unset_environment_variables()
if start_rest_service:
if triton_port == rest_service_port:
logging.error("REST service port and Triton server port cannot use the same port.")
return
# Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py
store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response)

triton_deployable = get_trtllm_deployable(
# Store triton ip, port and other args relevant for REST API as env vars to be accessible by rest_model_api.py
os.environ['TRITON_HTTP_ADDRESS'] = triton_http_address
os.environ['TRITON_PORT'] = str(triton_port)
os.environ['TRITON_REQUEST_TIMEOUT'] = str(triton_request_timeout)
os.environ['OPENAI_FORMAT_RESPONSE'] = str(openai_format_response)
os.environ['OUTPUT_GENERATION_LOGITS'] = str(output_generation_logits)

triton_deployable = deploy.get_trtllm_deployable(
nemo_checkpoint,
model_type,
triton_model_repository,
Expand All @@ -369,6 +337,7 @@ def deploy(
max_output_len,
max_batch_size,
dtype,
output_generation_logits,
)

try:
Expand All @@ -383,6 +352,7 @@ def deploy(

logging.info("Triton deploy function will be called.")
nm.deploy()
nm.run()
except Exception as error:
logging.error("Error message has occurred during deploy function. Error message: " + str(error))
return
Expand Down Expand Up @@ -416,6 +386,81 @@ def deploy(
nm.stop()


def evaluate(
nemo_checkpoint_path: Path,
url: str = "http://0.0.0.0:8080/v1",
model_name: str = "triton_model",
eval_task: str = "gsm8k",
num_fewshot: Optional[int] = None,
limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000,
# inference params
max_tokens_to_generate: Optional[int] = 256,
temperature: Optional[float] = 0.000000001,
top_p: Optional[float] = 0.0,
top_k: Optional[int] = 1,
add_bos: Optional[bool] = False,
):
"""
Evaluates nemo model deployed on PyTriton server (via trtllm) using lm-evaluation-harness
(https://github.com/EleutherAI/lm-evaluation-harness/tree/main).
Args:
nemo_checkpoint_path (Path): Path for nemo 2.0 checkpoint. This is used to get the tokenizer from the ckpt
which is required to tokenize the evaluation input and output prompts.
url (str): rest service url and port that were used in the deploy method above in the format:
http://{rest_service_http}:{rest_service_port}. Post requests with evaluation input prompts
(from lm-eval-harness) are sent to this url which is then passed to the model deployed on PyTriton server.
The rest service url and port serve as the entry point to evaluate model deployed on PyTriton server.
model_name (str): Name of the model that is deployed on PyTriton server. It should be the same as
triton_model_name passed to the deploy method above to be able to launch evaluation. Deafult: "triton_model".
eval_task (str): task to be evaluated on. For ex: "gsm8k", "gsm8k_cot", "mmlu", "lambada". Default: "gsm8k".
These are the tasks that are supported currently. Any other task of type generate_until or loglikelihood from
lm-evaluation-harness can be run, but only the above mentioned ones are tested. Tasks of type
loglikelihood_rolling are not supported yet.
num_fewshot (int): number of examples in few-shot context. Default: None.
limit (Union[int, float]): Limit the number of examples per task. If <1 (i.e float val between 0 and 1), limit
is a percentage of the total number of examples. If int say x, then run evaluation only on x number of samples
from the eval dataset. Default: None, which means eval is run the entire dataset.
bootstrap_iters (int): Number of iterations for bootstrap statistics, used when calculating stderrs. Set to 0
for no stderr calculations to be performed. Default: 100000.
# inference params
max_tokens_to_generate (int): max tokens to generate. Default: 256.
temperature: Optional[float]: float value between 0 and 1. temp of 0 indicates greedy decoding, where the token
with highest prob is chosen. Temperature can't be set to 0.0 currently, due to a bug with TRTLLM
(# TODO to be investigated). Hence using a very samll value as the default. Default: 0.000000001.
top_p: Optional[float]: float value between 0 and 1. limits to the top tokens within a certain probability.
top_p=0 means the model will only consider the single most likely token for the next prediction. Default: 0.0.
top_k: Optional[int]: limits to a certain number (K) of the top tokens to consider. top_k=1 means the model
will only consider the single most likely token for the next prediction. Default: 1
add_bos: Optional[bool]: whether a special token representing the beginning of a sequence should be added when
encoding a string. Default: False since typically for CausalLM its set to False. If needed set add_bos to True.
"""
try:
# lm-evaluation-harness import
from lm_eval import evaluator
except ImportError:
raise ImportError(
"Please ensure that lm-evaluation-harness is installed in your env as it is required " "to run evaluations"
)

from nemo.collections.llm import evaluation

# Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt.
tokenizer = io.load_context(nemo_checkpoint_path + '/context', subpath="model").tokenizer
# Wait for rest service to be ready before starting evaluation
evaluation.wait_for_rest_service(rest_url=f"{url}/v1/health")
# Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate
model = evaluation.NeMoFWLMEval(
model_name, url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos
)
results = evaluator.simple_evaluate(
model=model, tasks=eval_task, limit=limit, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters
)

print("score", results['results'][eval_task])


@run.cli.entrypoint(name="import", namespace="llm")
def import_ckpt(
model: pl.LightningModule,
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/llm/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from nemo.collections.llm.deploy.base import get_trtllm_deployable, unset_environment_variables

__all__ = ["unset_environment_variables", "get_trtllm_deployable"]
Loading

0 comments on commit 75ea08e

Please sign in to comment.