diff --git a/README.md b/README.md index b592fee..77e5f33 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,26 @@ batch = np.random.rand(10000, 784) result = y.eval({x: batch}) ``` +##### Convert a Keras Model + +You can also convert Keras models (although many layers are still not supported, see test output for the keras.py test suite) + +```python +>>> import tfdeploy as td +>>> from keras.models import Sequential, Model +>>> from keras.layers import Convolution2D +>>> k_model = Sequential() +>>> k_model.add(Convolution2D(5, (3,3), input_shape = (9,9,1))) +>>> k_model.compile('sgd', 'mse') +>>> t_model, i_names, o_names = td.deploy_keras(k_model) +>>> type(t_model) + +>>> i_names +OrderedDict([('conv2d_1_input', 'conv2d_1_input:0')]) +>>> o_names +OrderedDict([('conv2d_1', 'conv2d_1/BiasAdd:0')]) +``` + ##### Write your own `Operation` tfdeploy supports most of the `Operation`'s [implemented in tensorflow](https://www.tensorflow.org/versions/master/api_docs/python/math_ops.html). However, if you miss one (in that case, submit a PR or an issue ;) ) or if you're using custom ops, you might want to extend tfdeploy by defining a new class op that inherits from `tfdeploy.Operation`: diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..25c3db8 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,4 @@ +numpy +tensorflow>=1.0 +matplotlib +keras>=2.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 8a02e4e..44ca63d 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,14 @@ readme = os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md") if os.path.isfile(readme): cmd = "pandoc --from=markdown --to=rst " + readme - p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True, executable="/bin/bash") - out, err = p.communicate() - if p.returncode != 0: + try: + p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True, executable="/bin/bash") + out, err = p.communicate() + returncode = p.returncode + except FileNotFoundError as file_exp: + print('pandoc and/or bash not found') + out, err, returncode = -1,str(file_exp), -1 + if returncode != 0: print("pandoc conversion failed: " + err) long_description = out else: diff --git a/tests/keras.py b/tests/keras.py new file mode 100644 index 0000000..e918f42 --- /dev/null +++ b/tests/keras.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- + + +import os +import unittest +from itertools import product + +import numpy as np +import tensorflow as tf + +from .base import TestCase, td + +# noinspection PyUnresolvedReferences +try: + + from keras.models import Sequential, Model, Input + from keras.layers import Dense, Convolution2D, MaxPooling2D, UpSampling2D, BatchNormalization, Dropout, Reshape, \ + Conv2DTranspose, LSTM, LeakyReLU, Activation, RepeatVector, Lambda, LocallyConnected2D + from keras.optimizers import Adam + from keras.backend import tensorflow_backend as tfb + import keras.backend as K + from keras import applications as kapps # for bigger prebuilt models + + KERAS_MISSING = False +except ImportError: + KERAS_MISSING = True + +UNSUPPORTED_LAYERS = ['Dropout', 'BatchNormalization', 'UpSampling2D', 'Convolution2DTranspose', + 'LSTM', 'RepeatVector', 'LocallyConnected2D'] + +__all__ = ["KerasTestCase"] + + +@unittest.skipIf(KERAS_MISSING, "requires Keras to be installed") +class KerasTestCase(TestCase): + def __init__(self, *args, **kwargs): + super(KerasTestCase, self).__init__(*args, **kwargs) + td.setup(tf) + K.set_image_dim_ordering('tf') + + def test_deploy_tool(self): + c_model = KerasTestCase._build_simple_2d(use_leakyrelu=True, use_pooling=True) + t_model, in_mapping, out_mapping = td.deploy_keras(c_model) + print(in_mapping) + print(out_mapping) + + self.assertIsInstance(t_model, td.Model, "Output should be tfdeploy model") + self.assertEqual(len(in_mapping), 1, "only one input") + self.assertIn('Reshape_input', in_mapping, "Reshape not found in input") + self.assertIn('MaxPooling2D', out_mapping, "MaxPooling not found in output") + self.assertEqual(len(out_mapping), 1, "only one ouput") + for c_mapping in [in_mapping, out_mapping]: + for keras_name, tf_name in c_mapping.items(): + cur_tensor = t_model.get(tf_name) + self.assertIsNotNone(cur_tensor, "Layer: {} -> TF:{}, not found in model".format( + keras_name, tf_name)) + self.assertIsInstance(cur_tensor, td.Tensor, "Layer should be tensor: {}".format(cur_tensor)) + + def test_cnn_models(self): + model_kwargs = dict(use_dense=False, use_dropout=False, use_pooling=False, use_bn=False, use_upsample=False, + use_conv2dtrans=False, use_lstm=False, use_leakyrelu=False, use_repeatvec=False, + use_lambda=False, use_locallyconnected=False) + + def _try_args(**kw_args): + new_args = model_kwargs.copy() + new_args.update(kw_args) + return new_args + + test_models = [('base_cnn', KerasTestCase._build_simple_2d())] + test_models += [(c_arg, + KerasTestCase._build_simple_2d(**_try_args(**{c_arg: True}))) + for c_arg in model_kwargs.keys()] + + deployed_models = [] + for i, (model_name, cur_keras_model) in enumerate(test_models): + model_layers = ','.join(map(lambda x: x.name, cur_keras_model.layers)) + out_path = "%04d.pkl" % i + try: + deployed_models += \ + KerasTestCase.export_keras_model(cur_keras_model, out_path, model_name=model_layers) + except td.UnknownOperationException as uoe: + print('Model {}: {}'.format(i, model_name), 'could not be serialized', uoe) + bad_layer_count = sum([us_layer in model_layers for us_layer in UNSUPPORTED_LAYERS]) + self.assertGreater(bad_layer_count, 0, + "Model contains no unsupported layers {}, " + "Unsupported Layers:{}".format(model_layers, UNSUPPORTED_LAYERS)) + + self.assertGreater(len(deployed_models), 0, "No models could be tested") + print("Testing #{} models".format(len(deployed_models))) + for c_model_pkl in deployed_models: + result = KerasTestCase.deploy_model(c_model_pkl) + self.assertIsNotNone(result, "Result should not be empty") + self.assertEqual(len(result.shape), 4, "Output should be 4D Tensor: {}".format(result.shape)) + os.remove(c_model_pkl['path']) + + @unittest.skip("Takes quite awhile to run (and fails for all models)") + def test_big_models(self): + """ + A test for bigger commonly used pretrained models (for this we skip the weights) + :return: + """ + kapp_kwargs = dict( + input_shape=(99, 99, 3), + weights=None, + include_top=False # so we can use different sizes + ) + test_models = [] + + test_models += [('Resnet50', kapps.ResNet50(**kapp_kwargs))] + test_models += [('InceptionV3', kapps.InceptionV3(**kapp_kwargs))] + test_models += [('VGG19', kapps.VGG19(**kapp_kwargs))] + test_models += [('Xception', kapps.Xception(**kapp_kwargs))] + + for i, (model_name, cur_keras_model) in enumerate(test_models): + + model_layers = ','.join(map(lambda x: x.name, cur_keras_model.layers)) + out_path = "%04d.pkl" % i + try: + c_model_pkl = KerasTestCase.export_keras_model(cur_keras_model, out_path, model_name=model_layers) + except td.UnknownOperationException as uoe: + print('Model {}: {}'.format(i, model_layers), 'could not be serialized', uoe) + bad_layer_count = sum([us_layer in model_layers for us_layer in UNSUPPORTED_LAYERS]) + self.assertGreater(bad_layer_count, 0, + "Model contains no unsupported layers {}, " + "Unsupported Layers:{}".format(model_layers, UNSUPPORTED_LAYERS)) + continue + except tf.errors.RESOURCE_EXHAUSTED: + # many of the bigger models take up quite a bit of GPU memory + print('Model {} with #{} layers is too big for memory'.format(model_name, len(cur_keras_model.layers))) + + result = KerasTestCase.deploy_model(c_model_pkl, np.random.uniform(0, 1, size=(299, 299, 3))) + self.assertIsNotNone(result, "Result should not be empty") + self.assertEqual(len(result.shape), 4, "Output should be 4D Tensor: {}".format(result.shape)) + os.remove(c_model_pkl['path']) + + @staticmethod + def deploy_model(c_model_pkl, input=None): + model = td.Model(c_model_pkl['path']) + inp, outp = model.get(c_model_pkl['input'], c_model_pkl['output']) + if input is None: + input = np.random.rand(50, 81) + return outp.eval({inp: input}) + + @staticmethod + def export_keras_model(in_ks_model, out_path, model_name): + td_model = td.Model() + td_model.add(in_ks_model.get_output_at(0), + tfb.get_session()) # y and all its ops and related tensors are added recursively + + td_model.save(out_path) + return [dict(path=out_path, + output=in_ks_model.get_output_at(0).name, + input=in_ks_model.get_input_at(0).name, + name=model_name)] + + @staticmethod + def compile_model(i_model): + i_model.compile(optimizer=Adam(lr=2e-3), loss='mse') + + @staticmethod + def _build_simple_2d(use_dense=False, use_dropout=False, use_pooling=False, use_bn=False, use_upsample=False, + use_conv2dtrans=False, use_lstm=False, use_leakyrelu=False, use_repeatvec=False, + use_lambda=False, use_locallyconnected=False): + """ + Simple function for building CNN models with various layers turned on and off + :param use_dropout: + :param use_pooling: maxpooling2d + :param use_bn: batchnormalization + :param use_upsample: + :return: + """ + out_model = Sequential() + if use_lstm: + out_model.add(Reshape(target_shape=(1, 81), input_shape=(81,), name='Reshape_LSTM')) + out_model.add(LSTM(81, name='LSTM')) + if use_dense: + out_model.add(Dense(81, input_shape=(81,), name='Dense')) + if use_repeatvec: + out_model.add(RepeatVector(3, input_shape=(81,), name='RepeatVector')) + out_model.add(Lambda(lambda x: x[0, :], name='Lambda')) + out_model.add(Reshape(target_shape=(9, 9, 1), input_shape=(81,), name='Reshape')) + out_model.add(Convolution2D(2, (3, 3), input_shape=(9, 9, 1), name='Convolution2D')) + if use_lambda: + out_model.add(Lambda(lambda x: x + 1, name='Lambda_add')) + if use_leakyrelu: + out_model.add(LeakyReLU(0.1, name='LeakyRelu')) + if use_dropout: + out_model.add(Dropout(0.5, name='Dropout')) + if use_pooling: + out_model.add(MaxPooling2D((2, 2), name='MaxPooling2D')) + if use_upsample: + out_model.add(UpSampling2D((2, 2), name='UpSampling2D')) + if use_bn: + out_model.add(BatchNormalization(name='BatchNormalization')) + if use_conv2dtrans: + out_model.add(Conv2DTranspose(2, kernel_size=(3, 3), strides=(2, 2), name='Convolution2DTranspose')) + if use_locallyconnected: + out_model.add(LocallyConnected2D(3, (3, 3), name='LocallyConnected2D')) + + KerasTestCase.compile_model(out_model) + return out_model diff --git a/tfdeploy.py b/tfdeploy.py index 5672021..b9d91fd 100644 --- a/tfdeploy.py +++ b/tfdeploy.py @@ -19,6 +19,7 @@ "InvalidImplementationException", "UnknownImplementationException", "EnsembleMismatchException", "ScipyOperationException", "reset", "optimize", "print_tensor", "print_op", "print_tf_tensor", "print_tf_op", + "deploy_keras", "IMPL_NUMPY", "IMPL_SCIPY", "IMPLS", "METHOD_MEAN", "METHOD_MAX", "METHOD_MIN", "METHOD_CUSTOM", "METHODS", "HAS_SCIPY"] @@ -29,6 +30,7 @@ import re from uuid import uuid4 from functools import reduce +from collections import OrderedDict try: # python 2 @@ -56,7 +58,6 @@ def wrapper(cls): return metaclass(cls.__name__, cls.__bases__, orig_vars) return wrapper - class Model(object): """ A trained model that contains one or more converted tensorflow graphs. When *path* is set, a @@ -177,7 +178,7 @@ def save(self, path): """ path = os.path.expandvars(os.path.expanduser(path)) with open(path, "wb") as f: - pickle.dump(self.roots, f) + pickle.dump(self.roots, f, protocol = 2) # make it python2 compatible always class TensorRegister(type): @@ -196,6 +197,47 @@ def __call__(cls, tf_tensor, *args, **kwargs): return cls.instances[tf_tensor] +def deploy_keras(in_keras_model): + # type: (keras.models.Model) -> Tuple[Model, Dict[str, str], Dict[str, str]] + """ + Converts a keras model (>2.0) to a tfdeploy model and provides a list of input and output + mappings from keras layer names to tensorflow tensor names + :param in_keras_model: the keras model to convert + :return: the tfdeploy model, a dictionary mapping inputs and a dictionary mapping outputs + The dictionaries map keras layer names to tensorflow names + Usage + ==== + >>> from keras.models import Sequential, Model + >>> from keras.layers import Convolution2D + >>> k_model = Sequential() + >>> k_model.add(Convolution2D(5, (3,3), input_shape = (9,9,1))) + >>> k_model.compile('sgd', 'mse') + >>> t_model, i_names, o_names = deploy_keras(k_model) + >>> type(t_model) + + >>> i_names + OrderedDict([('conv2d_1_input', 'conv2d_1_input:0')]) + >>> o_names + OrderedDict([('conv2d_1', 'conv2d_1/BiasAdd:0')]) + """ + try: + from keras.backend import tensorflow_backend as tfb + except ImportError: + raise NotImplementedError("Keras is not installed or not setup with the tensorflow backend!") + + td_model = Model() + keras_in_mapping = OrderedDict() + for i, in_name in enumerate(in_keras_model.input_names): + keras_in_mapping[in_name] = in_keras_model.get_input_at(i).name + + keras_out_mapping = OrderedDict() + for i, out_name in enumerate(in_keras_model.output_names): + keras_out_mapping[out_name] = in_keras_model.get_output_at(i).name + td_model.add(in_keras_model.get_output_at(i), + tfb.get_session()) # y and all its ops and related tensors are added recursively + + return td_model, keras_in_mapping, keras_out_mapping + @add_metaclass(TensorRegister) class Tensor(object): """