Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dist_ir/executor/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ def __init__(self, function: Function, inputs: Sequence[Any]):
self.timestamps = defaultdict(float)
self.peak_memory = defaultdict(lambda: 0)
# Values are tuples of (device, memory_used)
# TODO values are actually (timestamp, memory_used)?
# TODO should they be (op, memory_used_during_op)?
self.live_memory = defaultdict(lambda: [(0, 0)])
self.consumers = defaultdict(int)
self.trace = []
self._function_inputs_set = set(function.inputs)

# TODO this should look at `inputs` instead?
for inp in function.inputs:
if inp.type is None or inp.type.device is None:
continue
Expand Down Expand Up @@ -133,10 +136,14 @@ def _simulate_op(
state.consumers[out_edge] = len(state.function.consumers[out_edge])
output_devices = _get_all_devices([output])
for output_device in output_devices:
# TODO if an output is on more than one device this adds the same
# size to the memory usage of all devices!
live_memory_deltas[output_device] += output.size()
# TODO update peak memory in update_live_memory?
state.update_live_memory(live_memory_deltas)

# Update the peak memory.
# TODO this assumes no op is in-place
for device in state.live_memory:
state.peak_memory[device] = max(
state.peak_memory[device], state.live_memory[device][-1][1]
Expand All @@ -159,6 +166,7 @@ def _simulate_op(
input_devices = _get_all_devices([inp])
for input_device in input_devices:
live_memory_deltas[input_device] -= inp.size()
# TODO doesn't this result in two entries in state.live_memory with same timestamp?
state.update_live_memory(live_memory_deltas)


Expand Down