backprompt
provides a data structure which allows a user to dynamically
construct prompts while avoiding repeated LLM computations.
In many large-scale tasks performed by LLMs, a particular prompt is used many times—once for each instance of the task. In cases like these, the amount of computation performed by future LLM calls can be reduced by caching and re-using the LLM's representation of the prompt.
backprompt
takes this well-known idea a step further by additionally caching
LLM representations of intermediate text in the prompt. Intermediate caching
may be useful when one needs to dynamically adjust the prompt without having to
re-compute the LLM's representation of it. backprompt
abstracts the complex
process of prompt construction and caching as plain-old string concatenation.
See the notebook
demos/minimal_example.ipynb
for a more realistic use case. Here's a toy demo:
from transformers import AutoModelForCausalLM, AutoTokenizer
from backprompt import Text
# Load a GPT model and its tokenizer
model_name = 'gpt2'
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
mt = (model, tokenizer)
# Wrap strings in Text and construct them via concatenation
context = Text('Hello there.', mt)
choices = [Text(' Senator', mt), Text(' General', mt)]
endings = [Text(' Amidala', mt), Text(' Kenobi...', mt)]
texts = [context + choice + ending for choice in choices for ending in endings]
print(texts[-1].string)
# Hello there. General Kenobi...
# Get next-token logits by calling every text obj
# The punchline is that you don't have to worry about repeated computation
for text in texts:
text()
texts[-1].model_repr[1].logits[:, -1, :]
python -m pip install git+https://github.com/kddubey/backprompt.git
If you basically know how
backprop works (watch this
YouTube video), and you basically
know how a decoder-only autoregressive language model works (watch this YouTube
video), then you know how
backprompt
works :-)
Analogies:
- backprop → "intermediate" gradient of a function
backprompt
→ attention block keys and values. - backprop → gradient of a function
backprompt
→ token logits. - backprop → chain rule
backprompt
→ tensor concatenation.
TODO: graph visualization
TODO: expand test cases
pytest
Research
- What's the computational complexity of using past keys and values wrt # tokens?
- Do few-shot prompts exhibit interesting independencies? If so, one could construct prompts using different examples on the fly.
Code
- Expand tests
- More autoregressive LMs
- More string breakdowns
- Graph visualization
- Allow for frozen representations / custom independencies in the graph
- Batching
- Eager mode
-
ModelRepr
dataclass for convenience- Add and update a
token_logprobs
attribute to the LM output obj - By default, only keep last (non-pad) token's logits in the LM output obj
- Add and update a
- Documentation?