You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I wanted to create a single issue that collects TODOs around the LLM feature that was added recently, so that everything is in one place.
Have a "fast/greedy" option for predict - it is not necessary to calculate probabilities for all classes up until the last token. When class A has probability p_A and class B has p_B(t) < p_A at token t, then no matter what the probability for tokens >t, p_B cannot exceed p_A anymore.
A small use case where the classifiers are used as a transformer in a bigger pipeline, e.g. to extract structured knowledge from a text ("Does this product description contain the size of the item?")
A way to format the text/labels/few-shot samples before they're string-interpolated, maybe Jinja2?
Test if this works with a more diverse range of LLMs
Enable multi-label classification. Would probably require sigmoid instead of softmax and a (empirically determined?) threshold.
Check if it is possible to enable caching for encoder-decoder LLMs like flan-t5.
Sampling strategy for few-shot learning:
Right now, the sampling is hard-coded and basically tries to add each label at least once. This seems reasonable but there are situations where other strategies could make sense. Therefore, I would like to see a feature that allows setting the sampling strategy as a parameter. Options that come to mind:
Stratified sampling: roughly what we have now, but not quite
Fully random sampling: sample regardless of label
Similarity-based sampling: use the current sample to find similar samples from the training data (maybe with a simple tfidf vector?)
Custom sampling: Allow users to pass a callable that performs the sampling
Fine-tuning:
Instead of in-context learning via few-shot samples, as in FewShotClassifier, it can often be more performant (both from runtime and from scoring perspective) to fine-tune on the training data. We could consider using peft under the hood, which is agnostic with regard to the training framework, so it should work with skorch. This would be implemented in a separate class.
Refactor to use forward instead of generate:
For this change, it is not clear if it is better than the existing implementation or not.
Right now we rely on the generate methods for transformers models but we could instead use forward: by constructing the whole token sequence (input + label) and returning the corresponding logits, we can calculate the probabilities without having to go through logit processor + forcing. Some advantages are that the code could be simplified and it is more trivial to add batching (right now we predict one sample+label at a time, with this we could predict all labels in a single batch). Disadvantages are that generate does some heavy lifting for encoder-decoder, which we would have to reproduce, and that we lose caching. In practice, which approach is faster depends on many factors: batch size (memory!), length of input, length and overlap of labels, etc.
Here is some sample code that I adopted from my colleague Joao that demonstrates this approach and returns the exact same probabilities as our existing approach:
model, tokenizer= ...
result= []
forxinX:
result.append([])
input_length=len(tokenizer(get_prompt(x))['input_ids'])
forlabelinlabels:
# Build the prompts for all possible labelsinputs=tokenizer(
[get_prompt(x) +label],
return_tensors="pt",
padding=True,
).to(model.device)
# Run the forward pass for all prompts, extract the logits# TODO fails for enc-dec# TODO allow setting batch sizelogits=model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask).logits# Discard the logits the correspond to the input (remember: the logits at index N correspond to the token at# index N + 1. The first token has no logits, the last set of logits correspond to a token not present in the# input. We have to shift by 1.)logits=logits[:, input_length-1:-1, :]
# Compute the probabilities for each label. Here you need to be careful to remove the padding that may exist.probas=logits.softmax(dim=-1).cpu().numpy()
label_ids=inputs.input_ids[:, input_length:]
padding_mask=~label_ids.eq(pad_token_id)
label_token_probas=probas[0, np.arange(0, label_ids.shape[1], dtype=int), label_ids[0].cpu().numpy()]
label_proba=label_token_probas.prod()
result[-1].append((label, label_proba.item()))
The text was updated successfully, but these errors were encountered:
I wanted to create a single issue that collects TODOs around the LLM feature that was added recently, so that everything is in one place.
predict
- it is not necessary to calculate probabilities for all classes up until the last token. When class A has probabilityp_A
and class B hasp_B(t) < p_A
at tokent
, then no matter what the probability for tokens>t
,p_B
cannot exceedp_A
anymore.Right now, the sampling is hard-coded and basically tries to add each label at least once. This seems reasonable but there are situations where other strategies could make sense. Therefore, I would like to see a feature that allows setting the sampling strategy as a parameter. Options that come to mind:
Instead of in-context learning via few-shot samples, as in
FewShotClassifier
, it can often be more performant (both from runtime and from scoring perspective) to fine-tune on the training data. We could consider using peft under the hood, which is agnostic with regard to the training framework, so it should work with skorch. This would be implemented in a separate class.forward
instead ofgenerate
:For this change, it is not clear if it is better than the existing implementation or not.
Right now we rely on the
generate
methods for transformers models but we could instead useforward
: by constructing the whole token sequence (input + label) and returning the corresponding logits, we can calculate the probabilities without having to go through logit processor + forcing. Some advantages are that the code could be simplified and it is more trivial to add batching (right now we predict one sample+label at a time, with this we could predict all labels in a single batch). Disadvantages are thatgenerate
does some heavy lifting for encoder-decoder, which we would have to reproduce, and that we lose caching. In practice, which approach is faster depends on many factors: batch size (memory!), length of input, length and overlap of labels, etc.Here is some sample code that I adopted from my colleague Joao that demonstrates this approach and returns the exact same probabilities as our existing approach:
The text was updated successfully, but these errors were encountered: