Skip to content

Mimicking Keras Training API #5

@abhmul

Description

@abhmul

In order to be able to seamlessly swap pyjet models with keras models, pyjet needs to change to use the same API for training and interact the same way with training related keras objects (e.g. callbacks). In order to do so, the following needs to happen:

  • Write a compile method for SLModel to pass in loss functions, optimizers, and metrics into
  • Support multi-input, multi-loss, and multi-output in the same fashion as keras
  • Make fit_generator create a training log for callbacks just like Keras
  • Modify fit_generator to return a history object
  • Replace save_state and load_state with save_weights and load_weights respectively
  • Create save and load_model just like with keras
  • Change impelentable method for the forward pass of the network to call not forward to use forward for some forward pass overhead
  • Create non-generator equivalents (i.e. create Keras's fit, evaluate, and predict methods)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions