trlX allows you to fine-tune 🤗 Hugging Face supported language models of up to 20B parameters (such as gpt2
, gpt-j
, and gpt-neox
, as well as T5 based models, including google/t5-v1_1
and google/flan-t5
) using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization (PPO) and Implicit Language Q-Learning (ILQL) are implemented.
You can read more about trlX in our documentation.
Want to collect human annotations for your RL application? Check out CHEESE!, our library for HiTL data collection.
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda
pip install -e .
You can train a model using a reward function or a reward-labeled dataset.
trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count('cats') for sample in samples])
trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])
trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
trainer.save_pretrained('/path/to/output/folder/')
🩹 Warning: Only the AcceleratePPOTrainer
can write HuggingFace transformers to disk with save_pretrained
at the moment, as ILQL trainers require inference behavior currently unsupported by available transformers
architectures.
accelerate config # choose DeepSpeed option
accelerate launch examples/simulacra.py
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py
For more usage see examples
For development check out these guidelines and also read our docs
Many thanks to Leandro von Werra for contributing with trl, a library that initially inspired this repo.