2222import json
2323import re
2424import threading
25- from typing import Any
25+ from typing import Any , Optional
2626
2727from absl import logging
2828from android_env .components import adb_call_parser as adb_call_parser_lib
4040class TaskManager :
4141 """Handles all events and information related to the task."""
4242
43+ _setup_step_interpreter : setup_step_interpreter .SetupStepInterpreter
44+ _dumpsys_thread : dumpsys_thread .DumpsysThread
45+ _task_start_time : datetime .datetime
46+ _logcat_thread : logcat_thread .LogcatThread
47+
4348 def __init__ (
4449 self ,
4550 task : task_pb2 .Task ,
@@ -55,9 +60,6 @@ def __init__(
5560 self ._task = task
5661 self ._config = config or config_classes .TaskManagerConfig ()
5762 self ._lock = threading .Lock ()
58- self ._logcat_thread = None
59- self ._dumpsys_thread = None
60- self ._setup_step_interpreter = None
6163
6264 # Initialize stats.
6365 self ._stats = {
@@ -71,7 +73,6 @@ def __init__(
7173 }
7274
7375 # Initialize internal state
74- self ._task_start_time = None
7576 self ._bad_state_counter = 0
7677 self ._is_bad_episode = False
7778
@@ -84,6 +85,11 @@ def __init__(
8485
8586 logging .info ('Task config: %s' , self ._task )
8687
88+ @property
89+ def _logcate_thread_ok (self ) -> logcat_thread .LogcatThread :
90+ assert self ._logcat_thread is not None
91+ return self ._logcat_thread
92+
8793 def stats (self ) -> dict [str , Any ]:
8894 """Returns a dictionary of stats.
8995
@@ -109,16 +115,16 @@ def start(
109115 """Starts task processing."""
110116
111117 self ._start_logcat_thread (log_stream = log_stream )
112- self ._logcat_thread .resume ()
118+ self ._logcate_thread_ok .resume ()
113119 self ._start_dumpsys_thread (adb_call_parser_factory ())
114120 self ._start_setup_step_interpreter (adb_call_parser_factory ())
115121
116122 def reset_task (self ) -> None :
117123 """Resets a task for a new run."""
118124
119- self ._logcat_thread .pause ()
125+ self ._logcate_thread_ok .pause ()
120126 self ._setup_step_interpreter .interpret (self ._task .reset_steps )
121- self ._logcat_thread .resume ()
127+ self ._logcate_thread_ok .resume ()
122128
123129 # Reset some other variables.
124130 if not self ._is_bad_episode :
@@ -139,7 +145,7 @@ def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep:
139145
140146 self ._stats ['episode_steps' ] = 0
141147
142- self ._logcat_thread .line_ready ().wait ()
148+ self ._logcate_thread_ok .line_ready ().wait ()
143149 with self ._lock :
144150 extras = self ._get_current_extras ()
145151
@@ -156,7 +162,7 @@ def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep:
156162
157163 self ._stats ['episode_steps' ] += 1
158164
159- self ._logcat_thread .line_ready ().wait ()
165+ self ._logcate_thread_ok .line_ready ().wait ()
160166 with self ._lock :
161167 reward = self ._get_current_reward ()
162168 extras = self ._get_current_extras ()
0 commit comments