Skip to content

Commit eb2f13c

Browse files
committed
train generate model updates for transformers 4.0.0
1 parent 95d0852 commit eb2f13c

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

amrlib/models/generate_t5/trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def __getitem__(self, idx):
3131
# prepares lm_labels from target_ids, returns examples with keys as expected by the forward method
3232
# this is necessacry because the trainer directly passes this dict as arguments to the model
3333
# so make sure the keys match the parameter names of the forward method
34+
# Note*1: The original code (with transformers v3.4.0) returned dict with "lm_labels".
35+
# Support for this was removed in transformers v4.0.0 and replaced it with "labels"
3436
class T2TDataCollator:
3537
def __call__(self, batch):
3638
input_ids = torch.stack([example['input_ids'] for example in batch])
@@ -39,7 +41,7 @@ def __call__(self, batch):
3941
attention_mask = torch.stack([example['attention_mask'] for example in batch])
4042
decoder_attention_mask = torch.stack([example['target_attention_mask'] for example in batch])
4143
return {'input_ids': input_ids, 'attention_mask': attention_mask,
42-
'lm_labels': lm_labels, 'decoder_attention_mask': decoder_attention_mask }
44+
'labels': lm_labels, 'decoder_attention_mask': decoder_attention_mask } # Note*1
4345

4446

4547
# Note that for save_steps, steps means gradient updates (not batch) so if
@@ -80,8 +82,11 @@ def train(self):
8082
len(valid_dataset), len(valid_dataset.bad_indexes)))
8183
# Train the model
8284
print('Training')
85+
# trainer = T5Trainer(model=self.model, args=self.training_args, train_dataset=train_dataset,
86+
# eval_dataset=valid_dataset, data_collator=T2TDataCollator(), prediction_loss_only=True)
87+
# prediction_loss_only=True moved to training_args for compatibility with transformers v4.0.0
8388
trainer = T5Trainer(model=self.model, args=self.training_args, train_dataset=train_dataset,
84-
eval_dataset=valid_dataset, data_collator=T2TDataCollator(), prediction_loss_only=True)
89+
eval_dataset=valid_dataset, data_collator=T2TDataCollator())
8590
trainer.train()
8691
# Save the results
8792
print('Saving model')

amrlib/models/parse_gsii/modules/transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def in_proj_qkv(self, query):
200200
# See release notes for v1.7 (torch.chunk) for an explanation. A temporary fix is to use unsafe_chunk instead.
201201
# See https://discuss.pytorch.org/t/runtimeerror-for-chunk-inplace-operation-new-with-torch-1-7/105334
202202
return self._in_proj(query).unsafe_chunk(3, dim=-1)
203+
# Possible solution...
204+
# proj = self._in_proj(query)
205+
# sz = proj.size()[2] // 3
206+
# return proj[:,:,:sz], proj[:,:,sz:2*sz], proj[:,:,2*sz:]
203207

204208
def in_proj_kv(self, key):
205209
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)

configs/model_generate_t5.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
"output_dir" : "amrlib/data/model_generate_t5",
1414
"do_train" : true,
1515
"do_eval" : false,
16-
"evaluate_during_training" : false,
1716
"overwrite_output_dir" : false,
17+
"prediction_loss_only" : true,
1818
"num_train_epochs" : 8,
1919
"save_steps" : 1000,
2020
"save_total_limit" : 2,

0 commit comments

Comments
 (0)