Skip to content

Commit

Permalink
fix: trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Aug 24, 2024
1 parent bb5d87c commit beb333e
Showing 1 changed file with 4 additions and 14 deletions.
18 changes: 4 additions & 14 deletions src/tevatron/reranker/driver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,23 @@
HfArgumentParser,
set_seed,
)
from transformers import TrainingArguments

from tevatron.reranker.arguments import ModelArguments, DataArguments, \
TevatronTrainingArguments
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
from tevatron.reranker.modeling import RerankerModel
from tevatron.reranker.dataset import RerankerTrainDataset
from tevatron.reranker.collator import RerankerTrainCollator
from tevatron.reranker.trainer import RerankerTrainer
from tevatron.reranker.gc_trainer import GradCacheTrainer

logger = logging.getLogger(__name__)

def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, tevatron_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, tevatron_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
tevatron_args: TevatronTrainingArguments

# Combine TrainingArguments and TevatronTrainingArguments
for key, value in vars(tevatron_args).items():
setattr(training_args, key, value)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if (
os.path.exists(training_args.output_dir)
Expand Down Expand Up @@ -60,7 +51,6 @@ def main():
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("MODEL parameters %s", model_args)
logger.info("Tevatron parameters %s", tevatron_args)

set_seed(training_args.seed)

Expand Down

0 comments on commit beb333e

Please sign in to comment.