@@ -31,6 +31,8 @@ def __getitem__(self, idx):
3131# prepares lm_labels from target_ids, returns examples with keys as expected by the forward method
3232# this is necessacry because the trainer directly passes this dict as arguments to the model
3333# so make sure the keys match the parameter names of the forward method
34+ # Note*1: The original code (with transformers v3.4.0) returned dict with "lm_labels".
35+ # Support for this was removed in transformers v4.0.0 and replaced it with "labels"
3436class T2TDataCollator :
3537 def __call__ (self , batch ):
3638 input_ids = torch .stack ([example ['input_ids' ] for example in batch ])
@@ -39,7 +41,7 @@ def __call__(self, batch):
3941 attention_mask = torch .stack ([example ['attention_mask' ] for example in batch ])
4042 decoder_attention_mask = torch .stack ([example ['target_attention_mask' ] for example in batch ])
4143 return {'input_ids' : input_ids , 'attention_mask' : attention_mask ,
42- 'lm_labels ' : lm_labels , 'decoder_attention_mask' : decoder_attention_mask }
44+ 'labels ' : lm_labels , 'decoder_attention_mask' : decoder_attention_mask } # Note*1
4345
4446
4547# Note that for save_steps, steps means gradient updates (not batch) so if
@@ -80,8 +82,11 @@ def train(self):
8082 len (valid_dataset ), len (valid_dataset .bad_indexes )))
8183 # Train the model
8284 print ('Training' )
85+ # trainer = T5Trainer(model=self.model, args=self.training_args, train_dataset=train_dataset,
86+ # eval_dataset=valid_dataset, data_collator=T2TDataCollator(), prediction_loss_only=True)
87+ # prediction_loss_only=True moved to training_args for compatibility with transformers v4.0.0
8388 trainer = T5Trainer (model = self .model , args = self .training_args , train_dataset = train_dataset ,
84- eval_dataset = valid_dataset , data_collator = T2TDataCollator (), prediction_loss_only = True )
89+ eval_dataset = valid_dataset , data_collator = T2TDataCollator ())
8590 trainer .train ()
8691 # Save the results
8792 print ('Saving model' )
0 commit comments