Skip to content

Gradient checkpointing seems to conflict with Keras batch norm #47

@demmerichs

Description

@demmerichs

I tried this out but get an Error when computing the gradients with the provided function using manually selected checkpoints. I get three different errors at the same time, and am not sure what of my graph is actually causing them, so I would appreciate some hints so that I could come up with a minimal non-working example. I currently use TF1.13.1 and especially the tf.keras.layers.BatchNormalization (just saying this because it pops up along the Error message). Is there any hope that this would be an easy fix?

Traceback (most recent call last):                                                                                                                             
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 415, in _MaybeCompile                    
    xla_compile = op.get_attr("_XlaCompile")                                                                                                                              
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2413, in get_attr
    raise ValueError(str(e))                                                                                                                          
ValueError: Operation 'optimizer/head/convolve_batch_activate_20/batch_normalization_v1_21/cond/ReadVariableOp_1/Switch' has no attr named '_XlaCompile'.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 455, in _apply_op_helper
    as_ref=input_arg.is_ref)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1240, in internal_convert_n_to_tensor
    ctx=ctx))
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1175, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 977, in _TensorTensorConversionFunction
    (dtype.name, t.dtype.name, str(t)))
ValueError: Tensor conversion requested dtype float32 for Tensor with dtype resource: 'Tensor("optimizer/gradients/optimizer/head/convolve_batch_activate_20/batch_normalization_v1_21/cond/ReadVariableOp_1/Switch
_grad/Switch_1:1", shape=(), dtype=resource)'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "./src/sadt.py", line 544, in <module>
    with SpaceAndDeformableTimeNetwork(cfg, datasets) as exp:
  File "/lhome/davidj2/code/sync/space_and_deformable_time/src/xxsflow/experiments/base_experiment.py", line 42, in __enter__
    self.build_graph()
  File "./src/sadt.py", line 298, in build_graph
    self.optimizer_op = self.optimizer
  File "/lhome/davidj2/code/sync/space_and_deformable_time/src/xxsflow/utils.py", line 388, in wrapped_function
    setattr(self, attribute, function(self))
  File "./src/sadt.py", line 267, in optimizer
    grads = grads = tf.gradients(self.loss, tf.trainable_variables())
  File "/lhome/davidj2/code/sync/space_and_deformable_time/packages/gradient_checkpointing/memory_saving_gradients.py", line 40, in gradients_collection
    return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/packages/gradient_checkpointing/memory_saving_gradients.py", line 227, in gradients
    dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/packages/gradient_checkpointing/memory_saving_gradients.py", line 27, in tf_gradients
    return tf_gradient_function(ys, *args, **kwargs)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 664, in gradients
    unconnected_gradients)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 965, in _GradientsHelper
    lambda: grad_fn(op, *out_grads))
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 420, in _MaybeCompile
    return grad_fn()  # Exit early
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 965, in <lambda>
    lambda: grad_fn(op, *out_grads))
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_grad.py", line 88, in _SwitchGrad
    return merge([false_grad, true_grad])[0], None
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 466, in merge
    return gen_control_flow_ops.merge(inputs, name)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gen_control_flow_ops.py", line 418, in merge
    "Merge", inputs=inputs, name=name)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 483, in _apply_op_helper
    raise TypeError("%s that don't all match." % prefix)
TypeError: Tensors in list passed to 'inputs' of 'Merge' Op have types [float32, resource] that don't all match.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions