diff --git a/distribution_strategy/parameter_server_training/Dockerfile.resnet_cifar_ps_strategy b/distribution_strategy/parameter_server_training/Dockerfile.resnet_cifar_ps_strategy new file mode 100644 index 00000000..72ec3977 --- /dev/null +++ b/distribution_strategy/parameter_server_training/Dockerfile.resnet_cifar_ps_strategy @@ -0,0 +1,22 @@ +FROM tensorflow/tensorflow:nightly + +RUN apt-get install -y python3 && \ + apt install python3-pip + +RUN pip3 install absl-py && \ + pip3 install portpicker + +# Install git +RUN apt-get update && \ + apt-get install -y git && \ + apt-get install -y vim + +RUN git clone --single-branch --branch benchmark https://github.com/tensorflow/models.git && \ + mv models tensorflow_models && \ + git clone https://github.com/tensorflow/model-optimization.git && \ + mv model-optimization tensorflow_model_optimization + +COPY resnet_cifar_ps_strategy.py / + +ENV PYTHONPATH "${PYTHONPATH}:/:/tensorflow_models" +CMD ["python", "/resnet_cifar_ps_strategy.py"] diff --git a/distribution_strategy/parameter_server_training/README.md b/distribution_strategy/parameter_server_training/README.md new file mode 100644 index 00000000..734c46fd --- /dev/null +++ b/distribution_strategy/parameter_server_training/README.md @@ -0,0 +1,167 @@ +# Parameter Server Training Using Distribution Strategies + +This directory provides an example of running parameter server training with Distribution Strategies. + +Please first read the [documentation](https://www.tensorflow.org/tutorials/distribute/parameter_server_training) of Distribution Strategy for parameter server training. We also assume that readers of this page are familiar with [Google Cloud](https://cloud.google.com/) and its [Kubernetes Engine](https://cloud.google.com/kubernetes-engine/). + +This directory contains the following files: + +- kubernetes/template.yaml.jinja: jinja template used for generating Kubernetes manifests +- kubernetes/render_template.py: script for rendering the jinja template +- Dockerfile.resnet_cifar_ps_strategy: a docker file to build the model image +- resnet_cifar_ps_strategy.py: a ResNet example using CIFAR-10 dataset for parameter server training +## Prerequisites + +1. First you need to have a Google Cloud project. Either create a new project or use an existing one. + +2. Install + [gcloud commandline tools](https://cloud.google.com/functions/docs/quickstart) + on your system, login, set project and zone, etc. + +3. Install [Docker](https://docs.docker.com/get-docker/) for your system + +4. Install kubectl: + + ```bash + gcloud components install kubectl + ``` +5. Start a Kubernetes cluster either with `gcloud` command as shown below or with + [GKE](https://cloud.google.com/kubernetes-engine/) web UI. Using more CPUs or nodes may require increasing your CPU [quotas](https://cloud.google.com/compute/quotas#requesting_additional_quota). + + ```bash + gcloud container clusters create --zone=us-west1-a --num-nodes=6 --machine-type=e2-standard-4 + ``` + +6. Set context for `kubectl` so that `kubectl` knows which cluster to use: + + ```bash + kubectl config use-context + ``` + +7. Create a + [service account](https://cloud.google.com/compute/docs/access/service-accounts) + and download its key file in JSON format. Assign Storage Admin role for + [Google Cloud Storage](https://cloud.google.com/storage/) to this service account: + + ```bash + gcloud iam service-accounts create --display-name="" + ``` + + ```bash + gcloud projects add-iam-policy-binding \ + --member="serviceAccount:@.iam.gserviceaccount.com" \ + --role="roles/storage.admin" + ``` + +8. Create a Kubernetes secret from the JSON key file of your service account: + + ```bash + kubectl create secret generic credential --from-file=key.json= + ``` + +9. Enable GCR ([Google Container Registry](https://cloud.google.com/container-registry)) service for your project using either GCP web UI or gcloud tool: + + ```bash + gcloud services enable containerregistry.googleapis.com + ``` + +10. Configure Docker to authenticate with Container Registry + + ```bash + gcloud auth configure-docker + ``` +## How to run the example + +1. Create three buckets for model data, checkpoints and training logs using either GCP web UI or gsutil tool (included with the gcloud tool you have installed above): + + ```bash + gsutil mb gs:// + ``` + You will use these bucket names to modify `data_dir`, `checkpoint_dir` and `train_log_dir` in step #4. + + +2. Download CIFAR-10 data and place them in your data_dir bucket. Head to the [ResNet in TensorFlow](https://github.com/tensorflow/models/tree/r1.13.0/official/resnet#cifar-10) directory to obtain CIFAR-10 data. Alternatively, you can use this [direct link](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz) to download and extract the data yourself as well. + + ```bash + python cifar10_download_and_extract.py + ``` + + Upload the contents of cifar-10-batches-bin directory to your `data_dir` bucket. + + ```bash + gsutil -m cp cifar-10-batches-bin/* gs:/// + + ``` + +3. Now let's build the Docker image: + + ```bash + docker build --no-cache -t resnet_cifar_ps_strategy:v1 -f Dockerfile.resnet_cifar_ps_strategy . + + ``` + + and push the image to + [Google Cloud Container Registery](https://cloud.google.com/container-registry/): + + ```bash + docker tag resnet_cifar_ps_strategy:v1 gcr.io//resnet_cifar_ps_strategy:v1 + docker push gcr.io//resnet_cifar_ps_strategy:v1 + ``` + +4. Modify the variables in template.yaml.jinja. You may want to change `name`, + `image`, `train_log_dir`, `script` and `cmdline_args`. + + * `name`: name your cluster, e.g. "my-parameter-server-example". + * `image`: the name of your docker image. + * `worker_replicas`: number of worker pods. + * `ps_replicas`: number of parameter server pods. + * `num_gpus_per_worker`: number of GPUs (this does not apply for this example since parameter server distribution strategy does not have GPU support yet) + * `has_coordinator`: flag for creating coordinator job + * `has_eval`: flag for creating evaluator job (this is set to False in the default template in order to use inline distributed evaluation. Setting this flag to True enables side-car evaluation.) + * `has_tensorboard`: flag for creating tensorboard job + * `script`: the script in the docker image to run. + * `train_log_dir`: used for logging training accuracy + * `cmdline_args`: the command line arguments passed to the `script`. + * `credential_secret_json`: the filename that was registered to Kubernetes as a secret. + * `credential_secret_key`: the name of the Kubernetes secret used for storing + your service account key. + * `port`: the port for all tasks including tensorboard. + * `use_node_port`: flag for using NodePort as type of service. Jinja template generates ingress only for tensorboard when this flag is set to `true`. Setting this flag to `false` enables LoadBalancer for all pods; assigning them external IPs (which may be limited by your public IP address quota). + +5. Start the training and evaluation on the cluster. + + You may want to verify the generated kubernetes manifests by running the following: + + ```bash + cd kubernetes + python render_template.py template.yaml.jinja | kubectl create -f - --dry-run=client + ``` + + After making sure that the above command succeeds, you can start the cluster (removing the dry-run flag): + + ```bash + python render_template.py template.yaml.jinja | kubectl create -f - + ``` + You'll see that your cluster has started training. You can inspect logs of + workers or use tensorboard to watch your model training. + + ```bash + kubectl get pods + ``` + + ```bash + kubectl logs -f + ``` + +6. You can find the TensorBoard service public IP address on Services & Ingress page of GKE, and access TensorBoard on http:// (or http://:5000 if you have set `use_node_port` to `false`)using your browser. + + The training accuracy graph shall look like the following: + + ![Traning accuracy - Tensorboard](images/tf-dist-ps-tensorboard.png) + +7. Destroy the cluster + + ```bash + gcloud container clusters delete + ``` + diff --git a/distribution_strategy/parameter_server_training/images/tf-dist-ps-tensorboard.png b/distribution_strategy/parameter_server_training/images/tf-dist-ps-tensorboard.png new file mode 100644 index 00000000..78e15d6f Binary files /dev/null and b/distribution_strategy/parameter_server_training/images/tf-dist-ps-tensorboard.png differ diff --git a/distribution_strategy/parameter_server_training/kubernetes/render_template.py b/distribution_strategy/parameter_server_training/kubernetes/render_template.py new file mode 100755 index 00000000..dcc0ff21 --- /dev/null +++ b/distribution_strategy/parameter_server_training/kubernetes/render_template.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +from __future__ import print_function + +import jinja2 +import sys + +if len(sys.argv) != 2: + print("usage: {} [template-file]".format(sys.argv[0]), file=sys.stderr) + sys.exit(1) +with open(sys.argv[1], "r") as f: + print(jinja2.Template(f.read()).render()) diff --git a/distribution_strategy/parameter_server_training/kubernetes/template.yaml.jinja b/distribution_strategy/parameter_server_training/kubernetes/template.yaml.jinja new file mode 100644 index 00000000..92ace0c9 --- /dev/null +++ b/distribution_strategy/parameter_server_training/kubernetes/template.yaml.jinja @@ -0,0 +1,150 @@ +{%- set name = "resnet-cifar-ps-strategy-example" -%} +{%- set image = "gcr.io/tensorflow-experimental/resnet_cifar_ps_strategy:v1" -%} +{%- set worker_replicas = 5 -%} +{%- set ps_replicas = 2 -%} +{%- set num_gpus_per_worker = 0 -%} +{%- set has_coordinator = True -%} +{%- set has_eval = False -%} +{%- set has_tensorboard = True -%} +{%- set script = "/resnet_cifar_ps_strategy.py" -%} +{%- set train_log_dir = "gs://cifar10-train-log/" -%} +{%- set cmdline_args = [ + "--data_dir=gs://cifar10-data/", + "--checkpoint_dir=gs://cifar10-ckpt/", + "--train_log_dir=" + train_log_dir + ] -%} +{%- set credential_secret_json = "key.json" -%} +{%- set credential_secret_key = "credential" -%} +{%- set port = 5000 -%} +{%- set use_node_port = True -%} + + +{%- set replicas = { + "worker": worker_replicas, + "ps": ps_replicas, + "chief": has_coordinator|int, + "evaluator": has_eval|int, + "tensorboard": has_tensorboard|int + } -%} + +{%- macro worker_hosts() -%} + {% for i in range(worker_replicas) %} + \"{{ name }}-worker-{{ i }}:{{ port }}\"{%- if not loop.last -%},{%- endif -%} + {% endfor %} +{%- endmacro -%} + +{%- macro ps_hosts() -%} + {% for i in range(ps_replicas) %} + \"{{ name }}-ps-{{ i }}:{{ port }}\"{%- if not loop.last -%},{%- endif -%} + {% endfor %} +{%- endmacro -%} + +{%- macro tf_config(task_type, task_id) -%} +{ + \"cluster\": { + \"worker\": [{{ worker_hosts() }}] + {%- if ps_replicas > 0 %}, + \"ps\": [{{ ps_hosts() }} + ]{% endif %} + {%- if has_coordinator %}, + \"chief\": [ + \"{{ name }}-chief-0:{{ port }}\" + ] + {%- endif %} + }, + \"task\": { + \"type\": \"{{ task_type }}\", + \"index\": \"{{ task_id }}\" + } +} +{%- endmacro -%} + +{% for job in ["chief", "worker", "ps", "evaluator", "tensorboard"] -%} +{%- for i in range(replicas[job]) -%} +{% if job == "tensorboard" and use_node_port %} +kind: Ingress +apiVersion: networking.k8s.io/v1beta1 +metadata: + name: tensorboard-ingress +spec: + backend: + serviceName: {{ name }}-{{ job }}-{{ i }} + servicePort: {{ port }} +--- +{% endif -%} +kind: Service +apiVersion: v1 +metadata: + name: {{ name }}-{{ job }}-{{ i }} +spec: + type: {{ 'NodePort' if use_node_port else 'LoadBalancer' }} + selector: + name: {{ name }} + job: {{ job }} + task: "{{ i }}" + ports: + - port: {{ port }} + {%- if use_node_port %} + targetPort: {{ port }} + {%- endif %} +--- +kind: Deployment +apiVersion: apps/v1 +metadata: + name: {{ name }}-{{ job }}-{{ i }} +spec: + replicas: 1 + selector: + matchLabels: + name: {{ name }} + job: {{ job }} + task: "{{ i }}" + template: + metadata: + labels: + name: {{ name }} + job: {{ job }} + task: "{{ i }}" + spec: + containers: +{%- if job == "tensorboard" %} + - name: tensorflow + image: tensorflow/tensorflow +{%- else %} + - name: tensorflow + image: {{ image }} +{%- endif %} + env: +{%- if job != "tensorboard" %} + - name: TF_CONFIG + value: "{{ tf_config(job, i) }}" +{%- endif %} + - name: GOOGLE_APPLICATION_CREDENTIALS + value: "/var/secrets/google/{{ credential_secret_json }}" + ports: + - containerPort: {{ port }} +{%- if job == "tensorboard" %} + command: + - "tensorboard" + args: + - "--logdir={{ train_log_dir }}" + - "--port={{ port }}" + - "--host=0.0.0.0" +{%- else %} + command: + - "python" + - "{{ script }}" + {%- for cmdline_arg in cmdline_args %} + - "{{ cmdline_arg }}" + {%- endfor -%} +{%- endif %} + volumeMounts: + - name: credential + mountPath: /var/secrets/google + volumes: + - name: credential + secret: + secretName: {{ credential_secret_key }} +--- +{% endfor %} +{%- endfor -%} diff --git a/distribution_strategy/parameter_server_training/resnet_cifar_ps_strategy.py b/distribution_strategy/parameter_server_training/resnet_cifar_ps_strategy.py new file mode 100644 index 00000000..57d6f88e --- /dev/null +++ b/distribution_strategy/parameter_server_training/resnet_cifar_ps_strategy.py @@ -0,0 +1,297 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# Lint as: python3 +"""ResNet Cifar + ParameterServerStrategy example. +""" +import os +from datetime import datetime +import multiprocessing +from absl import app +from absl import flags +from absl import logging +import portpicker +import tensorflow as tf +from tensorflow_models.official.benchmark.models import cifar_preprocessing +from tensorflow_models.official.benchmark.models import resnet_cifar_model +from tensorflow_models.official.vision.image_classification.resnet import common as img_class_common + +flags.DEFINE_string("checkpoint_dir", "gs://cifar10_ckpt/", + "Directory for writing model checkpoints.") +flags.DEFINE_string("data_dir", "gs://cifar10_data/", + "Directory for Resnet Cifar model input. Follow the " + "instruction here to get Cifar10 data: " + "https://github.com/tensorflow/models/tree/r1.13.0/official/resnet#cifar-10") +flags.DEFINE_string("train_log_dir", "gs://cifar10_train_log/", + "Directory for Resnet Cifar training logs") +flags.DEFINE_boolean( + "use_in_process_cluster", False, + "Whether to use in-process cluster for testing.") +flags.DEFINE_boolean( + "run_in_process_training", True, + "Whether to use in-process cluster to run training or evaluation.") + +FLAGS = flags.FLAGS + +TRAIN_EPOCHS = 182 +STEPS_PER_EPOCH = 781 +BATCH_SIZE = 64 +EVAL_BATCH_SIZE = 8 +EVAL_STEPS_PER_EPOCH = 88 + +def create_in_process_cluster(num_workers, num_ps): + """Creates and starts local servers and returns the cluster_spec dict.""" + worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] + ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] + + cluster_dict = {} + cluster_dict["worker"] = ["localhost:%s" % port for port in worker_ports] + if num_ps > 0: + cluster_dict["ps"] = ["localhost:%s" % port for port in ps_ports] + + cluster_spec = tf.train.ClusterSpec(cluster_dict) + + # Workers need some inter_ops threads to work properly + worker_config = tf.compat.v1.ConfigProto() + if multiprocessing.cpu_count() < num_workers + 1: + worker_config.inter_op_parallelism_threads = num_workers + 1 + + for i in range(num_workers): + tf.distribute.Server( + cluster_spec, + job_name="worker", + protocol="grpc", + config=worker_config, + task_index=i, + start=True) + + for i in range(num_ps): + tf.distribute.Server( + cluster_spec, + job_name="ps", + protocol="grpc", + task_index=i, + start=True) + + cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver( + cluster_spec, rpc_layer="grpc") + return cluster_resolver + +def run_tf_server_and_wait(cluster_resolver): + assert cluster_resolver.task_type in ("worker", "ps") + server = tf.distribute.Server( + cluster_resolver.cluster_spec(), + job_name=cluster_resolver.task_type, + task_index=cluster_resolver.task_id, + protocol=cluster_resolver.rpc_layer or "grpc", + start=True) + server.join() + + +def train_resnet_cifar(cluster_resolver): + """Trains the resnet56 model using parameter server distribution strategy. + + Args: + cluster_resolver: cluster resolver to give neccessary information to + set up distributed training + """ + + strategy = tf.distribute.experimental.ParameterServerStrategy( + cluster_resolver) + coordinator = ( + tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)) + with strategy.scope(): + model = resnet_cifar_model.resnet56() + + initial_learning_rate = ( + img_class_common.BASE_LEARNING_RATE * BATCH_SIZE / 128) + # Using the learning rate schedule from the model garden: + # tensorflow_models/official/benchmark/models/resnet_cifar_main.py + lr_segments = [ # (multiplier, epoch to start) tuples + (0.1, 91), (0.01, 136), (0.001, 182) + ] + lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay( + boundaries=list(p[1] * STEPS_PER_EPOCH for p in lr_segments), + values=[initial_learning_rate] + + list(p[0] * initial_learning_rate for p in lr_segments)) + optimizer = img_class_common.get_optimizer(lr_schedule) + + train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + name="train_accuracy") + eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + name="eval_accuracy") + + @tf.function + def worker_train_fn(iterator): + + def replica_fn(inputs): + """Training loop function.""" + batch_data, labels = inputs + with tf.GradientTape() as tape: + predictions = model(batch_data, training=True) + xent_loss = tf.keras.losses.SparseCategoricalCrossentropy( + reduction=tf.keras.losses.Reduction.NONE)(labels, predictions) + loss = ( + tf.nn.compute_average_loss(xent_loss) + + tf.nn.scale_regularization_loss(model.losses)) + gradients = tape.gradient(loss, model.trainable_variables) + + optimizer.apply_gradients( + zip(gradients, model.trainable_variables)) + train_accuracy.update_state(labels, predictions) + return loss + + inputs = next(iterator) + losses = strategy.run(replica_fn, args=(inputs,)) + return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None) + + + @tf.function + def worker_eval_fn(iterator): + + def eval_fn(inputs): + """Evaluation function""" + batch_data, labels = inputs + predictions = model(batch_data, training=False) + eval_accuracy.update_state(labels, predictions) + + inputs = next(iterator) + strategy.run(eval_fn, args=(inputs,)) + + checkpoint_manager = tf.train.CheckpointManager( + tf.train.Checkpoint(model=model, optimizer=optimizer), + FLAGS.checkpoint_dir, + max_to_keep=2) + if checkpoint_manager.latest_checkpoint: + checkpoint = checkpoint_manager.checkpoint + checkpoint.restore( + checkpoint_manager.latest_checkpoint + ).assert_existing_objects_matched() + + train_dataset_fn = lambda _: cifar_preprocessing.input_fn( + is_training=True, + data_dir=FLAGS.data_dir, + batch_size=BATCH_SIZE, + parse_record_fn=cifar_preprocessing.parse_record, + dtype=tf.float32, + drop_remainder=True) + eval_dataset_fn = lambda _: cifar_preprocessing.input_fn( + is_training=False, + data_dir=FLAGS.data_dir, + batch_size=EVAL_BATCH_SIZE, + parse_record_fn=cifar_preprocessing.parse_record, + dtype=tf.float32) + + # The following wrappers will allow efficient prefetching to GPUs + # when GPUs are supported by ParameterServerStrategy + @tf.function + def per_worker_train_dataset_fn(): + return strategy.distribute_datasets_from_function(train_dataset_fn) + + @tf.function + def per_worker_eval_dataset_fn(): + return strategy.distribute_datasets_from_function(eval_dataset_fn) + + per_worker_train_dataset = coordinator.create_per_worker_dataset( + per_worker_train_dataset_fn) + per_worker_eval_dataset = coordinator.create_per_worker_dataset( + per_worker_eval_dataset_fn) + + global_steps = int(optimizer.iterations.numpy()) + logging.info("Training starts with global_steps = %d", global_steps) + current_time = datetime.now().strftime("%Y%m%d-%H%M%S") + train_log_dir = FLAGS.train_log_dir + current_time + train_summary_writer = tf.summary.create_file_writer(train_log_dir) + + for epoch in range(global_steps // STEPS_PER_EPOCH, + TRAIN_EPOCHS): + per_worker_train_iterator = iter(per_worker_train_dataset) + per_worker_eval_iterator = iter(per_worker_eval_dataset) + for _ in range(STEPS_PER_EPOCH): + coordinator.schedule(worker_train_fn, args=(per_worker_train_iterator,)) + coordinator.join() + logging.info("Finished joining at epoch %d. Training accuracy: %f.", + epoch, train_accuracy.result()) + + # Since we are running inline evaluation below, a side-car evaluator job is not necessary. + for _ in range(EVAL_STEPS_PER_EPOCH): + coordinator.schedule(worker_eval_fn, args=(per_worker_eval_iterator,)) + coordinator.join() + logging.info("Finished joining at epoch %d. Evaluation accuracy: %f.", + epoch, eval_accuracy.result()) + + with train_summary_writer.as_default(): + tf.summary.scalar('train_accuracy', train_accuracy.result(), step=epoch) + tf.summary.scalar('eval_accuracy', eval_accuracy.result(), step=epoch) + train_accuracy.reset_states() + eval_accuracy.reset_states() + checkpoint_manager.save() + + +def evaluate_resnet_cifar(): + """Evaluates the resnet56 model + + This method provides side-car evaluation using the checkpoints + + """ + eval_model = resnet_cifar_model.resnet56() + eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + name="eval_accuracy") + eval_model.compile(metrics=eval_accuracy) + eval_dataset = cifar_preprocessing.input_fn( + is_training=False, + data_dir=FLAGS.data_dir, + batch_size=BATCH_SIZE, + parse_record_fn=cifar_preprocessing.parse_record) + + + checkpoint = tf.train.Checkpoint(model=eval_model) + + for latest_checkpoint in tf.train.checkpoints_iterator( + FLAGS.checkpoint_dir): + try: + checkpoint.restore(latest_checkpoint).expect_partial() + except tf.errors.OpError: + # checkpoint may be deleted by training when it is about to read it. + continue + + # Optionally add callbacks to write summaries. + eval_model.evaluate(eval_dataset) + + # Evaluation finishes when it has evaluated the last epoch. + if latest_checkpoint.endswith("-{}".format(TRAIN_EPOCHS)): + break + + +def main(_): + if FLAGS.use_in_process_cluster: + if FLAGS.run_in_process_training: + cluster_resolver = create_in_process_cluster(3, 1) + train_resnet_cifar(cluster_resolver) + else: + evaluate_resnet_cifar() + else: + os.environ["grpc_fail_fast"] = "use_caller" + cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() + if cluster_resolver.task_type in ("worker", "ps"): + run_tf_server_and_wait(cluster_resolver) + elif cluster_resolver.task_type == "evaluator": + evaluate_resnet_cifar() + else: + train_resnet_cifar(cluster_resolver) + + +if __name__ == "__main__": + app.run(main)