Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bff/token optimisation #268

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
da582a6
adding tweaks
benjaminfh Oct 11, 2023
1b90ace
added token handling for fplans
benjaminfh Oct 14, 2023
1a17049
updated token handling
benjaminfh Oct 14, 2023
3f01d42
Merge pull request #1 from benjaminfh/bff-training-tweaks
benjaminfh Oct 14, 2023
c9ab76b
fixed task name
benjaminfh Oct 14, 2023
11573e9
Merge pull request #2 from benjaminfh/bff-training-tweaks
benjaminfh Oct 14, 2023
b569caf
updated save last to True
benjaminfh Oct 14, 2023
7c27f58
added pad to multiple line
benjaminfh Oct 31, 2023
b53748e
removed pad to mtpl
benjaminfh Oct 31, 2023
e3a048f
fixed special tokens var
benjaminfh Oct 31, 2023
89ccaaf
upgrade transformers req version
benjaminfh Oct 31, 2023
5e800ec
added some logging to DonutDataset
benjaminfh Oct 31, 2023
777b9d9
i == 1
benjaminfh Oct 31, 2023
b6f7edc
time watching in slow code
benjaminfh Oct 31, 2023
1707dfa
json dumping
benjaminfh Oct 31, 2023
11bc663
split out code and added perf timers
benjaminfh Oct 31, 2023
795d7cf
added print of task name for debug
benjaminfh Oct 31, 2023
d8fff8c
more debugging
benjaminfh Oct 31, 2023
a4b320a
debugging1111!!
benjaminfh Oct 31, 2023
e759001
fixed loading the special tokens
benjaminfh Oct 31, 2023
7c7de48
reverted changes
benjaminfh Nov 1, 2023
e09b0e1
reverted to master repo
benjaminfh Nov 1, 2023
348402d
checking all_special_tokens for perf
benjaminfh Nov 1, 2023
54ede19
clean up
benjaminfh Nov 1, 2023
dc8a28d
additional_special_tokens bool flag fix
benjaminfh Nov 1, 2023
8aaab66
remove my yaml
benjaminfh Nov 1, 2023
7303da4
made explicit in all usages
benjaminfh Nov 1, 2023
6e8b4bc
updated gitignore
benjaminfh Nov 14, 2023
0114d54
rm .DS_Store
benjaminfh Nov 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,5 @@ dmypy.json

# Pyre type checker
.pyre/

.DS_Store
12 changes: 7 additions & 5 deletions donut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(
self.model.forward = self.forward # to get cross attentions and utilize `generate` function

self.model.config.is_encoder_decoder = True # to get cross-attention
self.add_special_tokens(["<sep/>"]) # <sep/> is used for representing a list in a JSON
self.add_special_tokens(["<sep/>"], replace_additional_special_tokens=False) # <sep/> is used for representing a list in a JSON
self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference

Expand All @@ -199,11 +199,14 @@ def __init__(
new_bart_state_dict[x] = bart_state_dict[x]
self.model.load_state_dict(new_bart_state_dict)

def add_special_tokens(self, list_of_tokens: List[str]):
def add_special_tokens(self, list_of_tokens: List[str], replace_additional_special_tokens=False):
"""
Add special tokens to tokenizer and resize the token embeddings
"""
newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
if len(set(list_of_tokens) - set(self.tokenizer.all_special_tokens)) > 0:
newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))}, replace_additional_special_tokens=replace_additional_special_tokens)
else:
newly_added_num = 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is the real solution to the detected bug, isn't it?

if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))

Expand Down Expand Up @@ -510,8 +513,7 @@ def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True,
else:
keys = obj.keys()
for k in keys:
if update_special_tokens_for_json_key:
self.decoder.add_special_tokens([fr"<s_{k}>", fr"</s_{k}>"])
self.decoder.add_special_tokens([rf"<s_{k}>", rf"</s_{k}>"], replace_additional_special_tokens=False)
output += (
fr"<s_{k}>"
+ self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
Expand Down
2 changes: 1 addition & 1 deletion donut/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
]
)

self.donut_model.decoder.add_special_tokens([self.task_start_token, self.prompt_end_token])
self.donut_model.decoder.add_special_tokens([self.task_start_token, self.prompt_end_token], replace_additional_special_tokens=False)
self.prompt_end_token_id = self.donut_model.decoder.tokenizer.convert_tokens_to_ids(self.prompt_end_token)

def __len__(self) -> int:
Expand Down
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ def train(config):
"<form/>", "<handwritten/>", "<invoice/>", "<letter/>",
"<memo/>", "<news_article/>", "<presentation/>", "<questionnaire/>",
"<resume/>", "<scientific_publication/>", "<scientific_report/>", "<specification/>"
])
],
replace_additional_special_tokens=False
)
if task_name == "docvqa":
model_module.model.decoder.add_special_tokens(["<yes/>", "<no/>"])
model_module.model.decoder.add_special_tokens(["<yes/>", "<no/>"], replace_additional_special_tokens=False)

for split in ["train", "validation"]:
datasets[split].append(
Expand Down Expand Up @@ -173,4 +175,4 @@ def train(config):
config.exp_version = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") if not args.exp_version else args.exp_version

save_config_file(config, Path(config.result_path) / config.exp_name / config.exp_version)
train(config)
train(config)