diff --git a/dist_ir/executor/simulator.py b/dist_ir/executor/simulator.py index f80a17e8..0b453417 100644 --- a/dist_ir/executor/simulator.py +++ b/dist_ir/executor/simulator.py @@ -39,11 +39,14 @@ def __init__(self, function: Function, inputs: Sequence[Any]): self.timestamps = defaultdict(float) self.peak_memory = defaultdict(lambda: 0) # Values are tuples of (device, memory_used) + # TODO values are actually (timestamp, memory_used)? + # TODO should they be (op, memory_used_during_op)? self.live_memory = defaultdict(lambda: [(0, 0)]) self.consumers = defaultdict(int) self.trace = [] self._function_inputs_set = set(function.inputs) + # TODO this should look at `inputs` instead? for inp in function.inputs: if inp.type is None or inp.type.device is None: continue @@ -133,10 +136,14 @@ def _simulate_op( state.consumers[out_edge] = len(state.function.consumers[out_edge]) output_devices = _get_all_devices([output]) for output_device in output_devices: + # TODO if an output is on more than one device this adds the same + # size to the memory usage of all devices! live_memory_deltas[output_device] += output.size() + # TODO update peak memory in update_live_memory? state.update_live_memory(live_memory_deltas) # Update the peak memory. + # TODO this assumes no op is in-place for device in state.live_memory: state.peak_memory[device] = max( state.peak_memory[device], state.live_memory[device][-1][1] @@ -159,6 +166,7 @@ def _simulate_op( input_devices = _get_all_devices([inp]) for input_device in input_devices: live_memory_deltas[input_device] -= inp.size() + # TODO doesn't this result in two entries in state.live_memory with same timestamp? state.update_live_memory(live_memory_deltas)