diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index f6f9d7cd..671531b2 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -7,14 +7,13 @@ HfArgumentParser, set_seed, ) -from transformers import TrainingArguments -from tevatron.reranker.arguments import ModelArguments, DataArguments, \ - TevatronTrainingArguments +from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments from tevatron.reranker.modeling import RerankerModel from tevatron.reranker.dataset import RerankerTrainDataset from tevatron.reranker.collator import RerankerTrainCollator from tevatron.reranker.trainer import RerankerTrainer +from tevatron.reranker.gc_trainer import GradCacheTrainer logger = logging.getLogger(__name__) @@ -22,17 +21,9 @@ def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, data_args, training_args, tevatron_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: - model_args, data_args, training_args, tevatron_args = parser.parse_args_into_dataclasses() - model_args: ModelArguments - data_args: DataArguments - training_args: TrainingArguments - tevatron_args: TevatronTrainingArguments - - # Combine TrainingArguments and TevatronTrainingArguments - for key, value in vars(tevatron_args).items(): - setattr(training_args, key, value) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) @@ -60,7 +51,6 @@ def main(): ) logger.info("Training/evaluation parameters %s", training_args) logger.info("MODEL parameters %s", model_args) - logger.info("Tevatron parameters %s", tevatron_args) set_seed(training_args.seed)