diff --git a/mnist.py b/mnist.py index 877eb5f..925f725 100644 --- a/mnist.py +++ b/mnist.py @@ -98,7 +98,7 @@ def define_mnist_flags(): flags_core.define_base() flags_core.define_performance(num_parallel_calls=False) flags_core.define_image() - data_dir = os.path.abspath(os.environ.get('PS_JOBSPACE', os.getcwd()) + '/data') + data_dir = os.path.abspath(os.environ.get('DATA_DIR', os.getcwd() + '/data')) model_dir = os.path.abspath(os.environ.get('PS_MODEL_PATH', os.getcwd() + '/models') + '/mnist') export_dir = os.path.abspath(os.environ.get('PS_MODEL_PATH', os.getcwd() + '/models')) flags.adopt_module_key_flags(flags_core)