Skip to content

Commit 0f30fee

Browse files
DeepMindcopybara-github
authored andcommitted
Resolve unsoundness caught by pytype --strict-none-binding.
PiperOrigin-RevId: 707220791
1 parent f612730 commit 0f30fee

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

android_env/components/task_manager.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import json
2323
import re
2424
import threading
25-
from typing import Any
25+
from typing import Any, Optional
2626

2727
from absl import logging
2828
from android_env.components import adb_call_parser as adb_call_parser_lib
@@ -40,6 +40,11 @@
4040
class TaskManager:
4141
"""Handles all events and information related to the task."""
4242

43+
_setup_step_interpreter: setup_step_interpreter.SetupStepInterpreter
44+
_dumpsys_thread: dumpsys_thread.DumpsysThread
45+
_task_start_time: datetime.datetime
46+
_logcat_thread: logcat_thread.LogcatThread
47+
4348
def __init__(
4449
self,
4550
task: task_pb2.Task,
@@ -55,9 +60,6 @@ def __init__(
5560
self._task = task
5661
self._config = config or config_classes.TaskManagerConfig()
5762
self._lock = threading.Lock()
58-
self._logcat_thread = None
59-
self._dumpsys_thread = None
60-
self._setup_step_interpreter = None
6163

6264
# Initialize stats.
6365
self._stats = {
@@ -71,7 +73,6 @@ def __init__(
7173
}
7274

7375
# Initialize internal state
74-
self._task_start_time = None
7576
self._bad_state_counter = 0
7677
self._is_bad_episode = False
7778

@@ -84,6 +85,11 @@ def __init__(
8485

8586
logging.info('Task config: %s', self._task)
8687

88+
@property
89+
def _logcate_thread_ok(self) -> logcat_thread.LogcatThread:
90+
assert self._logcat_thread is not None
91+
return self._logcat_thread
92+
8793
def stats(self) -> dict[str, Any]:
8894
"""Returns a dictionary of stats.
8995
@@ -109,16 +115,16 @@ def start(
109115
"""Starts task processing."""
110116

111117
self._start_logcat_thread(log_stream=log_stream)
112-
self._logcat_thread.resume()
118+
self._logcate_thread_ok.resume()
113119
self._start_dumpsys_thread(adb_call_parser_factory())
114120
self._start_setup_step_interpreter(adb_call_parser_factory())
115121

116122
def reset_task(self) -> None:
117123
"""Resets a task for a new run."""
118124

119-
self._logcat_thread.pause()
125+
self._logcate_thread_ok.pause()
120126
self._setup_step_interpreter.interpret(self._task.reset_steps)
121-
self._logcat_thread.resume()
127+
self._logcate_thread_ok.resume()
122128

123129
# Reset some other variables.
124130
if not self._is_bad_episode:
@@ -139,7 +145,7 @@ def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep:
139145

140146
self._stats['episode_steps'] = 0
141147

142-
self._logcat_thread.line_ready().wait()
148+
self._logcate_thread_ok.line_ready().wait()
143149
with self._lock:
144150
extras = self._get_current_extras()
145151

@@ -156,7 +162,7 @@ def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep:
156162

157163
self._stats['episode_steps'] += 1
158164

159-
self._logcat_thread.line_ready().wait()
165+
self._logcate_thread_ok.line_ready().wait()
160166
with self._lock:
161167
reward = self._get_current_reward()
162168
extras = self._get_current_extras()

0 commit comments

Comments
 (0)