-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add static KV cache and test on Gemma-2B (#4)
* chore(style): run make style * chore(style): update pyproject to avoid ruff warning * fix(tgi): sequence length should be based on sequence_length config It was previously using n_positions sometimes, but that would not be available on some model configs. * feat(modeling): model is immediately loaded on device * debug: added env var to debug on CPU if DBG_DEVICE env var is set, it will used to set the device for the model. * feat(test): reduce overhad when retrieving model This will avoid loading the model twice. * feat(modeling): make compilation optional Make compilation optional, it can be enabled with the environment variable DBG_COMPILE. This is because: 1. There are some models that produce bugs when the model is compiled. (notably gemma). 2. Models inference input params shapes change, triggering recompilation, leading to slow performance. 3. With the added xm.mark_step, performance is actually better when the model is not compiled. XLA builds a graph anyway, so performance is going to be good. * feat: add @torch.no_grad decorators to decode and prefill This is to reduce useless gradient calculations. * chore(generator): create buffers in device to avoid moving them * refactor(generator): some model params are passed as dict This will allow to handle passing different params in different model configurations later. * feat: use static KV cache when available Some models, like Gemma and Llama, support static KV cache in transformers. For these, it is possible to use this feature, leading to much higher performance. * fix(CI): added HF_TOKEN to use models that require it Also manually install accelerate to avoid memory issues when loading gemma. * fix(CI): adapt expected result in do_sample test The test produces different results after some operations are being done in a slightly different order.
- Loading branch information
1 parent
4c75f88
commit fdcd7ea
Showing
10 changed files
with
216 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.