-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
architecture: factor HFCompatible out #954
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move looks reasonable, I am on the fence on using garak.resources.api
to represent wrappers for dependencies.
One idea would be to use garak/resources/huggingface/__init__.py
to expose the class, this could allow for keeping the class definitions in unique files imported to __init__py
for exposure as more Compatible
types are identified over time. Just a thought that came to mind, no strong argument to favor this at this time.
This PR also suggests there is another consumer for this mixin in buffs.paraphrase.PegasusT5
.
garak/buffs/paraphrase.py
Outdated
from garak.buffs.base import Buff | ||
from garak.resources.api.huggingface import HFCompatible | ||
|
||
|
||
class PegasusT5(Buff): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an observation not required for this PR, it looks like this class could benefit from a refactor to use HFCompatible
and expose the para_model_name
and hf_args
as DEFAULT_PARAMS
.
class PegasusT5(Buff): | |
class PegasusT5(Buff, HFCompatible): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is incomplete the class needs to be extended to consume the HFCompatible
mixin in _load_model()
:
self.torch_device
moves to the standardized self.device
and should be detected/populated from hf_args["device"]
with a call to self._select_hf_device()
in _load_model()
:
def __init__(self, config_root=_config) -> None:
self.max_length = 60
self.temperature = 1.5
self.num_return_sequences = 6
self.num_beams = self.num_return_sequences
self.tokenizer = None
self.para_model = None
super().__init__(config_root=config_root)
def _load_model(self):
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
self.device = self._select_hf_device()
model_kwargs = self._gather_hf_params(
hf_constructor=PegasusForConditionalGeneration.from_pretrained
) # will defer to device_map if device map was `auto` may not match self.device
self.para_model = PegasusForConditionalGeneration.from_pretrained(
self.para_model_name, **model_kwargs
).to(self.device)
self.tokenizer = PegasusTokenizer.from_pretrained(self.para_model_name)
Not an issue for this PR, however I suspect a few more items should likely be promoted to DEFAULT_PARAMS
. The max_length
, temperature
, num_return_sequences
, and possibly num_beams
if the value does not always have to be equal to num_return_sequences
should likely be exposed a configurable. Since Fast
looks like it may also have similar items to promote I think that can be deferred.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Thanks for the details. Will mark as ready for review when out of draft.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the model_kwargs
part led to the paraphraser returning all blanks. Did the rest of the integration and added a test to catch this unwanted behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still getting blank results when using _gather_hf_params
, will take a look in a bit, but if you have suggestions, they're welcome!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets drop the _gather_hf_params()
for now as the PegasusForConditionalGeneration.from_pretrained()
looks like it is not handling the extra args in the way the current code is expecting, I suspect device
vs device_map
is also impacting the expectations.
HFCompatible
was embedded in generators.base, tying slow-to-import HF-specific stuff to base classes. This PR movesHFCompatible
to a separate module, with a candidate location ingarak.resources.api.huggingface
, enabling fast base class loading.