Skip to content

Can't save and load lattice model  #77

@win13676

Description

@win13676

hello, I'm having a problem with loading premade models (https://www.tensorflow.org/lattice/tutorials/premade_models)

when I save the model as .tf format and load the model I would get

KeyError: 'layers' 

when I save the model as keras format and load the model I would get

ValueError: Unknown object: 'CalibratedLatticeEnsembleConfig'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

I'm seeing this issue running on the colab given by the page (https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/premade_models.ipynb), on Databricks, and also local run

code I use to save and load the model (https://www.tensorflow.org/lattice/api_docs/python/tfl/premade/CalibratedLattice#save)

rtl_layer_ensemble_model.save("model.keras")
loaded_model = tf.keras.models.load_model("model.keras")

rtl_layer_ensemble_model.save("model.tf")
loaded_model = tf.keras.models.load_model("model.tf")

rtl_layer_ensemble_model.save("/")
loaded_model = tf.keras.models.load_model("/")

all the models in the example can't be loaded

linear_model, lattice_model, explicit_ensemble_model, random_ensemble_model, rtl_layer_ensemble_model, prefitting_model, crystals_ensemble_model

full stacktrace error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[<ipython-input-40-bc47832831ee>](https://localhost:8080/#) in <cell line: 2>()
      1 rtl_layer_ensemble_model.save("model.tf")
----> 2 loaded_model = tf.keras.models.load_model("model.tf")

2 frames
[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode, **kwargs)
    236 
    237     # Legacy case.
--> 238     return legacy_sm_saving_lib.load_model(
    239         filepath, custom_objects=custom_objects, compile=compile, **kwargs
    240     )

[/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

[/usr/local/lib/python3.10/dist-packages/keras/src/engine/functional.py](https://localhost:8080/#) in reconstruct_from_config(config, custom_objects, created_layers)
   1487 
   1488     # First, we create all layers and enqueue nodes to be processed
-> 1489     for layer_data in config["layers"]:
   1490         process_layer(layer_data)
   1491     # Then we process nodes in order of layer depth.

KeyError: 'layers'
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-39-c20ad91f21ca>](https://localhost:8080/#) in <cell line: 2>()
      1 rtl_layer_ensemble_model.save("model.keras")
----> 2 loaded_model = tf.keras.models.load_model("model.keras")

6 frames
[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode, **kwargs)
    228                 f"with the native Keras format: {list(kwargs.keys())}"
    229             )
--> 230         return saving_lib.load_model(
    231             filepath,
    232             custom_objects=custom_objects,

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode)
    273 
    274     except Exception as e:
--> 275         raise e
    276     else:
    277         return model

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode)
    238             # Construct the model from the configuration file in the archive.
    239             with ObjectSharingScope():
--> 240                 model = deserialize_keras_object(
    241                     config_dict, custom_objects, safe_mode=safe_mode
    242                 )

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/serialization_lib.py](https://localhost:8080/#) in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    702     safe_mode_scope = SafeModeScope(safe_mode)
    703     with custom_obj_scope, safe_mode_scope:
--> 704         instance = cls.from_config(inner_config)
    705         build_config = config.get("build_config", None)
    706         if build_config:

[/usr/local/lib/python3.10/dist-packages/tensorflow_lattice/python/premade.py](https://localhost:8080/#) in from_config(cls, config, custom_objects)
    146   @classmethod
    147   def from_config(cls, config, custom_objects=None):
--> 148     model_config = tf.keras.utils.legacy.deserialize_keras_object(
    149         config.get('model_config'), custom_objects=custom_objects
    150     )

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/legacy/serialization.py](https://localhost:8080/#) in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    478         # In this case we are dealing with a Keras config dictionary.
    479         config = identifier
--> 480         (cls, cls_config) = class_and_config_for_serialized_keras_object(
    481             config, module_objects, custom_objects, printable_module_name
    482         )

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/legacy/serialization.py](https://localhost:8080/#) in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    363     )
    364     if cls is None:
--> 365         raise ValueError(
    366             f"Unknown {printable_module_name}: '{class_name}'. "
    367             "Please ensure you are using a `keras.utils.custom_object_scope` "

ValueError: Unknown object: 'CalibratedLatticeEnsembleConfig'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions