diff --git a/dist_ir/executor/absint.py b/dist_ir/executor/absint.py index fb634be9..67bf9288 100644 --- a/dist_ir/executor/absint.py +++ b/dist_ir/executor/absint.py @@ -78,6 +78,31 @@ def interpret_pmap(self, op: Op, state: AbstractState): return state + def interpret_function_call(self, op: Op, state: AbstractState): + # Find the op's inputs in state's environment and save environment + inputs = tuple(state.env[v] for v in op.inputs) + old_env = state.env + state.env = {} # To enforce variable scoping + + # Change state's function pointer to subfunction (TODO necessary?) + function = state.function + state.function = op.subfunctions[0] + + # Interpret subfunction with appropriate inputs + self.interpret(op.subfunctions[0], inputs, state=state) + + # Find the outputs from the state's env + results = tuple(state.env[v] for v in op.subfunctions[0].outputs) + + # Put the results back into the state's environment + state.env = old_env + for x, val in zip(op.outputs, results): + state.env[x] = val + # Also reset state's function pointer + state.function = function + + return state + def interpret( self, function: Function, inputs: Sequence[Any], state: AbstractState = None ): @@ -93,7 +118,9 @@ def interpret( # Execute ops in topological order: for op in function.ops: - if op.op_type == "Pmap": + if op.op_type == "FnCall": + self.interpret_function_call(op, state) + elif op.op_type == "Pmap": self.interpret_pmap(op, state) else: # Function dispatch: diff --git a/dist_ir/executor/sequential_executor.py b/dist_ir/executor/sequential_executor.py index 9193daf5..e7544956 100644 --- a/dist_ir/executor/sequential_executor.py +++ b/dist_ir/executor/sequential_executor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Sequence +from typing import Any, Sequence, Tuple from .absint import AbstractInterpreter, convert_impls_to_semantics from .backend_register import BackendRegister @@ -12,34 +12,7 @@ def __init__(self, backend): semantics = convert_impls_to_semantics(BackendRegister[backend]) self.interpreter = AbstractInterpreter(semantics=semantics) - def _compute_op(self, op: Op, inputs: List[Any]): - """Executes the given op and returns its outputs.""" - op_type = op.op_type - if op_type == "Pmap": - # Zip the inputs so that we map over each corresponding value - inputs = zip(*inputs) - # Iterate over the inputs - results = [] - for inps in inputs: - # Execute subfunction with appropriate inputs - outs = self.compute(op.subfunctions[0], inps) - # Match output names to output data using the function output order. - ordered_outs = [outs[e] for e in op.subfunctions[0].outputs] - results.append(ordered_outs) - # Unzip the results - results = tuple(zip(*results)) - return results - if op_type not in BackendRegister[self._backend]: - raise NotImplementedError( - f"No {self._backend} implementation found for op {op_type}" - ) - impl = BackendRegister[self._backend][op_type] - output_data = impl(op, inputs) - if not isinstance(output_data, tuple): - output_data = (output_data,) - return output_data - - def compute(self, function: Function, inputs: Sequence[Any]) -> Dict[Value, Any]: + def compute(self, function: Function, inputs: Sequence[Any]) -> Tuple[Any]: """Executes the function given the specified inputs and returns the final result. Args: @@ -47,7 +20,7 @@ def compute(self, function: Function, inputs: Sequence[Any]) -> Dict[Value, Any] inputs: A sequence of input data represented in the specified backend. Returns: - A map from output value to output data. + A tuple of outputs. """ state = self.interpreter.interpret(function, inputs) return tuple(state.env[v] for v in function.outputs) diff --git a/dist_ir/executor/simulator.py b/dist_ir/executor/simulator.py index 15c13fad..5e39ba74 100644 --- a/dist_ir/executor/simulator.py +++ b/dist_ir/executor/simulator.py @@ -23,10 +23,10 @@ def __init__(self, function: Function, inputs: Sequence[Any]): self.consumers = defaultdict(int) self.trace = [] - def add_trace_event(self, op_name, device, start_time, duration): + def add_trace_event(self, op_type, device, start_time, duration): self.trace.append( { - "name": op_name, + "name": op_type, "ph": "X", "ts": start_time, "dur": duration, @@ -70,7 +70,7 @@ def _simulate_op( # Update the trace and timestamps for device in costs: state.add_trace_event( - op.name, + op.op_type, device, state.timestamps[device], costs[device], @@ -78,22 +78,23 @@ def _simulate_op( state.timestamps[device] += costs[device] # Update the live memory. - for out_edge in op.outputs: - state.consumers[out_edge] = len(state.function.consumers[out_edge]) - # Output value could live on multiple devices (e.g. scatter) so - # update memory on all devices: - output_devices = out_edge.type.get_all_devices() - for output_device in output_devices: - state.live_memory[output_device] += out_edge.type.size() + for out_val, conc_val in zip(op.outputs, outputs): + if isinstance(conc_val, Type): + state.consumers[conc_val] = len(state.function.consumers[out_val]) + # Output value could live on multiple devices (e.g. scatter) so + # update memory on all devices: + output_devices = conc_val.get_all_devices() + for output_device in output_devices: + state.live_memory[output_device] += conc_val.size() # TODO: Can we optimize this using a priority queue? for value in state.consumers: # TODO we are missing a decrement of state.consumers[value] somewhere if state.consumers[value] == 0 and all( value != v for v in state.function.inputs ): - value_devices = value.type.get_all_devices() + value_devices = value.get_all_devices() for device in value_devices: - state.live_memory[device] -= value.type.size() + state.live_memory[device] -= value.size() # Update the peak memory. for device in state.live_memory: diff --git a/dist_ir/executor/type_inference.py b/dist_ir/executor/type_inference.py index ff8a46c8..ad296bef 100644 --- a/dist_ir/executor/type_inference.py +++ b/dist_ir/executor/type_inference.py @@ -287,15 +287,16 @@ def _type_function(function: Function, type_map: Dict[Value, Type]) -> Function: # Invariant: inputs of op are already typed (as ops are toposorted) typed_inputs = tuple(value_map[inp] for inp in op.inputs) - # Recursively convert the subfunctions: - subfunctions = tuple(_type_function(fn, type_map) for fn in op.subfunctions) + # Recursively convert the subfunctions? + # TODO how to handle multiple calls to function with varying types/shapes? + # subfunctions = tuple(_type_function(fn, type_map) for fn in op.subfunctions) new_op = Op( op_type=op.op_type, name=op.name, inputs=typed_inputs, attributes=op.attributes, - subfunctions=subfunctions, + subfunctions=op.subfunctions, output_names=tuple(v.name for v in op.outputs), # Look up output types from type_map output_types=tuple(type_map[v] for v in op.outputs), diff --git a/dist_ir/ir/op.py b/dist_ir/ir/op.py index 29a548db..207585d7 100644 --- a/dist_ir/ir/op.py +++ b/dist_ir/ir/op.py @@ -22,7 +22,17 @@ class Op: output_types: InitVar[Tuple[Type]] = None def __post_init__(self, output_names, output_types): - if self.op_type == "Pmap": + if self.op_type == "FnCall": + # Function calls. Subfunction 0 is the called function + assert len(self.subfunctions) == 1 + # Number of inputs is arbitrary but positive + assert len(self.inputs) > 0 + # Number of inputs matches subfunction + assert len(self.inputs) == len(self.subfunctions[0].inputs) + # Number of outputs is given by subfunction + num_outputs = len(self.subfunctions[0].outputs) + + elif self.op_type == "Pmap": # Handle pmap specially assert len(self.subfunctions) == 1 # Number of inputs is arbitrary but positive diff --git a/test/test_sequential_executor.py b/test/test_sequential_executor.py index 071024d2..46ebfe7a 100644 --- a/test/test_sequential_executor.py +++ b/test/test_sequential_executor.py @@ -331,3 +331,26 @@ def test_pmap_dp(): (res,) = ex.compute(function, [(x_0, x_1), (_wA, _wA), (_wB, _wB)]) assert np.array_equal(res[0], np.matmul(np.matmul(x_0, _wA), _wB)) assert np.array_equal(res[1], np.matmul(np.matmul(x_1, _wA), _wB)) + + +def test_function_call(): + layer = FunctionMaker() + x = layer.add_input_value("x", None) + w = layer.add_input_value("w", None) + _ = layer.add_op("MatMul", inputs=[x, w]) + layer = layer.finalize() + fn = FunctionMaker() + x = fn.add_input_value("x", None) + w1 = fn.add_input_value("w1", None) + w2 = fn.add_input_value("w2", None) + a1 = fn.add_op("FnCall", inputs=[x, w1], subfunctions=[layer]) + _ = fn.add_op("FnCall", inputs=[a1, w2], subfunctions=[layer]) + fn = fn.finalize() + cpprint(fn) + + ex = SequentialExecutor("numpy") + _x = np.arange(16 * 4).reshape((16, 4)) + _w1 = np.ones((4, 2)) + _w2 = np.ones((2, 1)) + (res,) = ex.compute(fn, [_x, _w1, _w2]) + assert np.array_equal(res, np.matmul(np.matmul(_x, _w1), _w2)) diff --git a/test/test_simulator.py b/test/test_simulator.py index 3519830e..966e3f4a 100644 --- a/test/test_simulator.py +++ b/test/test_simulator.py @@ -104,3 +104,27 @@ def test_chrome_trace(): transformed_function, (v.type for v in transformed_function.inputs) ) simulation.dump_chrome_trace("test/trace.json") + + +def test_function_call(): + topology = Topology() + d0 = topology.add_device("gpu") + + layer = FunctionMaker() + x = layer.add_input_value("x", None) + w = layer.add_input_value("w", None) + _ = layer.add_op("MatMul", inputs=[x, w]) + layer = layer.finalize() + fn = FunctionMaker() + x = fn.add_input_value("x", Tensor(Float(), (4, 5), device=d0)) + w1 = fn.add_input_value("w1", Tensor(Float(), (5, 6), device=d0)) + w2 = fn.add_input_value("w2", Tensor(Float(), (6, 2), device=d0)) + a1 = fn.add_op("FnCall", inputs=[x, w1], subfunctions=[layer]) + _ = fn.add_op("FnCall", inputs=[a1, w2], subfunctions=[layer]) + fn = fn.finalize() + fn = infer_types(fn, fn.inputs) + + device_speeds = {"gpu": 1.0e13} + simulator = Simulator(CostModel(topology, device_speeds)) + simulation = simulator.interpret(fn, (v.type for v in fn.inputs)) + simulation.dump_chrome_trace("test/trace.json") diff --git a/test/test_type_inference.py b/test/test_type_inference.py index 3d0d1bbc..2e010906 100644 --- a/test/test_type_inference.py +++ b/test/test_type_inference.py @@ -201,5 +201,24 @@ def test_scatter(): assert xs.type.types[1].device == d1 +def test_function_call(): + layer = FunctionMaker() + x = layer.add_input_value("x", None) + w = layer.add_input_value("w", None) + _ = layer.add_op("MatMul", inputs=[x, w]) + layer = layer.finalize() + fn = FunctionMaker() + x = fn.add_input_value("x", Tensor(Float(), (4, 5))) + w1 = fn.add_input_value("w1", Tensor(Float(), (5, 6))) + w2 = fn.add_input_value("w2", Tensor(Float(), (6, 2))) + a1 = fn.add_op("FnCall", inputs=[x, w1], subfunctions=[layer]) + _ = fn.add_op("FnCall", inputs=[a1, w2], subfunctions=[layer]) + fn = fn.finalize() + + fn = infer_types(fn, [x, w1, w2]) + y = fn.outputs[0] + assert y.type == Tensor(Float(), (4, 2)) + + if __name__ == "__main__": test_pmap()