diff --git a/memory_saving_gradients.py b/memory_saving_gradients.py index 0345bb5..931aec9 100644 --- a/memory_saving_gradients.py +++ b/memory_saving_gradients.py @@ -88,8 +88,10 @@ def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] + fwd_ops = [op for op in fwd_ops if not '/Read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] + ts_all = [t for t in ts_all if '/Read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not