Skip to content

Commit aad0e2a

Browse files
JinhaoLeijlei2
andauthored
spawn to start vertex uploader process (apple#1187)
Co-authored-by: Jinhao Lei <jlei2@apple.com>
1 parent f491d76 commit aad0e2a

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

axlearn/cloud/gcp/vertexai_tensorboard.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,13 @@ def upload(self):
137137
resource_name=self._resource_name,
138138
logdir=cfg.summary_dir,
139139
)
140-
self._uploader_proc = multiprocessing.Process(
140+
# Uses spawn method to start the process as it was observed that using fork method may
141+
# cause training job fail to exit correctly after completion. It is also suggested by
142+
# the warning message: "RuntimeWarning: os.fork() was called. os.fork() is incompatible
143+
# with multithreaded code, and JAX is multithreaded, so this will likely lead to a
144+
# deadlock.".
145+
ctx = multiprocessing.get_context("spawn")
146+
self._uploader_proc = ctx.Process(
141147
target=_start_vertexai_tensorboard, kwargs=kwargs, daemon=True
142148
)
143149
self._uploader_proc.start()

axlearn/cloud/gcp/vertexai_tensorboard_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def start(self): # pylint: disable=no-self-use
3030
class VertexAITensorboardUploaderTest(absltest.TestCase):
3131
"""Tests VertexAITensorboardUploader."""
3232

33-
@mock.patch("multiprocessing.Process", side_effect=fake_process)
33+
@mock.patch("multiprocessing.get_context")
3434
@mock.patch(f"{uploader.TensorBoardUploader.__module__}.TensorBoardUploader", autospec=True)
3535
@mock.patch(
3636
f"{initializer.global_config.__module__}.global_config.create_client",
@@ -45,9 +45,15 @@ class VertexAITensorboardUploaderTest(absltest.TestCase):
4545
return_value=("fake_bucket", "fake_folder"),
4646
)
4747
def test_uploader_calls(
48-
self, bucket_folder_fn, create_client_fn, tb_uploader_class, unused
48+
self,
49+
bucket_folder_fn,
50+
create_client_fn,
51+
tb_uploader_class,
52+
mock_get_context,
4953
): # pylint: disable=no-self-use
50-
del unused
54+
mock_context = mock.MagicMock()
55+
mock_context.Process.side_effect = fake_process
56+
mock_get_context.return_value = mock_context
5157

5258
mock_settings = {
5359
"vertexai_tensorboard": "fake_tb_instance",
@@ -61,6 +67,7 @@ def test_uploader_calls(
6167
)
6268
tb_uploader = cfg.instantiate()
6369
tb_uploader.upload()
70+
mock_get_context.assert_called_once_with("spawn")
6471
create_client_fn.assert_called_once()
6572
bucket_folder_fn.assert_called_once()
6673
tb_uploader_class.return_value.create_experiment.assert_called_once()

0 commit comments

Comments
 (0)