From b4d63bf768f9c649b5aca4d2dfe197635dcfab34 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 26 Jan 2021 20:08:08 +0000 Subject: [PATCH 01/46] Implement a `denormalize` custom Jaxpr operator simplifying MCX logpdf by removing constants. --- mcx/core/jaxpr_ops.py | 178 +++++++++++++++++++++++++++++++++++ tests/core/jaxpr_ops_test.py | 111 ++++++++++++++++++++++ 2 files changed, 289 insertions(+) create mode 100644 mcx/core/jaxpr_ops.py create mode 100644 tests/core/jaxpr_ops_test.py diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py new file mode 100644 index 00000000..c3df00c3 --- /dev/null +++ b/mcx/core/jaxpr_ops.py @@ -0,0 +1,178 @@ +"""A collection of operations/transformations on JAX expressions. +""" +import jax.core +import jax.lax +from jax.util import safe_map + +from functools import wraps +from typing import List, Dict, Tuple, Any + +Array = Any + + +def jax_lax_identity(x: Array) -> Array: + """Identity operator. + + Intrinsingly, it seems jax.lax does not have a public identity operation? + """ + return x + + +def jaxpr_find_constvars( + jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] +) -> List[jax.core.Var]: + """Find all intermediates variables in a JAX expression which are expected to be constants. + + Parameters + ---------- + jaxpr: JAX expression. + consts: List of known constant variables in the JAX expression. + + Returns + ------- + List of all intermediate constant variables. + """ + constvars_dict = {str(v): v for v in consts} + for eqn in jaxpr.eqns: + # Are inputs literal or const variables? + is_const_invars = [ + str(v) in constvars_dict or type(v) is jax.core.Literal for v in eqn.invars + ] + if all(is_const_invars): + constvars_dict.update({str(v): v for v in eqn.outvars}) + return list(constvars_dict.values()) + + +def jaxpr_find_denormalize_mapping( + jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] +) -> Dict[jax.core.Var, Tuple[jax.core.Primitive, jax.core.Var]]: + """Find all assignment simplifications in a JAX expression when denormalizing. + + More specifically, this method is looking to simplify `add` and `sub` operations, with output linear + with respect to the Jaxpr outputs, and where one of the input is constant. It returns the simplified mapping + between input and output of `add`/`sub` ops which can be removed. + + Parameters + ---------- + jaxpr: JAX expression. + consts: List of known constant variables in the JAX expression. + + Returns + ------- + Simplified mapping between `add` output and input (with the proper assignment lax op `identity` or `neg`). + """ + denormalize_mapping = {} + # List of linear ops which can be traversed backward from the outputs. + denorm_supported_linear_ops = [ + jax.lax.broadcast_in_dim_p, + jax.lax.broadcast_p, + jax.lax.neg_p, + jax.lax.reshape_p, + jax.lax.squeeze_p, + ] + # Collection of variables linear with respect to the Jaxpr final outputs. + linear_vars = set(jaxpr.outvars) + + # Traversing backward the graph of operations. + for eqn in jaxpr.eqns[::-1]: + if eqn.primitive in denorm_supported_linear_ops: + # Can continue denormalizing inputs if all outputs are in the linear vars collection. + if all([o in linear_vars for o in eqn.outvars]): + linear_vars |= set(eqn.invars) + elif eqn.primitive == jax.lax.add_p and eqn.outvars[0] in linear_vars: + lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] + # Mapping the output to the non-const input. + if lhs_invar in consts or type(lhs_invar) is jax.core.Literal: + linear_vars.add(rhs_invar) + denormalize_mapping[eqn.outvars[0]] = (jax_lax_identity, rhs_invar) + elif rhs_invar in consts or type(rhs_invar) is jax.core.Literal: + linear_vars.add(lhs_invar) + denormalize_mapping[eqn.outvars[0]] = (jax_lax_identity, lhs_invar) + elif eqn.primitive == jax.lax.sub_p and eqn.outvars[0] in linear_vars: + lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] + # Mapping the output to the non-const input (or the negative). + if lhs_invar in consts or type(lhs_invar) is jax.core.Literal: + linear_vars.add(rhs_invar) + denormalize_mapping[eqn.outvars[0]] = (jax.lax.neg, rhs_invar) + elif rhs_invar in consts or type(rhs_invar) is jax.core.Literal: + linear_vars.add(lhs_invar) + denormalize_mapping[eqn.outvars[0]] = (jax_lax_identity, lhs_invar) + + return denormalize_mapping + + +def jaxpr_denormalize(jaxpr, consts, *args): + """Denormalize a Jaxpr, i.e. removing any normalizing constant added to the output. + + This method is analysing the Jaxpr graph, simplifying it by skipping any unnecessary constant + addition, and then it runs the method step-by-step to get the output values. + + Parameters + ---------- + jaxpr: JAX expression. + consts: Values assigned to the Jaxpr constant variables. + args: Input values to the method. + + Returns + ------- + Output values of the denormalized logpdf. + """ + # Denormalized simplification mapping. + denorm_mapping = jaxpr_find_denormalize_mapping(jaxpr, jaxpr.constvars) + # Mapping from variable -> value + env = {} + + def read(var): + # Literals are values baked into the Jaxpr + if type(var) is jax.core.Literal: + return var.val + return env[var] + + def write(var, val): + env[var] = val + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + safe_map(write, jaxpr.invars, args) + safe_map(write, jaxpr.constvars, consts) + + # Similar to a classic eval Jaxpr loop, just skipping the op with mapping available + for eqn in jaxpr.eqns: + if len(eqn.outvars) == 1 and eqn.outvars[0] in denorm_mapping: + # Output registered: skip the primitive and map directly to one of the input. + outvar = eqn.outvars[0] + map_primitive, map_invar = ( + denorm_mapping[outvar][0], + denorm_mapping[outvar][1], + ) + # Mapping the inval to output var (identity or neg). + inval = read(map_invar) + outval = map_primitive(inval) + write(outvar, outval) + else: + # Usual map: calling the primitive and mapping the output values. + invals = safe_map(read, eqn.invars) + outvals = eqn.primitive.bind(*invals, **eqn.params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + safe_map(write, eqn.outvars, outvals) + # Read the final result of the Jaxpr from the environment + return safe_map(read, jaxpr.outvars) + + +def denormalize(logpdf_fn): + """Denormalizing decorator for MCX logpdfs. + + The method returned by the `denormalize` decorator is a simplification of the Jaxpr, + removing any constant in the output logpdf. + """ + + @wraps(logpdf_fn) + def wrapped(*args, **kwargs): + # TODO: flattening/unflattening of inputs/outputs? + closed_jaxpr = jax.make_jaxpr(logpdf_fn)(*args, **kwargs) + out = jaxpr_denormalize(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) + # Assuming a single output at the moment? + return out[0] + + return wrapped diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py new file mode 100644 index 00000000..16435671 --- /dev/null +++ b/tests/core/jaxpr_ops_test.py @@ -0,0 +1,111 @@ +import jax +import jax.lax +import numpy as onp +import pytest +from jax import numpy as np + +from mcx.core.jaxpr_ops import ( + jax_lax_identity, + jaxpr_find_constvars, + jaxpr_find_denormalize_mapping, + denormalize, +) + + +def test__jaxpr_find_constvars__propagate_constants(): + def foo(x): + return x + np.ones((2,)) + np.exp(2.0) + + typed_jaxpr = jax.make_jaxpr(foo)(1.0) + + # All inputs consts, outputs should be consts! + constvars = jaxpr_find_constvars( + typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.invars + typed_jaxpr.jaxpr.constvars + ) + for outvar in typed_jaxpr.jaxpr.outvars: + assert outvar in constvars + + +denorm_expected_add_mapping_op = [ + {"fn": lambda x: x + 1.0, "expected_op": jax_lax_identity}, + {"fn": lambda x: 1.0 + x, "expected_op": jax_lax_identity}, + {"fn": lambda x: x - 1.0, "expected_op": jax_lax_identity}, + {"fn": lambda x: 1.0 - x, "expected_op": jax.lax.neg}, +] + + +@pytest.mark.parametrize("case", denorm_expected_add_mapping_op) +def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): + typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) + denorm_map = jaxpr_find_denormalize_mapping( + typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars + ) + invar = typed_jaxpr.jaxpr.invars[0] + outvar = typed_jaxpr.jaxpr.outvars[0] + + # Proper mapping of the output to the input. + assert len(denorm_map) == 1 + assert outvar in denorm_map + assert denorm_map[outvar][0] == case["expected_op"] + assert denorm_map[outvar][1] == invar + + +denorm_linear_op_propagating = [ + {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, + {"fn": lambda x: np.expand_dims(1.0 - x, axis=0), "expected_op": jax.lax.neg}, + {"fn": lambda x: np.reshape(1.0 - x, (1, 1)), "expected_op": jax.lax.neg}, + { + "fn": lambda x: np.squeeze(np.expand_dims(1.0 - x, axis=0)), + "expected_op": jax.lax.neg, + }, +] + + +@pytest.mark.parametrize("case", denorm_linear_op_propagating) +def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping(case): + typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) + denorm_map = jaxpr_find_denormalize_mapping( + typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars + ) + invar = typed_jaxpr.jaxpr.invars[0] + + # Proper mapping of the output to the input. + assert len(denorm_map) == 1 + map_op, map_invar = list(denorm_map.values())[0] + assert map_op == case["expected_op"] + assert map_invar == invar + + +denorm_non_linear_fn = [ + {"fn": lambda x: np.sin(x + 1.0)}, + {"fn": lambda x: np.abs(x + 1.0)}, + {"fn": lambda x: np.exp(x + 1.0)}, + {"fn": lambda x: x * (x + 1.0)}, +] + + +@pytest.mark.parametrize("case", denorm_non_linear_fn) +def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): + typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) + denorm_map = jaxpr_find_denormalize_mapping( + typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars + ) + assert len(denorm_map) == 0 + + +denormalize_test_cases = [ + {"fn": lambda x: x + 1.0, "exp_denorm_fn": lambda x: x}, + { + "fn": lambda x: 2.0 - np.sin(x + 1.0), + "exp_denorm_fn": lambda x: -np.sin(x + 1.0), + }, +] + + +@pytest.mark.parametrize("case", denormalize_test_cases) +def test__denormalize__proper_simplication(case): + denorm_fn = denormalize(case["fn"]) + exp_denorm_fn = case["exp_denorm_fn"] + + inval = 1.0 + assert np.allclose(denorm_fn(inval), exp_denorm_fn(inval)) From 6398e8ffbc40a62820ca64deb096ac1a348e65d6 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 31 Jan 2021 15:47:35 +0000 Subject: [PATCH 02/46] wip --- mcx/core/jaxpr_ops.py | 1 + tests/core/jaxpr_ops_test.py | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index c3df00c3..ec7209f2 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -69,6 +69,7 @@ def jaxpr_find_denormalize_mapping( jax.lax.neg_p, jax.lax.reshape_p, jax.lax.squeeze_p, + jax.lax.reduce_sum_p, ] # Collection of variables linear with respect to the Jaxpr final outputs. linear_vars = set(jaxpr.outvars) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 16435671..8d7d4ae4 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -94,10 +94,21 @@ def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): denormalize_test_cases = [ - {"fn": lambda x: x + 1.0, "exp_denorm_fn": lambda x: x}, + {"fn": lambda x: x + 1.0, "denorm_fn": lambda x: x, "inval": 2.0}, { "fn": lambda x: 2.0 - np.sin(x + 1.0), - "exp_denorm_fn": lambda x: -np.sin(x + 1.0), + "denorm_fn": lambda x: -np.sin(x + 1.0), + "inval": 2.0, + }, + { + "fn": lambda x: 2.0 - np.sin(x + 1.0), + "denorm_fn": lambda x: -np.sin(x + 1.0), + "inval": 2.0, + }, + { + "fn": lambda x: np.sum(x + 2.0), + "denorm_fn": lambda x: np.sum(x), + "inval": np.ones((10,)), }, ] @@ -105,7 +116,6 @@ def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): @pytest.mark.parametrize("case", denormalize_test_cases) def test__denormalize__proper_simplication(case): denorm_fn = denormalize(case["fn"]) - exp_denorm_fn = case["exp_denorm_fn"] - - inval = 1.0 - assert np.allclose(denorm_fn(inval), exp_denorm_fn(inval)) + expected_denorm_fn = case["denorm_fn"] + inval = case["inval"] + assert np.allclose(denorm_fn(inval), expected_denorm_fn(inval)) From 04cc11f36fc93cf02bd02c0fed222e4e9e37111e Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 6 Feb 2021 14:35:09 +0000 Subject: [PATCH 03/46] wip --- mcx/core/jaxpr_ops.py | 59 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index ec7209f2..7e1bd96c 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -5,9 +5,66 @@ from jax.util import safe_map from functools import wraps -from typing import List, Dict, Tuple, Any +from typing import List, Dict, Tuple, Any, Type, TypeVar, Callable Array = Any +TState = TypeVar("TState") + +jaxpr_high_order_primitives_to_subjaxprs = { + jax.lax.cond_p: lambda jxpr: (None,), + jax.lax.while_p: None, + jax.lax.scan_p: None, + jax.core.CallPrimitive: None, # xla_call, from jax.jit + jax.core.MapPrimitive: None, +} +"""Collection of high-order Jax primitives, with sub-Jaxprs. +""" + + +def jaxpr_visitor( + jaxpr: jax.core.Jaxpr, + initial_state: TState, + visitor_fn: Callable[[jax.core.JaxprEqn, TState, Any], TState], + init_sub_state_fn: Callable[[jax.core.JaxprEqn, TState], List[TState]], + reverse: bool = False, +) -> Tuple[TState, List[Any]]: + """Visitor pattern on a Jaxpr, traversing equations and supporting higher-order primitives + with sub-Jaxprs. + + Parameters + ---------- + initial_state: Initial state to feed to the visitor method. + visitor_fn: Visitor function, taking an input state and Jaxpr, outputting an updated state. + init_sub_state_fn: Initializing method for higher-order primitives sub-Jaxprs. Taking as input + the existing state, and outputting input states to respective sub-Jaxprs. + reverse: Traverse the Jaxpr equations in reverse order. + Returns + ------- + Output state of the last iteration. + """ + state = initial_state + subjaxprs_visit = [] + + equations = jax.eqns if not reverse else jax.eqns[::-1] + for eqn in equations: + if eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs: + sub_jaxprs = jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive] + sub_states = init_sub_state_fn(eqn, state) + # Map visitor method to each sub-jaxpr. + res_sub_states = [ + jaxpr_visitor( + sub_jaxpr, sub_state, visitor_fn, init_sub_state_fn, reverse + ) + for sub_jaxpr, sub_state in zip(sub_jaxprs, sub_states) + ] + # Reduce, to update the current state. + sate = visitor_fn(eqn, state, res_sub_states) + subjaxprs_visit.append(res_sub_states) + else: + # Common Jaxpr equation: apply the visitor and update state. + state = visitor_fn(eqn, state, None) + subjaxprs_visit.append(None) + return state, subjaxprs_visit def jax_lax_identity(x: Array) -> Array: From ffb7329ddee3cf4d3c9b8b0000c46325d24af773 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 6 Feb 2021 19:37:13 +0000 Subject: [PATCH 04/46] wip --- mcx/core/jaxpr_ops.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 7e1bd96c..ee713a70 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -11,14 +11,20 @@ TState = TypeVar("TState") jaxpr_high_order_primitives_to_subjaxprs = { - jax.lax.cond_p: lambda jxpr: (None,), - jax.lax.while_p: None, - jax.lax.scan_p: None, - jax.core.CallPrimitive: None, # xla_call, from jax.jit - jax.core.MapPrimitive: None, + jax.lax.cond_p: lambda jxpr: jxpr.params["branches"], + jax.lax.while_p: lambda jxpr: ( + jxpr.params["cond_jaxpr"], + jxpr.params["body_jaxpr"], + ), + jax.lax.scan_p: lambda jxpr: (jxpr.params["jaxpr"],), + jax.core.CallPrimitive: lambda jxpr: ( + jxpr.params["call_jaxpr"], + ), # xla_call, from jax.jit + jax.core.MapPrimitive: lambda jxpr: (jxpr.params["call_jaxpr"],), } """Collection of high-order Jax primitives, with sub-Jaxprs. """ +jaxpr_high_order_primitives = set(jaxpr_high_order_primitives_to_subjaxprs.keys()) def jaxpr_visitor( @@ -57,7 +63,7 @@ def jaxpr_visitor( ) for sub_jaxpr, sub_state in zip(sub_jaxprs, sub_states) ] - # Reduce, to update the current state. + # Reduce to update the current state. sate = visitor_fn(eqn, state, res_sub_states) subjaxprs_visit.append(res_sub_states) else: From c2c53a429948e36d5b5141bf9e3f033297bb6afa Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 6 Feb 2021 19:37:46 +0000 Subject: [PATCH 05/46] wip --- mcx/core/jaxpr_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index ee713a70..88c04aa3 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -64,7 +64,7 @@ def jaxpr_visitor( for sub_jaxpr, sub_state in zip(sub_jaxprs, sub_states) ] # Reduce to update the current state. - sate = visitor_fn(eqn, state, res_sub_states) + state = visitor_fn(eqn, state, res_sub_states) subjaxprs_visit.append(res_sub_states) else: # Common Jaxpr equation: apply the visitor and update state. From 5d8b681a27e80e617975daff51e218b1f8e9c6db Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 6 Feb 2021 19:56:00 +0000 Subject: [PATCH 06/46] wip --- mcx/core/jaxpr_ops.py | 79 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 69 insertions(+), 10 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 88c04aa3..1b1d92a2 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -27,11 +27,19 @@ jaxpr_high_order_primitives = set(jaxpr_high_order_primitives_to_subjaxprs.keys()) +def jax_lax_identity(x: Array) -> Array: + """Identity operator. + + Intrinsingly, it seems jax.lax does not have a public identity operation? + """ + return x + + def jaxpr_visitor( jaxpr: jax.core.Jaxpr, initial_state: TState, visitor_fn: Callable[[jax.core.JaxprEqn, TState, Any], TState], - init_sub_state_fn: Callable[[jax.core.JaxprEqn, TState], List[TState]], + init_sub_states_fn: Callable[[jax.core.JaxprEqn, TState], List[TState]], reverse: bool = False, ) -> Tuple[TState, List[Any]]: """Visitor pattern on a Jaxpr, traversing equations and supporting higher-order primitives @@ -41,7 +49,7 @@ def jaxpr_visitor( ---------- initial_state: Initial state to feed to the visitor method. visitor_fn: Visitor function, taking an input state and Jaxpr, outputting an updated state. - init_sub_state_fn: Initializing method for higher-order primitives sub-Jaxprs. Taking as input + init_sub_states_fn: Initializing method for higher-order primitives sub-Jaxprs. Taking as input the existing state, and outputting input states to respective sub-Jaxprs. reverse: Traverse the Jaxpr equations in reverse order. Returns @@ -51,15 +59,15 @@ def jaxpr_visitor( state = initial_state subjaxprs_visit = [] - equations = jax.eqns if not reverse else jax.eqns[::-1] + equations = jaxpr.eqns if not reverse else jaxpr.eqns[::-1] for eqn in equations: if eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs: sub_jaxprs = jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive] - sub_states = init_sub_state_fn(eqn, state) + sub_states = init_sub_states_fn(eqn, state) # Map visitor method to each sub-jaxpr. res_sub_states = [ jaxpr_visitor( - sub_jaxpr, sub_state, visitor_fn, init_sub_state_fn, reverse + sub_jaxpr, sub_state, visitor_fn, init_sub_states_fn, reverse ) for sub_jaxpr, sub_state in zip(sub_jaxprs, sub_states) ] @@ -73,12 +81,46 @@ def jaxpr_visitor( return state, subjaxprs_visit -def jax_lax_identity(x: Array) -> Array: - """Identity operator. +def jaxpr_find_constvars_visitor_fn( + eqn: jax.core.JaxprEqn, + state: Dict[Any, bool], + sub_states: List[Tuple[Dict, Any]] = None, +) -> Dict[Any, bool]: + """fdsafads""" + primitive_type = type(eqn.primitive) + + # Reduce logic of high level primitives: combine the results to update the state. + if ( + eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs + or primitive_type in jaxpr_high_order_primitives_to_subjaxprs + ): + if primitive_type == jax.core.CallPrimitive: + # Jit compiled sub-jaxpr. + sub_jaxpr = eqn.params["call_jaxpr"] + sub_state = sub_states[0][0] + for eqn_outvar, sub_outvar in zip(eqn.outvars, sub_jaxpr.outvars): + # Add a constant variables if marked constant in the sub-jaxpr. + if sub_outvar in sub_state: + state[eqn_outvar] = sub_state[sub_outvar] + else: + # TODO: support other primitive. No constants marked at the moment. + pass + return state + + # Common ops logic: are inputs literal or const variables? + is_const_invars = [ + str(v) in state or type(v) is jax.core.Literal for v in eqn.invars + ] + if all(is_const_invars): + state.update({v: False for v in eqn.outvars}) + return state - Intrinsingly, it seems jax.lax does not have a public identity operation? - """ - return x + +def jaxpr_find_constvars_init_sub_states_fn( + eqn: jax.core.JaxprEqn, state: Dict[Any, bool] +) -> List[Dict[Any, bool]]: + """fdsafads""" + pass def jaxpr_find_constvars( @@ -91,6 +133,23 @@ def jaxpr_find_constvars( jaxpr: JAX expression. consts: List of known constant variables in the JAX expression. + Returns + ------- + List of all intermediate constant variables. + """ + pass + + +def jaxpr_find_constvars_old( + jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] +) -> List[jax.core.Var]: + """Find all intermediates variables in a JAX expression which are expected to be constants. + + Parameters + ---------- + jaxpr: JAX expression. + consts: List of known constant variables in the JAX expression. + Returns ------- List of all intermediate constant variables. From 098ca4f43a9b564d9789fbf3b1b1031068ee06d2 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 6 Feb 2021 20:05:46 +0000 Subject: [PATCH 07/46] wip --- mcx/core/jaxpr_ops.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 1b1d92a2..d274f0fb 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -103,7 +103,7 @@ def jaxpr_find_constvars_visitor_fn( if sub_outvar in sub_state: state[eqn_outvar] = sub_state[sub_outvar] else: - # TODO: support other primitive. No constants marked at the moment. + # TODO: support other high primitive. No constants marked at the moment. pass return state @@ -120,7 +120,22 @@ def jaxpr_find_constvars_init_sub_states_fn( eqn: jax.core.JaxprEqn, state: Dict[Any, bool] ) -> List[Dict[Any, bool]]: """fdsafads""" - pass + # Mapping the current state to the sub-jaxprs. + primitive_type = type(eqn.primitive) + sub_jaxprs = jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive] + sub_init_states = None + + if primitive_type == jax.core.CallPrimitive: + # Jit compiled sub-jaxpr: map eqn inputs to sub-jaxpr inputs. + sub_init_state = {} + for eqn_invar, sub_invar in zip(eqn.invars, sub_jaxprs[0].invars): + # Add a constant variables if marked constant in the sub-jaxpr. + if eqn_invar in state or type(sub_invar) is jax.core.Literal: + sub_init_state[sub_invar] = False + else: + # TODO: support other high primitives. + sub_init_states = [{} for _ in range(len(sub_jaxprs))] + return sub_init_states def jaxpr_find_constvars( From 2f125daf8769e5a38942e73c1acb7cc52e281dbc Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 6 Feb 2021 20:13:15 +0000 Subject: [PATCH 08/46] wip --- mcx/core/jaxpr_ops.py | 73 ++++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index d274f0fb..5039a571 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -10,6 +10,8 @@ Array = Any TState = TypeVar("TState") +ConstVarState = Dict[jax.core.Var, bool] + jaxpr_high_order_primitives_to_subjaxprs = { jax.lax.cond_p: lambda jxpr: jxpr.params["branches"], jax.lax.while_p: lambda jxpr: ( @@ -38,8 +40,9 @@ def jax_lax_identity(x: Array) -> Array: def jaxpr_visitor( jaxpr: jax.core.Jaxpr, initial_state: TState, - visitor_fn: Callable[[jax.core.JaxprEqn, TState, Any], TState], - init_sub_states_fn: Callable[[jax.core.JaxprEqn, TState], List[TState]], + visitor_fn: Callable[[jax.core.JaxprEqn, TState], TState], + map_sub_states_fn: Callable[[jax.core.JaxprEqn, TState], List[TState]], + reduce_sub_states_fn: Callable[[jax.core.JaxprEqn, TState, List[TState]], TState], reverse: bool = False, ) -> Tuple[TState, List[Any]]: """Visitor pattern on a Jaxpr, traversing equations and supporting higher-order primitives @@ -62,51 +65,36 @@ def jaxpr_visitor( equations = jaxpr.eqns if not reverse else jaxpr.eqns[::-1] for eqn in equations: if eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs: + init_sub_states = map_sub_states_fn(eqn, state) sub_jaxprs = jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive] - sub_states = init_sub_states_fn(eqn, state) # Map visitor method to each sub-jaxpr. res_sub_states = [ jaxpr_visitor( - sub_jaxpr, sub_state, visitor_fn, init_sub_states_fn, reverse + sub_jaxpr, + sub_state, + visitor_fn, + map_sub_states_fn, + reduce_sub_states_fn, + reverse, ) - for sub_jaxpr, sub_state in zip(sub_jaxprs, sub_states) + for sub_jaxpr, sub_state in zip(sub_jaxprs, init_sub_states) ] # Reduce to update the current state. - state = visitor_fn(eqn, state, res_sub_states) + state = reduce_sub_states_fn(eqn, state, [v[0] for v in res_sub_states]) subjaxprs_visit.append(res_sub_states) else: # Common Jaxpr equation: apply the visitor and update state. - state = visitor_fn(eqn, state, None) + state = visitor_fn(eqn, state) subjaxprs_visit.append(None) return state, subjaxprs_visit def jaxpr_find_constvars_visitor_fn( eqn: jax.core.JaxprEqn, - state: Dict[Any, bool], - sub_states: List[Tuple[Dict, Any]] = None, -) -> Dict[Any, bool]: + state: ConstVarState, + sub_states: List[Tuple[ConstVarState, Any]] = None, +) -> ConstVarState: """fdsafads""" - primitive_type = type(eqn.primitive) - - # Reduce logic of high level primitives: combine the results to update the state. - if ( - eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs - or primitive_type in jaxpr_high_order_primitives_to_subjaxprs - ): - if primitive_type == jax.core.CallPrimitive: - # Jit compiled sub-jaxpr. - sub_jaxpr = eqn.params["call_jaxpr"] - sub_state = sub_states[0][0] - for eqn_outvar, sub_outvar in zip(eqn.outvars, sub_jaxpr.outvars): - # Add a constant variables if marked constant in the sub-jaxpr. - if sub_outvar in sub_state: - state[eqn_outvar] = sub_state[sub_outvar] - else: - # TODO: support other high primitive. No constants marked at the moment. - pass - return state - # Common ops logic: are inputs literal or const variables? is_const_invars = [ str(v) in state or type(v) is jax.core.Literal for v in eqn.invars @@ -116,9 +104,9 @@ def jaxpr_find_constvars_visitor_fn( return state -def jaxpr_find_constvars_init_sub_states_fn( - eqn: jax.core.JaxprEqn, state: Dict[Any, bool] -) -> List[Dict[Any, bool]]: +def jaxpr_find_constvars_map_sub_states_fn( + eqn: jax.core.JaxprEqn, state: ConstVarState +) -> List[ConstVarState]: """fdsafads""" # Mapping the current state to the sub-jaxprs. primitive_type = type(eqn.primitive) @@ -138,6 +126,25 @@ def jaxpr_find_constvars_init_sub_states_fn( return sub_init_states +def jaxpr_find_constvars_reduce_sub_states_fn( + eqn: jax.core.JaxprEqn, state: ConstVarState, sub_states: List[ConstVarState] +) -> ConstVarState: + """fdsafads""" + primitive_type = type(eqn.primitive) + if primitive_type == jax.core.CallPrimitive: + # Jit compiled sub-jaxpr. + sub_jaxpr = eqn.params["call_jaxpr"] + sub_state = sub_states[0][0] + for eqn_outvar, sub_outvar in zip(eqn.outvars, sub_jaxpr.outvars): + # Add a constant variables if marked constant in the sub-jaxpr. + if sub_outvar in sub_state: + state[eqn_outvar] = sub_state[sub_outvar] + else: + # TODO: support other high primitive. No constants marked at the moment. + pass + return state + + def jaxpr_find_constvars( jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] ) -> List[jax.core.Var]: From 5369cc75bbd12a907a19fc69b3b2d26864b2725d Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 6 Feb 2021 20:14:50 +0000 Subject: [PATCH 09/46] wip --- mcx/core/jaxpr_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 5039a571..cab1f065 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -121,7 +121,7 @@ def jaxpr_find_constvars_map_sub_states_fn( if eqn_invar in state or type(sub_invar) is jax.core.Literal: sub_init_state[sub_invar] = False else: - # TODO: support other high primitives. + # TODO: support other high primitives. No constants passed at the moment. sub_init_states = [{} for _ in range(len(sub_jaxprs))] return sub_init_states @@ -140,7 +140,7 @@ def jaxpr_find_constvars_reduce_sub_states_fn( if sub_outvar in sub_state: state[eqn_outvar] = sub_state[sub_outvar] else: - # TODO: support other high primitive. No constants marked at the moment. + # TODO: support other high primitives. No constants passed at the moment. pass return state From 3652b600cc898e827993faeb9ad4b6a57003ae75 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 7 Feb 2021 12:32:06 +0000 Subject: [PATCH 10/46] wip --- mcx/core/jaxpr_ops.py | 51 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index cab1f065..0e8bf5ff 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -37,6 +37,40 @@ def jax_lax_identity(x: Array) -> Array: return x +def jax_is_high_order_primitive( + eqn: jax.core.JaxprEqn, +) -> bool: + """Is the input Jaxpr equation corresponding to a high-order Jax primitive? + + Parameters + ---------- + Returns + ------- + """ + is_high_order = (eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs) or ( + type(eqn.primitive) in jaxpr_high_order_primitives_to_subjaxprs + ) + return is_high_order + + +def jaxpr_find_sub_jaxprs( + eqn: jax.core.JaxprEqn, +) -> List[jax.core.Jaxpr]: + """Is the input Jaxpr equation corresponding to a high-order Jax primitive? + + Parameters + ---------- + Returns + ------- + """ + primitive_type = type(eqn.primitive) + if eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs: + return jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive] + elif primitive_type in jaxpr_high_order_primitives_to_subjaxprs: + return jaxpr_high_order_primitives_to_subjaxprs[primitive_type] + return [] + + def jaxpr_visitor( jaxpr: jax.core.Jaxpr, initial_state: TState, @@ -64,9 +98,9 @@ def jaxpr_visitor( equations = jaxpr.eqns if not reverse else jaxpr.eqns[::-1] for eqn in equations: - if eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs: + if jax_is_high_order_primitive(eqn): init_sub_states = map_sub_states_fn(eqn, state) - sub_jaxprs = jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive] + sub_jaxprs = jaxpr_find_sub_jaxprs[eqn] # Map visitor method to each sub-jaxpr. res_sub_states = [ jaxpr_visitor( @@ -110,7 +144,7 @@ def jaxpr_find_constvars_map_sub_states_fn( """fdsafads""" # Mapping the current state to the sub-jaxprs. primitive_type = type(eqn.primitive) - sub_jaxprs = jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive] + sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) sub_init_states = None if primitive_type == jax.core.CallPrimitive: @@ -159,7 +193,16 @@ def jaxpr_find_constvars( ------- List of all intermediate constant variables. """ - pass + const_state = {} + const_state, _ = jaxpr_visitor( + jaxpr, + const_state, + jaxpr_find_constvars_visitor_fn, + jaxpr_find_constvars_map_sub_states_fn, + jaxpr_find_constvars_reduce_sub_states_fn, + reverse=False, + ) + return list(const_state) def jaxpr_find_constvars_old( From cf60bf3c4aac3953c7b8dabfbef17297012f7a72 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 7 Feb 2021 12:44:40 +0000 Subject: [PATCH 11/46] wip --- mcx/core/jaxpr_ops.py | 9 +- tests/core/jaxpr_ops_test.py | 188 ++++++++++++++++++----------------- 2 files changed, 99 insertions(+), 98 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 0e8bf5ff..3145c964 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -126,13 +126,11 @@ def jaxpr_visitor( def jaxpr_find_constvars_visitor_fn( eqn: jax.core.JaxprEqn, state: ConstVarState, - sub_states: List[Tuple[ConstVarState, Any]] = None, ) -> ConstVarState: """fdsafads""" # Common ops logic: are inputs literal or const variables? - is_const_invars = [ - str(v) in state or type(v) is jax.core.Literal for v in eqn.invars - ] + # NOTE: Jax literal are not hashable! + is_const_invars = [type(v) is jax.core.Literal or v in state for v in eqn.invars] if all(is_const_invars): state.update({v: False for v in eqn.outvars}) return state @@ -193,7 +191,8 @@ def jaxpr_find_constvars( ------- List of all intermediate constant variables. """ - const_state = {} + # Start with the list of input constants. + const_state = {c: False for c in consts} const_state, _ = jaxpr_visitor( jaxpr, const_state, diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 8d7d4ae4..607ffe1e 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -22,100 +22,102 @@ def foo(x): constvars = jaxpr_find_constvars( typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.invars + typed_jaxpr.jaxpr.constvars ) + print(typed_jaxpr) + print(constvars, typed_jaxpr.jaxpr.invars, typed_jaxpr.jaxpr.outvars) for outvar in typed_jaxpr.jaxpr.outvars: assert outvar in constvars -denorm_expected_add_mapping_op = [ - {"fn": lambda x: x + 1.0, "expected_op": jax_lax_identity}, - {"fn": lambda x: 1.0 + x, "expected_op": jax_lax_identity}, - {"fn": lambda x: x - 1.0, "expected_op": jax_lax_identity}, - {"fn": lambda x: 1.0 - x, "expected_op": jax.lax.neg}, -] - - -@pytest.mark.parametrize("case", denorm_expected_add_mapping_op) -def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): - typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) - denorm_map = jaxpr_find_denormalize_mapping( - typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars - ) - invar = typed_jaxpr.jaxpr.invars[0] - outvar = typed_jaxpr.jaxpr.outvars[0] - - # Proper mapping of the output to the input. - assert len(denorm_map) == 1 - assert outvar in denorm_map - assert denorm_map[outvar][0] == case["expected_op"] - assert denorm_map[outvar][1] == invar - - -denorm_linear_op_propagating = [ - {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, - {"fn": lambda x: np.expand_dims(1.0 - x, axis=0), "expected_op": jax.lax.neg}, - {"fn": lambda x: np.reshape(1.0 - x, (1, 1)), "expected_op": jax.lax.neg}, - { - "fn": lambda x: np.squeeze(np.expand_dims(1.0 - x, axis=0)), - "expected_op": jax.lax.neg, - }, -] - - -@pytest.mark.parametrize("case", denorm_linear_op_propagating) -def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping(case): - typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) - denorm_map = jaxpr_find_denormalize_mapping( - typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars - ) - invar = typed_jaxpr.jaxpr.invars[0] - - # Proper mapping of the output to the input. - assert len(denorm_map) == 1 - map_op, map_invar = list(denorm_map.values())[0] - assert map_op == case["expected_op"] - assert map_invar == invar - - -denorm_non_linear_fn = [ - {"fn": lambda x: np.sin(x + 1.0)}, - {"fn": lambda x: np.abs(x + 1.0)}, - {"fn": lambda x: np.exp(x + 1.0)}, - {"fn": lambda x: x * (x + 1.0)}, -] - - -@pytest.mark.parametrize("case", denorm_non_linear_fn) -def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): - typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) - denorm_map = jaxpr_find_denormalize_mapping( - typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars - ) - assert len(denorm_map) == 0 - - -denormalize_test_cases = [ - {"fn": lambda x: x + 1.0, "denorm_fn": lambda x: x, "inval": 2.0}, - { - "fn": lambda x: 2.0 - np.sin(x + 1.0), - "denorm_fn": lambda x: -np.sin(x + 1.0), - "inval": 2.0, - }, - { - "fn": lambda x: 2.0 - np.sin(x + 1.0), - "denorm_fn": lambda x: -np.sin(x + 1.0), - "inval": 2.0, - }, - { - "fn": lambda x: np.sum(x + 2.0), - "denorm_fn": lambda x: np.sum(x), - "inval": np.ones((10,)), - }, -] - - -@pytest.mark.parametrize("case", denormalize_test_cases) -def test__denormalize__proper_simplication(case): - denorm_fn = denormalize(case["fn"]) - expected_denorm_fn = case["denorm_fn"] - inval = case["inval"] - assert np.allclose(denorm_fn(inval), expected_denorm_fn(inval)) +# denorm_expected_add_mapping_op = [ +# {"fn": lambda x: x + 1.0, "expected_op": jax_lax_identity}, +# {"fn": lambda x: 1.0 + x, "expected_op": jax_lax_identity}, +# {"fn": lambda x: x - 1.0, "expected_op": jax_lax_identity}, +# {"fn": lambda x: 1.0 - x, "expected_op": jax.lax.neg}, +# ] + + +# @pytest.mark.parametrize("case", denorm_expected_add_mapping_op) +# def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): +# typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) +# denorm_map = jaxpr_find_denormalize_mapping( +# typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars +# ) +# invar = typed_jaxpr.jaxpr.invars[0] +# outvar = typed_jaxpr.jaxpr.outvars[0] + +# # Proper mapping of the output to the input. +# assert len(denorm_map) == 1 +# assert outvar in denorm_map +# assert denorm_map[outvar][0] == case["expected_op"] +# assert denorm_map[outvar][1] == invar + + +# denorm_linear_op_propagating = [ +# {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, +# {"fn": lambda x: np.expand_dims(1.0 - x, axis=0), "expected_op": jax.lax.neg}, +# {"fn": lambda x: np.reshape(1.0 - x, (1, 1)), "expected_op": jax.lax.neg}, +# { +# "fn": lambda x: np.squeeze(np.expand_dims(1.0 - x, axis=0)), +# "expected_op": jax.lax.neg, +# }, +# ] + + +# @pytest.mark.parametrize("case", denorm_linear_op_propagating) +# def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping(case): +# typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) +# denorm_map = jaxpr_find_denormalize_mapping( +# typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars +# ) +# invar = typed_jaxpr.jaxpr.invars[0] + +# # Proper mapping of the output to the input. +# assert len(denorm_map) == 1 +# map_op, map_invar = list(denorm_map.values())[0] +# assert map_op == case["expected_op"] +# assert map_invar == invar + + +# denorm_non_linear_fn = [ +# {"fn": lambda x: np.sin(x + 1.0)}, +# {"fn": lambda x: np.abs(x + 1.0)}, +# {"fn": lambda x: np.exp(x + 1.0)}, +# {"fn": lambda x: x * (x + 1.0)}, +# ] + + +# @pytest.mark.parametrize("case", denorm_non_linear_fn) +# def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): +# typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) +# denorm_map = jaxpr_find_denormalize_mapping( +# typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars +# ) +# assert len(denorm_map) == 0 + + +# denormalize_test_cases = [ +# {"fn": lambda x: x + 1.0, "denorm_fn": lambda x: x, "inval": 2.0}, +# { +# "fn": lambda x: 2.0 - np.sin(x + 1.0), +# "denorm_fn": lambda x: -np.sin(x + 1.0), +# "inval": 2.0, +# }, +# { +# "fn": lambda x: 2.0 - np.sin(x + 1.0), +# "denorm_fn": lambda x: -np.sin(x + 1.0), +# "inval": 2.0, +# }, +# { +# "fn": lambda x: np.sum(x + 2.0), +# "denorm_fn": lambda x: np.sum(x), +# "inval": np.ones((10,)), +# }, +# ] + + +# @pytest.mark.parametrize("case", denormalize_test_cases) +# def test__denormalize__proper_simplication(case): +# denorm_fn = denormalize(case["fn"]) +# expected_denorm_fn = case["denorm_fn"] +# inval = case["inval"] +# assert np.allclose(denorm_fn(inval), expected_denorm_fn(inval)) From ba4b560db65830c554a2a0c4ecdf6b9687431b90 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 7 Feb 2021 12:47:33 +0000 Subject: [PATCH 12/46] wip --- mcx/core/jaxpr_ops.py | 7 ++++--- tests/core/jaxpr_ops_test.py | 6 ++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 3145c964..ca82f5bb 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -65,9 +65,9 @@ def jaxpr_find_sub_jaxprs( """ primitive_type = type(eqn.primitive) if eqn.primitive in jaxpr_high_order_primitives_to_subjaxprs: - return jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive] + return jaxpr_high_order_primitives_to_subjaxprs[eqn.primitive](eqn) elif primitive_type in jaxpr_high_order_primitives_to_subjaxprs: - return jaxpr_high_order_primitives_to_subjaxprs[primitive_type] + return jaxpr_high_order_primitives_to_subjaxprs[primitive_type](eqn) return [] @@ -100,7 +100,7 @@ def jaxpr_visitor( for eqn in equations: if jax_is_high_order_primitive(eqn): init_sub_states = map_sub_states_fn(eqn, state) - sub_jaxprs = jaxpr_find_sub_jaxprs[eqn] + sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) # Map visitor method to each sub-jaxpr. res_sub_states = [ jaxpr_visitor( @@ -148,6 +148,7 @@ def jaxpr_find_constvars_map_sub_states_fn( if primitive_type == jax.core.CallPrimitive: # Jit compiled sub-jaxpr: map eqn inputs to sub-jaxpr inputs. sub_init_state = {} + print(sub_jaxprs) for eqn_invar, sub_invar in zip(eqn.invars, sub_jaxprs[0].invars): # Add a constant variables if marked constant in the sub-jaxpr. if eqn_invar in state or type(sub_invar) is jax.core.Literal: diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 607ffe1e..431a7d18 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -14,6 +14,12 @@ def test__jaxpr_find_constvars__propagate_constants(): def foo(x): + @jax.jit + def g(y): + return y + np.ones((2,)) + + return g(x) + np.exp(2.0) + return x + np.ones((2,)) + np.exp(2.0) typed_jaxpr = jax.make_jaxpr(foo)(1.0) From 3d80f5caca4bc84e878e2fbdc7e3dd8fa6311567 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 7 Feb 2021 19:40:55 +0000 Subject: [PATCH 13/46] wip --- mcx/core/jaxpr_ops.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index ca82f5bb..fa79a054 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -143,20 +143,19 @@ def jaxpr_find_constvars_map_sub_states_fn( # Mapping the current state to the sub-jaxprs. primitive_type = type(eqn.primitive) sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) - sub_init_states = None if primitive_type == jax.core.CallPrimitive: # Jit compiled sub-jaxpr: map eqn inputs to sub-jaxpr inputs. sub_init_state = {} - print(sub_jaxprs) for eqn_invar, sub_invar in zip(eqn.invars, sub_jaxprs[0].invars): # Add a constant variables if marked constant in the sub-jaxpr. if eqn_invar in state or type(sub_invar) is jax.core.Literal: sub_init_state[sub_invar] = False + return [sub_init_state] else: # TODO: support other high primitives. No constants passed at the moment. sub_init_states = [{} for _ in range(len(sub_jaxprs))] - return sub_init_states + return sub_init_states def jaxpr_find_constvars_reduce_sub_states_fn( @@ -167,7 +166,7 @@ def jaxpr_find_constvars_reduce_sub_states_fn( if primitive_type == jax.core.CallPrimitive: # Jit compiled sub-jaxpr. sub_jaxpr = eqn.params["call_jaxpr"] - sub_state = sub_states[0][0] + sub_state = sub_states[0] for eqn_outvar, sub_outvar in zip(eqn.outvars, sub_jaxpr.outvars): # Add a constant variables if marked constant in the sub-jaxpr. if sub_outvar in sub_state: From f39bce6ee3ca2c705a4cd95dbab108f071501b5a Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 7 Feb 2021 19:46:55 +0000 Subject: [PATCH 14/46] wip --- tests/core/jaxpr_ops_test.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 431a7d18..db64c038 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -12,24 +12,24 @@ ) -def test__jaxpr_find_constvars__propagate_constants(): - def foo(x): - @jax.jit - def g(y): - return y + np.ones((2,)) +find_constvars_test_functions = [ + # Simple constant propagation. + {"fn": lambda x: x + np.ones((2,)) + np.exp(2.0)}, + # Handle properly jax.jit sub-jaxpr. + {"fn": lambda x: jax.jit(lambda y: y + np.ones((2,)))(x) + np.exp(2.0)}, + # TODO: test pmap, while, scan, cond. +] - return g(x) + np.exp(2.0) - return x + np.ones((2,)) + np.exp(2.0) - - typed_jaxpr = jax.make_jaxpr(foo)(1.0) +@pytest.mark.parametrize("case", find_constvars_test_functions) +def test__jaxpr_find_constvars__propagate_constants(case): + test_fn = case["fn"] + typed_jaxpr = jax.make_jaxpr(test_fn)(1.0) # All inputs consts, outputs should be consts! constvars = jaxpr_find_constvars( typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.invars + typed_jaxpr.jaxpr.constvars ) - print(typed_jaxpr) - print(constvars, typed_jaxpr.jaxpr.invars, typed_jaxpr.jaxpr.outvars) for outvar in typed_jaxpr.jaxpr.outvars: assert outvar in constvars From 29435e7d996ac6aa220c4afb70f6ca8a57fc0605 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 7 Feb 2021 19:58:56 +0000 Subject: [PATCH 15/46] wip --- mcx/core/jaxpr_ops.py | 57 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index fa79a054..6cf30bb4 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -8,9 +8,12 @@ from typing import List, Dict, Tuple, Any, Type, TypeVar, Callable Array = Any +"""Generic Array type. +""" TState = TypeVar("TState") +"""Generic Jaxpr visitor state. +""" -ConstVarState = Dict[jax.core.Var, bool] jaxpr_high_order_primitives_to_subjaxprs = { jax.lax.cond_p: lambda jxpr: jxpr.params["branches"], @@ -89,6 +92,7 @@ def jaxpr_visitor( init_sub_states_fn: Initializing method for higher-order primitives sub-Jaxprs. Taking as input the existing state, and outputting input states to respective sub-Jaxprs. reverse: Traverse the Jaxpr equations in reverse order. + Returns ------- Output state of the last iteration. @@ -123,11 +127,29 @@ def jaxpr_visitor( return state, subjaxprs_visit +ConstVarState = Dict[jax.core.Var, bool] +"""Const variables visitor state: dictionary associating const variables with their status. +""" + + def jaxpr_find_constvars_visitor_fn( eqn: jax.core.JaxprEqn, state: ConstVarState, ) -> ConstVarState: - """fdsafads""" + """Jaxpr find const variables visitor method: propagating constant through ops. + + This method is implementing a very simple logic, assuming that as method in Jax should be pure + functions, any output of a function with constant inputs is constant. + + Parameters + ---------- + eqn: Jaxpr equation. + state: Current collection of constant variables. + + Returns + ------- + Updated constant variables collection with outputs of the Jaxpr equation. + """ # Common ops logic: are inputs literal or const variables? # NOTE: Jax literal are not hashable! is_const_invars = [type(v) is jax.core.Literal or v in state for v in eqn.invars] @@ -139,7 +161,20 @@ def jaxpr_find_constvars_visitor_fn( def jaxpr_find_constvars_map_sub_states_fn( eqn: jax.core.JaxprEqn, state: ConstVarState ) -> List[ConstVarState]: - """fdsafads""" + """Map the constant variables collection to sub-jaxprs initial constant collections. + + The method is performing a simple mapping of constant variables of the main jaxpr to the inputs + of the sub-jaxprs. + + Parameters + ---------- + eqn: Jaxpr equation with high order primitive (xla_call, ...). + state: Constant variables collection. + + Returns + ------- + List of initial const variable states corresponding to each sub-jaxpr. + """ # Mapping the current state to the sub-jaxprs. primitive_type = type(eqn.primitive) sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) @@ -161,7 +196,21 @@ def jaxpr_find_constvars_map_sub_states_fn( def jaxpr_find_constvars_reduce_sub_states_fn( eqn: jax.core.JaxprEqn, state: ConstVarState, sub_states: List[ConstVarState] ) -> ConstVarState: - """fdsafads""" + """Reduce the collection of sub-jaxpr const variables states to update to main jaxpr state. + + The method is performing a simple update of the main jaxpr state using the result of the + sub-jaxprs (i.e. whether the latter are constants). + + Parameters + ---------- + eqn: Main jaxpr equation. + state: Main jaxpr current const variables state. + sub_states: Sub-jaxprs final const variables states. + + Returns + ------- + Updated main Jaxpr constant variables state. + """ primitive_type = type(eqn.primitive) if primitive_type == jax.core.CallPrimitive: # Jit compiled sub-jaxpr. From 9b9f5fb636ade69592fd291cdb81a4e012cf8da1 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 8 Feb 2021 19:47:09 +0000 Subject: [PATCH 16/46] wip --- mcx/core/jaxpr_ops.py | 21 +++++++++++++++++---- tests/core/jaxpr_ops_test.py | 12 ++++++++++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 6cf30bb4..40fa6e1c 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -4,6 +4,7 @@ import jax.lax from jax.util import safe_map +import enum from functools import wraps from typing import List, Dict, Tuple, Any, Type, TypeVar, Callable @@ -127,7 +128,19 @@ def jaxpr_visitor( return state, subjaxprs_visit -ConstVarState = Dict[jax.core.Var, bool] +class ConstVarStatus(enum.Enum): + """Const variable status. + + For the application to constant simplification in logpdfs, we do not need a full constant + folding in the graph of operations (which can be quite expensive, in computation and memory), but + just to keep track of constant variables and whether these are non-finite or not. + """ + + Unknown = 0 + NonFinite = 1 + + +ConstVarState = Dict[jax.core.Var, ConstVarStatus] """Const variables visitor state: dictionary associating const variables with their status. """ @@ -154,7 +167,7 @@ def jaxpr_find_constvars_visitor_fn( # NOTE: Jax literal are not hashable! is_const_invars = [type(v) is jax.core.Literal or v in state for v in eqn.invars] if all(is_const_invars): - state.update({v: False for v in eqn.outvars}) + state.update({v: ConstVarStatus.Unknown for v in eqn.outvars}) return state @@ -228,7 +241,7 @@ def jaxpr_find_constvars_reduce_sub_states_fn( def jaxpr_find_constvars( jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] -) -> List[jax.core.Var]: +) -> Dict[jax.core.Var, ConstVarStatus]: """Find all intermediates variables in a JAX expression which are expected to be constants. Parameters @@ -250,7 +263,7 @@ def jaxpr_find_constvars( jaxpr_find_constvars_reduce_sub_states_fn, reverse=False, ) - return list(const_state) + return const_state def jaxpr_find_constvars_old( diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index db64c038..feac6f96 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -9,14 +9,20 @@ jaxpr_find_constvars, jaxpr_find_denormalize_mapping, denormalize, + ConstVarStatus, ) find_constvars_test_functions = [ # Simple constant propagation. - {"fn": lambda x: x + np.ones((2,)) + np.exp(2.0)}, + {"fn": lambda x: x + np.ones((2,)) + np.exp(2.0), "status": ConstVarStatus.Unknown}, # Handle properly jax.jit sub-jaxpr. - {"fn": lambda x: jax.jit(lambda y: y + np.ones((2,)))(x) + np.exp(2.0)}, + { + "fn": lambda x: jax.jit(lambda y: y + np.ones((2,)))(x) + np.exp(2.0), + "status": ConstVarStatus.Unknown, + }, + # Simple inf constant propagation. + {"fn": lambda x: x + np.ones((2,)) + np.inf, "status": ConstVarStatus.NonFinite}, # TODO: test pmap, while, scan, cond. ] @@ -24,6 +30,7 @@ @pytest.mark.parametrize("case", find_constvars_test_functions) def test__jaxpr_find_constvars__propagate_constants(case): test_fn = case["fn"] + expected_status = case["status"] typed_jaxpr = jax.make_jaxpr(test_fn)(1.0) # All inputs consts, outputs should be consts! @@ -32,6 +39,7 @@ def test__jaxpr_find_constvars__propagate_constants(case): ) for outvar in typed_jaxpr.jaxpr.outvars: assert outvar in constvars + assert constvars[outvar] == expected_status # denorm_expected_add_mapping_op = [ From 89d76d23adabccb8aef2bc8b871c9aed08f11faa Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 8 Feb 2021 20:09:18 +0000 Subject: [PATCH 17/46] wip --- mcx/core/jaxpr_ops.py | 45 +++++++++++++++++++++++++++++------- tests/core/jaxpr_ops_test.py | 7 +++--- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 40fa6e1c..713129aa 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -3,7 +3,9 @@ import jax.core import jax.lax from jax.util import safe_map +import numpy as np +import copy import enum from functools import wraps from typing import List, Dict, Tuple, Any, Type, TypeVar, Callable @@ -163,11 +165,30 @@ def jaxpr_find_constvars_visitor_fn( ------- Updated constant variables collection with outputs of the Jaxpr equation. """ + + def get_var_status(v) -> ConstVarStatus: + if type(v) is jax.core.Literal: + # Non-finite if all entries are non-finite. + return ( + ConstVarStatus.Unknown + if np.any(np.isfinite(v.val)) + else ConstVarStatus.NonFinite + ) + return state.get(v, ConstVarStatus.Unknown) + # Common ops logic: are inputs literal or const variables? # NOTE: Jax literal are not hashable! is_const_invars = [type(v) is jax.core.Literal or v in state for v in eqn.invars] + status_invars = [get_var_status(v) for v in eqn.invars] if all(is_const_invars): - state.update({v: ConstVarStatus.Unknown for v in eqn.outvars}) + # Using a form of heuristic here: outputs are non-finite if one the input is. Should + # refine this logic per op. + outvar_status = ( + ConstVarStatus.NonFinite + if any([s == ConstVarStatus.NonFinite for s in status_invars]) + else ConstVarStatus.Unknown + ) + state.update({v: outvar_status for v in eqn.outvars}) return state @@ -196,9 +217,16 @@ def jaxpr_find_constvars_map_sub_states_fn( # Jit compiled sub-jaxpr: map eqn inputs to sub-jaxpr inputs. sub_init_state = {} for eqn_invar, sub_invar in zip(eqn.invars, sub_jaxprs[0].invars): - # Add a constant variables if marked constant in the sub-jaxpr. - if eqn_invar in state or type(sub_invar) is jax.core.Literal: - sub_init_state[sub_invar] = False + if eqn_invar in state: + # Add a constant variables if marked constant in the sub-jaxpr. + sub_init_state[sub_invar] = state[eqn_invar] + elif type(sub_invar) is jax.core.Literal: + # Literal argument: check the value fo the status. + sub_init_state[sub_invar] = ( + ConstVarStatus.Unknown + if np.any(np.isfinite(sub_invar.val)) + else ConstVarStatus.NonFinite + ) return [sub_init_state] else: # TODO: support other high primitives. No constants passed at the moment. @@ -240,21 +268,22 @@ def jaxpr_find_constvars_reduce_sub_states_fn( def jaxpr_find_constvars( - jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] + jaxpr: jax.core.Jaxpr, constvars: Dict[jax.core.Var, ConstVarStatus] ) -> Dict[jax.core.Var, ConstVarStatus]: """Find all intermediates variables in a JAX expression which are expected to be constants. Parameters ---------- jaxpr: JAX expression. - consts: List of known constant variables in the JAX expression. + constvars: List of known constant variables in the JAX expression. Returns ------- List of all intermediate constant variables. """ - # Start with the list of input constants. - const_state = {c: False for c in consts} + # Start with the collection of input constants. + const_state = copy.copy(constvars) + print(jaxpr) const_state, _ = jaxpr_visitor( jaxpr, const_state, diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index feac6f96..770b2856 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -34,9 +34,10 @@ def test__jaxpr_find_constvars__propagate_constants(case): typed_jaxpr = jax.make_jaxpr(test_fn)(1.0) # All inputs consts, outputs should be consts! - constvars = jaxpr_find_constvars( - typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.invars + typed_jaxpr.jaxpr.constvars - ) + constvars = {v: ConstVarStatus.Unknown for v in typed_jaxpr.jaxpr.invars} + constvars.update({v: ConstVarStatus.Unknown for v in typed_jaxpr.jaxpr.constvars}) + + constvars = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) for outvar in typed_jaxpr.jaxpr.outvars: assert outvar in constvars assert constvars[outvar] == expected_status From da25e8ebeff19aca7432a892f71dbf8102ee162d Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 8 Feb 2021 20:10:10 +0000 Subject: [PATCH 18/46] wip --- tests/core/jaxpr_ops_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 770b2856..4d03e023 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -23,6 +23,11 @@ }, # Simple inf constant propagation. {"fn": lambda x: x + np.ones((2,)) + np.inf, "status": ConstVarStatus.NonFinite}, + # Handle properly jax.jit sub-jaxpr. + { + "fn": lambda x: jax.jit(lambda y: y + np.full((2,), np.inf))(x) + np.exp(2.0), + "status": ConstVarStatus.NonFinite, + }, # TODO: test pmap, while, scan, cond. ] From 8113a51a49aac50ea185fafb029afb5db5f5459b Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 9 Feb 2021 19:28:13 +0000 Subject: [PATCH 19/46] wip --- mcx/core/jaxpr_ops.py | 62 +++++++++++------------------------- tests/core/jaxpr_ops_test.py | 2 +- 2 files changed, 19 insertions(+), 45 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 713129aa..02f5ec02 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -147,6 +147,18 @@ class ConstVarStatus(enum.Enum): """ +def get_variable_const_status(v: Any, state: ConstVarState) -> ConstVarStatus: + """Get the constant status on a variable (or literal).""" + if type(v) is jax.core.Literal: + # Non-finite if all entries are non-finite. + return ( + ConstVarStatus.Unknown + if np.any(np.isfinite(v.val)) + else ConstVarStatus.NonFinite + ) + return state.get(v, None) + + def jaxpr_find_constvars_visitor_fn( eqn: jax.core.JaxprEqn, state: ConstVarState, @@ -166,27 +178,18 @@ def jaxpr_find_constvars_visitor_fn( Updated constant variables collection with outputs of the Jaxpr equation. """ - def get_var_status(v) -> ConstVarStatus: - if type(v) is jax.core.Literal: - # Non-finite if all entries are non-finite. - return ( - ConstVarStatus.Unknown - if np.any(np.isfinite(v.val)) - else ConstVarStatus.NonFinite - ) - return state.get(v, ConstVarStatus.Unknown) - # Common ops logic: are inputs literal or const variables? # NOTE: Jax literal are not hashable! is_const_invars = [type(v) is jax.core.Literal or v in state for v in eqn.invars] - status_invars = [get_var_status(v) for v in eqn.invars] + status_invars = [get_variable_const_status(v, state) for v in eqn.invars] if all(is_const_invars): # Using a form of heuristic here: outputs are non-finite if one the input is. Should # refine this logic per op. + any_non_finite_invar = any( + [s == ConstVarStatus.NonFinite for s in status_invars] + ) outvar_status = ( - ConstVarStatus.NonFinite - if any([s == ConstVarStatus.NonFinite for s in status_invars]) - else ConstVarStatus.Unknown + ConstVarStatus.NonFinite if any_non_finite_invar else ConstVarStatus.Unknown ) state.update({v: outvar_status for v in eqn.outvars}) return state @@ -222,11 +225,7 @@ def jaxpr_find_constvars_map_sub_states_fn( sub_init_state[sub_invar] = state[eqn_invar] elif type(sub_invar) is jax.core.Literal: # Literal argument: check the value fo the status. - sub_init_state[sub_invar] = ( - ConstVarStatus.Unknown - if np.any(np.isfinite(sub_invar.val)) - else ConstVarStatus.NonFinite - ) + sub_init_state[sub_invar] = get_variable_const_status(sub_invar, None) return [sub_init_state] else: # TODO: support other high primitives. No constants passed at the moment. @@ -295,31 +294,6 @@ def jaxpr_find_constvars( return const_state -def jaxpr_find_constvars_old( - jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] -) -> List[jax.core.Var]: - """Find all intermediates variables in a JAX expression which are expected to be constants. - - Parameters - ---------- - jaxpr: JAX expression. - consts: List of known constant variables in the JAX expression. - - Returns - ------- - List of all intermediate constant variables. - """ - constvars_dict = {str(v): v for v in consts} - for eqn in jaxpr.eqns: - # Are inputs literal or const variables? - is_const_invars = [ - str(v) in constvars_dict or type(v) is jax.core.Literal for v in eqn.invars - ] - if all(is_const_invars): - constvars_dict.update({str(v): v for v in eqn.outvars}) - return list(constvars_dict.values()) - - def jaxpr_find_denormalize_mapping( jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] ) -> Dict[jax.core.Var, Tuple[jax.core.Primitive, jax.core.Var]]: diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 4d03e023..ab94431b 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -23,7 +23,7 @@ }, # Simple inf constant propagation. {"fn": lambda x: x + np.ones((2,)) + np.inf, "status": ConstVarStatus.NonFinite}, - # Handle properly jax.jit sub-jaxpr. + # Handle properly jax.jit sub-jaxpr + inf constant. { "fn": lambda x: jax.jit(lambda y: y + np.full((2,), np.inf))(x) + np.exp(2.0), "status": ConstVarStatus.NonFinite, From f1c310e9f3ca9b173f9066bd62fa9ea64de1b6a6 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 9 Feb 2021 19:33:39 +0000 Subject: [PATCH 20/46] wip --- tests/core/jaxpr_ops_test.py | 86 +++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index ab94431b..b33dc098 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -1,3 +1,4 @@ +from typing import TypedDict import jax import jax.lax import numpy as onp @@ -48,54 +49,59 @@ def test__jaxpr_find_constvars__propagate_constants(case): assert constvars[outvar] == expected_status -# denorm_expected_add_mapping_op = [ -# {"fn": lambda x: x + 1.0, "expected_op": jax_lax_identity}, -# {"fn": lambda x: 1.0 + x, "expected_op": jax_lax_identity}, -# {"fn": lambda x: x - 1.0, "expected_op": jax_lax_identity}, -# {"fn": lambda x: 1.0 - x, "expected_op": jax.lax.neg}, -# ] +denorm_expected_add_mapping_op = [ + {"fn": lambda x: x + 1.0, "expected_op": jax_lax_identity}, + {"fn": lambda x: 1.0 + x, "expected_op": jax_lax_identity}, + {"fn": lambda x: x - 1.0, "expected_op": jax_lax_identity}, + {"fn": lambda x: 1.0 - x, "expected_op": jax.lax.neg}, +] -# @pytest.mark.parametrize("case", denorm_expected_add_mapping_op) -# def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): -# typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) -# denorm_map = jaxpr_find_denormalize_mapping( -# typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars -# ) -# invar = typed_jaxpr.jaxpr.invars[0] -# outvar = typed_jaxpr.jaxpr.outvars[0] +@pytest.mark.parametrize("case", denorm_expected_add_mapping_op) +def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): + typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) + denorm_map = jaxpr_find_denormalize_mapping( + typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars + ) + invar = typed_jaxpr.jaxpr.invars[0] + outvar = typed_jaxpr.jaxpr.outvars[0] -# # Proper mapping of the output to the input. -# assert len(denorm_map) == 1 -# assert outvar in denorm_map -# assert denorm_map[outvar][0] == case["expected_op"] -# assert denorm_map[outvar][1] == invar + # Proper mapping of the output to the input. + assert len(denorm_map) == 1 + assert outvar in denorm_map + assert denorm_map[outvar][0] == case["expected_op"] + assert denorm_map[outvar][1] == invar -# denorm_linear_op_propagating = [ -# {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, -# {"fn": lambda x: np.expand_dims(1.0 - x, axis=0), "expected_op": jax.lax.neg}, -# {"fn": lambda x: np.reshape(1.0 - x, (1, 1)), "expected_op": jax.lax.neg}, -# { -# "fn": lambda x: np.squeeze(np.expand_dims(1.0 - x, axis=0)), -# "expected_op": jax.lax.neg, -# }, -# ] +denorm_linear_op_propagating = [ + {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, + {"fn": lambda x: np.expand_dims(1.0 - x, axis=0), "expected_op": jax.lax.neg}, + {"fn": lambda x: np.reshape(1.0 - x, (1, 1)), "expected_op": jax.lax.neg}, + { + "fn": lambda x: np.squeeze(np.expand_dims(1.0 - x, axis=0)), + "expected_op": jax.lax.neg, + }, + {"fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), "expected_op": jax.lax.neg}, + {"fn": lambda x: 2.0 * (1.0 - x), "expected_op": jax.lax.neg}, + {"fn": lambda x: (1.0 - x) / 2.0, "expected_op": jax.lax.neg}, +] -# @pytest.mark.parametrize("case", denorm_linear_op_propagating) -# def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping(case): -# typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) -# denorm_map = jaxpr_find_denormalize_mapping( -# typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars -# ) -# invar = typed_jaxpr.jaxpr.invars[0] +@pytest.mark.parametrize("case", denorm_linear_op_propagating) +def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping(case): + typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) + denorm_map = jaxpr_find_denormalize_mapping( + typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars + ) + invar = typed_jaxpr.jaxpr.invars[0] + + print(typed_jaxpr) -# # Proper mapping of the output to the input. -# assert len(denorm_map) == 1 -# map_op, map_invar = list(denorm_map.values())[0] -# assert map_op == case["expected_op"] -# assert map_invar == invar + # Proper mapping of the output to the input. + assert len(denorm_map) == 1 + map_op, map_invar = list(denorm_map.values())[0] + assert map_op == case["expected_op"] + assert map_invar == invar # denorm_non_linear_fn = [ From d6e452b122809dee375dfc48d22363137c405fcf Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 9 Feb 2021 19:35:12 +0000 Subject: [PATCH 21/46] wip --- tests/core/jaxpr_ops_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index b33dc098..26793963 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -82,7 +82,7 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): "expected_op": jax.lax.neg, }, {"fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), "expected_op": jax.lax.neg}, - {"fn": lambda x: 2.0 * (1.0 - x), "expected_op": jax.lax.neg}, + {"fn": lambda x: np.full((2,), 2.0) * (1.0 - x), "expected_op": jax.lax.neg}, {"fn": lambda x: (1.0 - x) / 2.0, "expected_op": jax.lax.neg}, ] From 51c7ed5d30854ad4a3c7a3468d71827d600aac87 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 10 Feb 2021 19:56:40 +0000 Subject: [PATCH 22/46] wip --- mcx/core/jaxpr_ops.py | 73 +++++++++++++++++++++++------------- tests/core/jaxpr_ops_test.py | 41 +++++++++++++------- 2 files changed, 73 insertions(+), 41 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 02f5ec02..2dbcea66 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -7,6 +7,8 @@ import copy import enum + +from dataclasses import dataclass from functools import wraps from typing import List, Dict, Tuple, Any, Type, TypeVar, Callable @@ -130,32 +132,43 @@ def jaxpr_visitor( return state, subjaxprs_visit -class ConstVarStatus(enum.Enum): - """Const variable status. +@dataclass +class ConstVarInfo: + """Const variable additional information. - For the application to constant simplification in logpdfs, we do not need a full constant + For the application of constant simplification in logpdfs, we do not need a full constant folding in the graph of operations (which can be quite expensive, in computation and memory), but - just to keep track of constant variables and whether these are non-finite or not. + just to keep track of constant variables, and some additional information: + + Parameters + ---------- + is_non_finite: whether the constant is non-finite. false by default. + is_uniform: whether the constant is a uniform tensor. false by default. """ - Unknown = 0 - NonFinite = 1 + is_non_finite: bool = False + is_uniform: bool = False -ConstVarState = Dict[jax.core.Var, ConstVarStatus] -"""Const variables visitor state: dictionary associating const variables with their status. +ConstVarState = Dict[jax.core.Var, ConstVarInfo] +"""Const variables visitor state: dictionary associating const variables with their info. """ -def get_variable_const_status(v: Any, state: ConstVarState) -> ConstVarStatus: - """Get the constant status on a variable (or literal).""" +def get_variable_const_info(v: Any, state: ConstVarState) -> ConstVarInfo: + """Get the constant info on a variable (or literal). + + Parameters + ---------- + v: Input variable or literal + state: Constant variable state. + + Returns + ------- + Optional const info. None if not a constant variable. + """ if type(v) is jax.core.Literal: - # Non-finite if all entries are non-finite. - return ( - ConstVarStatus.Unknown - if np.any(np.isfinite(v.val)) - else ConstVarStatus.NonFinite - ) + return ConstVarInfo(is_non_finite=not bool(np.isfinite(v.val)), is_uniform=True) return state.get(v, None) @@ -181,17 +194,23 @@ def jaxpr_find_constvars_visitor_fn( # Common ops logic: are inputs literal or const variables? # NOTE: Jax literal are not hashable! is_const_invars = [type(v) is jax.core.Literal or v in state for v in eqn.invars] - status_invars = [get_variable_const_status(v, state) for v in eqn.invars] + invars_const_info = [get_variable_const_info(v, state) for v in eqn.invars] if all(is_const_invars): - # Using a form of heuristic here: outputs are non-finite if one the input is. Should - # refine this logic per op. - any_non_finite_invar = any( - [s == ConstVarStatus.NonFinite for s in status_invars] + # Using a form of heuristic here: outputs are non-finite if one the input is. + # TODO: refine this logic per op. + is_non_finite_outvars = any( + [s is not None and s.is_non_finite for s in invars_const_info] + ) + # Another heuristic to refine! if all inputs are uniform, all outputs are uniform. + # TODO: refine the logic per op supported. + is_uniform_outvars = all( + [s is not None and s.is_uniform for s in invars_const_info] ) - outvar_status = ( - ConstVarStatus.NonFinite if any_non_finite_invar else ConstVarStatus.Unknown + + outvar_const_info = ConstVarInfo( + is_non_finite=is_non_finite_outvars, is_uniform=is_uniform_outvars ) - state.update({v: outvar_status for v in eqn.outvars}) + state.update({v: copy.copy(outvar_const_info) for v in eqn.outvars}) return state @@ -225,7 +244,7 @@ def jaxpr_find_constvars_map_sub_states_fn( sub_init_state[sub_invar] = state[eqn_invar] elif type(sub_invar) is jax.core.Literal: # Literal argument: check the value fo the status. - sub_init_state[sub_invar] = get_variable_const_status(sub_invar, None) + sub_init_state[sub_invar] = get_variable_const_info(sub_invar, None) return [sub_init_state] else: # TODO: support other high primitives. No constants passed at the moment. @@ -267,8 +286,8 @@ def jaxpr_find_constvars_reduce_sub_states_fn( def jaxpr_find_constvars( - jaxpr: jax.core.Jaxpr, constvars: Dict[jax.core.Var, ConstVarStatus] -) -> Dict[jax.core.Var, ConstVarStatus]: + jaxpr: jax.core.Jaxpr, constvars: Dict[jax.core.Var, ConstVarInfo] +) -> Dict[jax.core.Var, ConstVarInfo]: """Find all intermediates variables in a JAX expression which are expected to be constants. Parameters diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 26793963..5649e34d 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -3,31 +3,40 @@ import jax.lax import numpy as onp import pytest -from jax import numpy as np +import jax.numpy as jnp +import numpy as np from mcx.core.jaxpr_ops import ( jax_lax_identity, jaxpr_find_constvars, jaxpr_find_denormalize_mapping, denormalize, - ConstVarStatus, + ConstVarInfo, ) find_constvars_test_functions = [ # Simple constant propagation. - {"fn": lambda x: x + np.ones((2,)) + np.exp(2.0), "status": ConstVarStatus.Unknown}, + { + "fn": lambda x: x + np.ones((2,)) + np.exp(2.0), + "info": ConstVarInfo(False, True), + }, + # Simple constant propagation, non-uniform. + { + "fn": lambda x: x + jnp.array([1.0, 2.0]) + np.exp(2.0), + "info": ConstVarInfo(False, False), + }, # Handle properly jax.jit sub-jaxpr. { - "fn": lambda x: jax.jit(lambda y: y + np.ones((2,)))(x) + np.exp(2.0), - "status": ConstVarStatus.Unknown, + "fn": lambda x: jax.jit(lambda y: y + jnp.ones((2,)))(x) + np.exp(2.0), + "info": ConstVarInfo(False, True), }, # Simple inf constant propagation. - {"fn": lambda x: x + np.ones((2,)) + np.inf, "status": ConstVarStatus.NonFinite}, + {"fn": lambda x: x + np.ones((2,)) + np.inf, "info": ConstVarInfo(True, True)}, # Handle properly jax.jit sub-jaxpr + inf constant. { - "fn": lambda x: jax.jit(lambda y: y + np.full((2,), np.inf))(x) + np.exp(2.0), - "status": ConstVarStatus.NonFinite, + "fn": lambda x: jax.jit(lambda y: y + jnp.full((2,), np.inf))(x) + np.exp(2.0), + "info": ConstVarInfo(True, True), }, # TODO: test pmap, while, scan, cond. ] @@ -36,17 +45,21 @@ @pytest.mark.parametrize("case", find_constvars_test_functions) def test__jaxpr_find_constvars__propagate_constants(case): test_fn = case["fn"] - expected_status = case["status"] + expected_const_info = case["info"] typed_jaxpr = jax.make_jaxpr(test_fn)(1.0) + print(typed_jaxpr.consts, typed_jaxpr.jaxpr.constvars) + # All inputs consts, outputs should be consts! - constvars = {v: ConstVarStatus.Unknown for v in typed_jaxpr.jaxpr.invars} - constvars.update({v: ConstVarStatus.Unknown for v in typed_jaxpr.jaxpr.constvars}) + constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.invars} + constvars.update( + {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} + ) constvars = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) for outvar in typed_jaxpr.jaxpr.outvars: assert outvar in constvars - assert constvars[outvar] == expected_status + assert constvars[outvar] == expected_const_info denorm_expected_add_mapping_op = [ @@ -81,9 +94,9 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): "fn": lambda x: np.squeeze(np.expand_dims(1.0 - x, axis=0)), "expected_op": jax.lax.neg, }, - {"fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), "expected_op": jax.lax.neg}, + # {"fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), "expected_op": jax.lax.neg}, {"fn": lambda x: np.full((2,), 2.0) * (1.0 - x), "expected_op": jax.lax.neg}, - {"fn": lambda x: (1.0 - x) / 2.0, "expected_op": jax.lax.neg}, + {"fn": lambda x: (1.0 - x) / (np.ones((2,)) * 2.0), "expected_op": jax.lax.neg}, ] From c370c661410c3a989a0810c0666461b1ffc6fa9b Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 11 Feb 2021 19:38:53 +0000 Subject: [PATCH 23/46] wip --- tests/core/jaxpr_ops_test.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 5649e34d..54365b81 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -21,11 +21,11 @@ "fn": lambda x: x + np.ones((2,)) + np.exp(2.0), "info": ConstVarInfo(False, True), }, - # Simple constant propagation, non-uniform. - { - "fn": lambda x: x + jnp.array([1.0, 2.0]) + np.exp(2.0), - "info": ConstVarInfo(False, False), - }, + # TODO: Simple constant propagation, non-uniform. + # { + # "fn": lambda x: x + jnp.array([1.0, 2.0]) + np.exp(2.0), + # "info": ConstVarInfo(False, False), + # }, # Handle properly jax.jit sub-jaxpr. { "fn": lambda x: jax.jit(lambda y: y + jnp.ones((2,)))(x) + np.exp(2.0), @@ -88,15 +88,15 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): denorm_linear_op_propagating = [ {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, - {"fn": lambda x: np.expand_dims(1.0 - x, axis=0), "expected_op": jax.lax.neg}, - {"fn": lambda x: np.reshape(1.0 - x, (1, 1)), "expected_op": jax.lax.neg}, + {"fn": lambda x: jnp.expand_dims(1.0 - x, axis=0), "expected_op": jax.lax.neg}, + {"fn": lambda x: jnp.reshape(1.0 - x, (1, 1)), "expected_op": jax.lax.neg}, { - "fn": lambda x: np.squeeze(np.expand_dims(1.0 - x, axis=0)), + "fn": lambda x: jnp.squeeze(jnp.expand_dims(1.0 - x, axis=0)), "expected_op": jax.lax.neg, }, # {"fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), "expected_op": jax.lax.neg}, - {"fn": lambda x: np.full((2,), 2.0) * (1.0 - x), "expected_op": jax.lax.neg}, - {"fn": lambda x: (1.0 - x) / (np.ones((2,)) * 2.0), "expected_op": jax.lax.neg}, + # {"fn": lambda x: jnp.full((2,), 2.0) * (1.0 - x), "expected_op": jax.lax.neg}, + # {"fn": lambda x: (1.0 - x) / (jnp.ones((2,)) * 2.0), "expected_op": jax.lax.neg}, ] From 6c635e28bca1f0c2c353d96ada819b0b26714afb Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 13 Feb 2021 19:09:36 +0000 Subject: [PATCH 24/46] wip --- mcx/core/jaxpr_ops.py | 133 +++++++++++++++++++++++++++++++++-- tests/core/jaxpr_ops_test.py | 21 +++--- 2 files changed, 139 insertions(+), 15 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 2dbcea66..b195dc65 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from functools import wraps -from typing import List, Dict, Tuple, Any, Type, TypeVar, Callable +from typing import List, Dict, Optional, Set, Tuple, Any, Type, TypeVar, Callable Array = Any """Generic Array type. @@ -18,7 +18,10 @@ TState = TypeVar("TState") """Generic Jaxpr visitor state. """ - +TRecState = Tuple[TState, List[Optional[List["TRecState"]]]] +"""Full recursive state, representing the visitor state of the Jaxpr as well as +the sub-states of all sub-jaxprs. +""" jaxpr_high_order_primitives_to_subjaxprs = { jax.lax.cond_p: lambda jxpr: jxpr.params["branches"], @@ -86,7 +89,7 @@ def jaxpr_visitor( map_sub_states_fn: Callable[[jax.core.JaxprEqn, TState], List[TState]], reduce_sub_states_fn: Callable[[jax.core.JaxprEqn, TState, List[TState]], TState], reverse: bool = False, -) -> Tuple[TState, List[Any]]: +) -> TRecState: """Visitor pattern on a Jaxpr, traversing equations and supporting higher-order primitives with sub-Jaxprs. @@ -153,6 +156,10 @@ class ConstVarInfo: ConstVarState = Dict[jax.core.Var, ConstVarInfo] """Const variables visitor state: dictionary associating const variables with their info. """ +ConstVarRecState = Tuple[ConstVarState, List[Optional[List["ConstVarRecState"]]]] + + +# Garthoks = Union[Garthok, Iterable['Garthoks']] def get_variable_const_info(v: Any, state: ConstVarState) -> ConstVarInfo: @@ -287,7 +294,7 @@ def jaxpr_find_constvars_reduce_sub_states_fn( def jaxpr_find_constvars( jaxpr: jax.core.Jaxpr, constvars: Dict[jax.core.Var, ConstVarInfo] -) -> Dict[jax.core.Var, ConstVarInfo]: +) -> ConstVarRecState: """Find all intermediates variables in a JAX expression which are expected to be constants. Parameters @@ -302,7 +309,7 @@ def jaxpr_find_constvars( # Start with the collection of input constants. const_state = copy.copy(constvars) print(jaxpr) - const_state, _ = jaxpr_visitor( + const_rec_state = jaxpr_visitor( jaxpr, const_state, jaxpr_find_constvars_visitor_fn, @@ -310,10 +317,124 @@ def jaxpr_find_constvars( jaxpr_find_constvars_reduce_sub_states_fn, reverse=False, ) - return const_state + return const_rec_state + + +DenormalizeState = Tuple[ + Dict[jax.core.Var, Tuple[Any, jax.core.Var]], Set[jax.core.Var], ConstVarRecState +] +"""Denormalization state, combination of: + - dictionary of variable mapping, corresponding to `add` or `sub` ops which can be simplified; + - set of variables which can be traverse backward for denormalization; + - full recursive const variable state of the Jaxpr. +""" +DenormalizeRecState = Tuple[ + DenormalizeState, List[Optional[List["DenormalizeRecState"]]] +] + +denorm_supported_linear_ops = [ + jax.lax.broadcast_in_dim_p, + jax.lax.broadcast_p, + jax.lax.neg_p, + jax.lax.reshape_p, + jax.lax.squeeze_p, + jax.lax.reduce_sum_p, +] + + +def jaxpr_denorm_mapping_visitor_fn( + eqn: jax.core.JaxprEqn, + state: DenormalizeState, +) -> DenormalizeState: + """pass + fdsafas + """ + # Un-stack input complex input state! + denorm_map_dict, denorm_valid_vars, constvar_full_state = state + constvar_state, _ = constvar_full_state + + def is_var_constant(v: Any) -> bool: + return type(v) is jax.core.Literal or v in constvar_state + + if eqn.primitive in denorm_supported_linear_ops: + # Can continue denormalizing inputs if all outputs are in the linear vars collection. + if all([o in denorm_valid_vars for o in eqn.outvars]): + denorm_valid_vars |= set(eqn.invars) + elif eqn.primitive == jax.lax.add_p and eqn.outvars[0] in denorm_valid_vars: + lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] + # Mapping the output to the non-const input. + if is_var_constant(lhs_invar): + denorm_valid_vars.add(rhs_invar) + denorm_map_dict[eqn.outvars[0]] = (jax_lax_identity, rhs_invar) + elif is_var_constant(rhs_invar): + denorm_valid_vars.add(lhs_invar) + denorm_map_dict[eqn.outvars[0]] = (jax_lax_identity, lhs_invar) + elif eqn.primitive == jax.lax.sub_p and eqn.outvars[0] in denorm_valid_vars: + lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] + # Mapping the output to the non-const input (or the negative). + if is_var_constant(lhs_invar): + denorm_valid_vars.add(rhs_invar) + denorm_map_dict[eqn.outvars[0]] = (jax.lax.neg, rhs_invar) + elif is_var_constant(rhs_invar): + denorm_valid_vars.add(lhs_invar) + denorm_map_dict[eqn.outvars[0]] = (jax_lax_identity, lhs_invar) + + # Re-construct updated state. + return (denorm_map_dict, denorm_valid_vars, constvar_full_state) + + +def jaxpr_denorm_mapping_map_sub_states_fn( + eqn: jax.core.JaxprEqn, state: DenormalizeState +) -> List[DenormalizeState]: + """""" + denorm_map_dict, denorm_valid_vars, constvar_full_state = state + sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) + # TODO: fix properly. + return [state for _ in sub_jaxprs] + + +def jaxpr_denorm_mapping_reduce_sub_states_fn( + eqn: jax.core.JaxprEqn, state: DenormalizeState, sub_states: List[DenormalizeState] +) -> DenormalizeState: + """""" + # TODO: fix properly. + return state def jaxpr_find_denormalize_mapping( + jaxpr: jax.core.Jaxpr, constvar_state: ConstVarRecState +) -> DenormalizeRecState: + """Find all assignment simplifications in a JAX expression when denormalizing. + + More specifically, this method is looking to simplify `add` and `sub` operations, with output linear + with respect to the Jaxpr outputs, and where one of the input is constant. It returns the simplified mapping + between input and output of `add`/`sub` ops which can be removed. + + Parameters + ---------- + jaxpr: JAX expression. + consts: List of known constant variables in the JAX expression. + + Returns + ------- + Simplified mapping between `add` output and input (with the proper assignment lax op `identity` or `neg`). + """ + # Initialize the denormalize state, starting from the ouput variables. + denormalize_mapping = {} + denorm_valid_vars = set(jaxpr.outvars) + denorm_state = (denormalize_mapping, denorm_valid_vars, constvar_state) + denorm_rec_state = jaxpr_visitor( + jaxpr, + denorm_state, + jaxpr_denorm_mapping_visitor_fn, + jaxpr_denorm_mapping_map_sub_states_fn, + jaxpr_denorm_mapping_reduce_sub_states_fn, + reverse=True, + ) + return denorm_rec_state + + +def jaxpr_find_denormalize_mapping_old( jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] ) -> Dict[jax.core.Var, Tuple[jax.core.Primitive, jax.core.Var]]: """Find all assignment simplifications in a JAX expression when denormalizing. diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 54365b81..6351e350 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -56,7 +56,7 @@ def test__jaxpr_find_constvars__propagate_constants(case): {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} ) - constvars = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) + constvars, _ = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) for outvar in typed_jaxpr.jaxpr.outvars: assert outvar in constvars assert constvars[outvar] == expected_const_info @@ -73,9 +73,12 @@ def test__jaxpr_find_constvars__propagate_constants(case): @pytest.mark.parametrize("case", denorm_expected_add_mapping_op) def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) - denorm_map = jaxpr_find_denormalize_mapping( - typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars - ) + constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} + constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) + + denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) + denorm_map = denorm_rec_state[0][0] + invar = typed_jaxpr.jaxpr.invars[0] outvar = typed_jaxpr.jaxpr.outvars[0] @@ -103,13 +106,13 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): @pytest.mark.parametrize("case", denorm_linear_op_propagating) def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping(case): typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) - denorm_map = jaxpr_find_denormalize_mapping( - typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars - ) - invar = typed_jaxpr.jaxpr.invars[0] + constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} + constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) - print(typed_jaxpr) + denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) + denorm_map = denorm_rec_state[0][0] + invar = typed_jaxpr.jaxpr.invars[0] # Proper mapping of the output to the input. assert len(denorm_map) == 1 map_op, map_invar = list(denorm_map.values())[0] From 939612968d5b46d86af185da16108ebee89d9ab4 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 13 Feb 2021 19:15:35 +0000 Subject: [PATCH 25/46] wip --- mcx/core/jaxpr_ops.py | 1 - tests/core/jaxpr_ops_test.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index b195dc65..cd57e51f 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -308,7 +308,6 @@ def jaxpr_find_constvars( """ # Start with the collection of input constants. const_state = copy.copy(constvars) - print(jaxpr) const_rec_state = jaxpr_visitor( jaxpr, const_state, diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 6351e350..8940156d 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -77,7 +77,7 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) - denorm_map = denorm_rec_state[0][0] + denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] invar = typed_jaxpr.jaxpr.invars[0] outvar = typed_jaxpr.jaxpr.outvars[0] @@ -87,6 +87,8 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): assert outvar in denorm_map assert denorm_map[outvar][0] == case["expected_op"] assert denorm_map[outvar][1] == invar + # Input is a valid denorm variable (which could be propagated in sub-jaxpr). + assert denorm_valid_vars == {invar, outvar} denorm_linear_op_propagating = [ From c3096c7fe2bf5665643afce421eb66a4a606b99d Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 13 Feb 2021 19:17:03 +0000 Subject: [PATCH 26/46] wip --- tests/core/jaxpr_ops_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 8940156d..74c39d60 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -112,7 +112,7 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) - denorm_map = denorm_rec_state[0][0] + denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] invar = typed_jaxpr.jaxpr.invars[0] # Proper mapping of the output to the input. @@ -120,6 +120,8 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( map_op, map_invar = list(denorm_map.values())[0] assert map_op == case["expected_op"] assert map_invar == invar + # Input is a valid denorm variable (which could be propagated in sub-jaxpr). + assert invar in denorm_valid_vars # denorm_non_linear_fn = [ From 028670554887ee02300e693e882b81b932f4a085 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 13 Feb 2021 19:22:10 +0000 Subject: [PATCH 27/46] wip --- mcx/core/jaxpr_ops.py | 17 +++++++++-------- tests/core/jaxpr_ops_test.py | 3 +++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index cd57e51f..2ecd42a8 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -355,6 +355,10 @@ def jaxpr_denorm_mapping_visitor_fn( def is_var_constant(v: Any) -> bool: return type(v) is jax.core.Literal or v in constvar_state + def denorm_add_sub_op(invar, outvar, replace_op): + denorm_valid_vars.add(invar) + denorm_map_dict[outvar] = (replace_op, invar) + if eqn.primitive in denorm_supported_linear_ops: # Can continue denormalizing inputs if all outputs are in the linear vars collection. if all([o in denorm_valid_vars for o in eqn.outvars]): @@ -363,20 +367,16 @@ def is_var_constant(v: Any) -> bool: lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] # Mapping the output to the non-const input. if is_var_constant(lhs_invar): - denorm_valid_vars.add(rhs_invar) - denorm_map_dict[eqn.outvars[0]] = (jax_lax_identity, rhs_invar) + denorm_add_sub_op(rhs_invar, eqn.outvars[0], jax_lax_identity) elif is_var_constant(rhs_invar): - denorm_valid_vars.add(lhs_invar) - denorm_map_dict[eqn.outvars[0]] = (jax_lax_identity, lhs_invar) + denorm_add_sub_op(lhs_invar, eqn.outvars[0], jax_lax_identity) elif eqn.primitive == jax.lax.sub_p and eqn.outvars[0] in denorm_valid_vars: lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] # Mapping the output to the non-const input (or the negative). if is_var_constant(lhs_invar): - denorm_valid_vars.add(rhs_invar) - denorm_map_dict[eqn.outvars[0]] = (jax.lax.neg, rhs_invar) + denorm_add_sub_op(rhs_invar, eqn.outvars[0], jax.lax.neg) elif is_var_constant(rhs_invar): - denorm_valid_vars.add(lhs_invar) - denorm_map_dict[eqn.outvars[0]] = (jax_lax_identity, lhs_invar) + denorm_add_sub_op(lhs_invar, eqn.outvars[0], jax_lax_identity) # Re-construct updated state. return (denorm_map_dict, denorm_valid_vars, constvar_full_state) @@ -422,6 +422,7 @@ def jaxpr_find_denormalize_mapping( denormalize_mapping = {} denorm_valid_vars = set(jaxpr.outvars) denorm_state = (denormalize_mapping, denorm_valid_vars, constvar_state) + # NOTE: scanning the jaxpr in reverse order. denorm_rec_state = jaxpr_visitor( jaxpr, denorm_state, diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 74c39d60..cf162759 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -93,6 +93,7 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): denorm_linear_op_propagating = [ {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, + {"fn": lambda x: 2.0 * (x + 1.0), "expected_op": jax_lax_identity}, {"fn": lambda x: jnp.expand_dims(1.0 - x, axis=0), "expected_op": jax.lax.neg}, {"fn": lambda x: jnp.reshape(1.0 - x, (1, 1)), "expected_op": jax.lax.neg}, { @@ -111,6 +112,8 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) + print(typed_jaxpr) + denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] From 7f895d551e8bda02a85d0302b36aef3b4a8d6632 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 13 Feb 2021 19:45:37 +0000 Subject: [PATCH 28/46] wip --- mcx/core/jaxpr_ops.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 2ecd42a8..c9ade94f 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -355,7 +355,7 @@ def jaxpr_denorm_mapping_visitor_fn( def is_var_constant(v: Any) -> bool: return type(v) is jax.core.Literal or v in constvar_state - def denorm_add_sub_op(invar, outvar, replace_op): + def denorm_linear_op(invar, outvar, replace_op): denorm_valid_vars.add(invar) denorm_map_dict[outvar] = (replace_op, invar) @@ -367,16 +367,16 @@ def denorm_add_sub_op(invar, outvar, replace_op): lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] # Mapping the output to the non-const input. if is_var_constant(lhs_invar): - denorm_add_sub_op(rhs_invar, eqn.outvars[0], jax_lax_identity) + denorm_linear_op(rhs_invar, eqn.outvars[0], jax_lax_identity) elif is_var_constant(rhs_invar): - denorm_add_sub_op(lhs_invar, eqn.outvars[0], jax_lax_identity) + denorm_linear_op(lhs_invar, eqn.outvars[0], jax_lax_identity) elif eqn.primitive == jax.lax.sub_p and eqn.outvars[0] in denorm_valid_vars: lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] # Mapping the output to the non-const input (or the negative). if is_var_constant(lhs_invar): - denorm_add_sub_op(rhs_invar, eqn.outvars[0], jax.lax.neg) + denorm_linear_op(rhs_invar, eqn.outvars[0], jax.lax.neg) elif is_var_constant(rhs_invar): - denorm_add_sub_op(lhs_invar, eqn.outvars[0], jax_lax_identity) + denorm_linear_op(lhs_invar, eqn.outvars[0], jax_lax_identity) # Re-construct updated state. return (denorm_map_dict, denorm_valid_vars, constvar_full_state) From 542177d0bf3fa4cdd648a1edc1566bca6aec9d01 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 13 Feb 2021 19:59:39 +0000 Subject: [PATCH 29/46] wip --- mcx/core/jaxpr_ops.py | 71 ++++++++++++++++++++++++++++++------ tests/core/jaxpr_ops_test.py | 3 ++ 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index c9ade94f..dfe1a3d2 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -331,14 +331,58 @@ def jaxpr_find_constvars( DenormalizeState, List[Optional[List["DenormalizeRecState"]]] ] -denorm_supported_linear_ops = [ - jax.lax.broadcast_in_dim_p, - jax.lax.broadcast_p, - jax.lax.neg_p, - jax.lax.reshape_p, - jax.lax.squeeze_p, - jax.lax.reduce_sum_p, -] + +def jaxpr_eqn_denorm_propagate_basic_check( + eqn: jax.core.JaxprEqn, state: DenormalizeState +) -> bool: + """fdsfafasd + + fasdfasd + """ + # Base check for back-propagate through a valid op: all outputs must be valid denorm variables. + _, denorm_valid_vars, _ = state + return all([o in denorm_valid_vars for o in eqn.outvars]) + + +def jaxpr_eqn_denorm_propagate_mul_check( + eqn: jax.core.JaxprEqn, state: DenormalizeState +) -> bool: + """fdsfafasd + + fasdfasd + """ + _, denorm_valid_vars, constvar_full_state = state + constvar_state, _ = constvar_full_state + + def is_var_constant(v: Any) -> bool: + return type(v) is jax.core.Literal or v in constvar_state + + # Propagate denormalization if one of the input is a uniform constant. + all_valid_outvars = all([o in denorm_valid_vars for o in eqn.outvars]) + any_invar_const = is_var_constant(eqn.invars[0]) or is_var_constant(eqn.invars[1]) + return all_valid_outvars and any_invar_const + + +def jaxpr_eqn_denorm_propagate_div_check( + eqn: jax.core.JaxprEqn, state: DenormalizeState +) -> bool: + """fdsfafasd + + fasdfasd + """ + + +jaxpr_eqn_denorm_propagate_ops = { + jax.lax.broadcast_in_dim_p: jaxpr_eqn_denorm_propagate_basic_check, + jax.lax.broadcast_p: jaxpr_eqn_denorm_propagate_basic_check, + jax.lax.neg_p: jaxpr_eqn_denorm_propagate_basic_check, + jax.lax.reshape_p: jaxpr_eqn_denorm_propagate_basic_check, + jax.lax.squeeze_p: jaxpr_eqn_denorm_propagate_basic_check, + jax.lax.reduce_sum_p: jaxpr_eqn_denorm_propagate_basic_check, + jax.lax.mul_p: jaxpr_eqn_denorm_propagate_mul_check, +} +""" +""" def jaxpr_denorm_mapping_visitor_fn( @@ -359,10 +403,13 @@ def denorm_linear_op(invar, outvar, replace_op): denorm_valid_vars.add(invar) denorm_map_dict[outvar] = (replace_op, invar) - if eqn.primitive in denorm_supported_linear_ops: - # Can continue denormalizing inputs if all outputs are in the linear vars collection. - if all([o in denorm_valid_vars for o in eqn.outvars]): - denorm_valid_vars |= set(eqn.invars) + if ( + eqn.primitive in jaxpr_eqn_denorm_propagate_ops + and jaxpr_eqn_denorm_propagate_ops[eqn.primitive](eqn, state) + ): + # Can continue denormalizing input variables + denorm_valid_vars |= {v for v in eqn.invars if type(v) is not jax.core.Literal} + elif eqn.primitive == jax.lax.add_p and eqn.outvars[0] in denorm_valid_vars: lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] # Mapping the output to the non-const input. diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index cf162759..a64803ab 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -94,6 +94,8 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): denorm_linear_op_propagating = [ {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, {"fn": lambda x: 2.0 * (x + 1.0), "expected_op": jax_lax_identity}, + {"fn": lambda x: (x + 1.0) * 2.0, "expected_op": jax_lax_identity}, + {"fn": lambda x: (x + 1.0) / 2.0, "expected_op": jax_lax_identity}, {"fn": lambda x: jnp.expand_dims(1.0 - x, axis=0), "expected_op": jax.lax.neg}, {"fn": lambda x: jnp.reshape(1.0 - x, (1, 1)), "expected_op": jax.lax.neg}, { @@ -113,6 +115,7 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) print(typed_jaxpr) + print(str(typed_jaxpr.jaxpr.eqns[-1].primitive)) denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] From 1c784e9a290f0e4f0dc85f1c4d4398e34ad4bab0 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sat, 13 Feb 2021 20:01:05 +0000 Subject: [PATCH 30/46] wip --- mcx/core/jaxpr_ops.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index dfe1a3d2..198a0e66 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -370,6 +370,16 @@ def jaxpr_eqn_denorm_propagate_div_check( fasdfasd """ + _, denorm_valid_vars, constvar_full_state = state + constvar_state, _ = constvar_full_state + + def is_var_constant(v: Any) -> bool: + return type(v) is jax.core.Literal or v in constvar_state + + # Propagate denormalization if second input is a uniform constant. + all_valid_outvars = all([o in denorm_valid_vars for o in eqn.outvars]) + second_invar_const = is_var_constant(eqn.invars[1]) + return all_valid_outvars and second_invar_const jaxpr_eqn_denorm_propagate_ops = { @@ -380,6 +390,7 @@ def jaxpr_eqn_denorm_propagate_div_check( jax.lax.squeeze_p: jaxpr_eqn_denorm_propagate_basic_check, jax.lax.reduce_sum_p: jaxpr_eqn_denorm_propagate_basic_check, jax.lax.mul_p: jaxpr_eqn_denorm_propagate_mul_check, + jax.lax.div_p: jaxpr_eqn_denorm_propagate_div_check, } """ """ From 85cec9476847fa0cab4d775f604a405a505f5219 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 14 Feb 2021 12:06:17 +0000 Subject: [PATCH 31/46] wip --- tests/core/jaxpr_ops_test.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index a64803ab..f433130a 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -93,6 +93,7 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): denorm_linear_op_propagating = [ {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, + {"fn": lambda x: x + (x + 1.0), "expected_op": jax_lax_identity}, {"fn": lambda x: 2.0 * (x + 1.0), "expected_op": jax_lax_identity}, {"fn": lambda x: (x + 1.0) * 2.0, "expected_op": jax_lax_identity}, {"fn": lambda x: (x + 1.0) / 2.0, "expected_op": jax_lax_identity}, @@ -114,9 +115,6 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) - print(typed_jaxpr) - print(str(typed_jaxpr.jaxpr.eqns[-1].primitive)) - denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] @@ -130,21 +128,24 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( assert invar in denorm_valid_vars -# denorm_non_linear_fn = [ -# {"fn": lambda x: np.sin(x + 1.0)}, -# {"fn": lambda x: np.abs(x + 1.0)}, -# {"fn": lambda x: np.exp(x + 1.0)}, -# {"fn": lambda x: x * (x + 1.0)}, -# ] +denorm_non_linear_fn = [ + {"fn": lambda x: jnp.sin(x + 1.0)}, + {"fn": lambda x: jnp.abs(x + 1.0)}, + {"fn": lambda x: jnp.exp(x + 1.0)}, + {"fn": lambda x: x + jnp.sin(x + 1.0)}, + {"fn": lambda x: x * (x + 1.0)}, +] + +@pytest.mark.parametrize("case", denorm_non_linear_fn) +def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): + typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) + constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} + constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) -# @pytest.mark.parametrize("case", denorm_non_linear_fn) -# def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): -# typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) -# denorm_map = jaxpr_find_denormalize_mapping( -# typed_jaxpr.jaxpr, typed_jaxpr.jaxpr.constvars -# ) -# assert len(denorm_map) == 0 + denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) + denorm_map, _, _ = denorm_rec_state[0] + assert len(denorm_map) == 0 # denormalize_test_cases = [ From 2853d378d4ee83abb6d7c024ca49151690222d59 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 14 Feb 2021 12:18:47 +0000 Subject: [PATCH 32/46] wip --- mcx/core/jaxpr_ops.py | 11 +++++++++-- tests/core/jaxpr_ops_test.py | 7 ++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 198a0e66..eb313421 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -389,6 +389,8 @@ def is_var_constant(v: Any) -> bool: jax.lax.reshape_p: jaxpr_eqn_denorm_propagate_basic_check, jax.lax.squeeze_p: jaxpr_eqn_denorm_propagate_basic_check, jax.lax.reduce_sum_p: jaxpr_eqn_denorm_propagate_basic_check, + jax.lax.add_p: jaxpr_eqn_denorm_propagate_basic_check, + jax.lax.sub_p: jaxpr_eqn_denorm_propagate_basic_check, jax.lax.mul_p: jaxpr_eqn_denorm_propagate_mul_check, jax.lax.div_p: jaxpr_eqn_denorm_propagate_div_check, } @@ -418,10 +420,15 @@ def denorm_linear_op(invar, outvar, replace_op): eqn.primitive in jaxpr_eqn_denorm_propagate_ops and jaxpr_eqn_denorm_propagate_ops[eqn.primitive](eqn, state) ): - # Can continue denormalizing input variables + # Valid variables on which to propagate the denormalization. denorm_valid_vars |= {v for v in eqn.invars if type(v) is not jax.core.Literal} + else: + # Make sure we can not propagate the inputs. see for instance `x * sin(x+1)` + invalid_invars = {v for v in eqn.invars if type(v) is not jax.core.Literal} + denorm_valid_vars = denorm_valid_vars - invalid_invars - elif eqn.primitive == jax.lax.add_p and eqn.outvars[0] in denorm_valid_vars: + # Add and sub operations which can be simplified. + if eqn.primitive == jax.lax.add_p and eqn.outvars[0] in denorm_valid_vars: lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] # Mapping the output to the non-const input. if is_var_constant(lhs_invar): diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index f433130a..3f94784d 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -94,6 +94,7 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): denorm_linear_op_propagating = [ {"fn": lambda x: -(x + 1.0), "expected_op": jax_lax_identity}, {"fn": lambda x: x + (x + 1.0), "expected_op": jax_lax_identity}, + {"fn": lambda x: x - (x + 1.0), "expected_op": jax_lax_identity}, {"fn": lambda x: 2.0 * (x + 1.0), "expected_op": jax_lax_identity}, {"fn": lambda x: (x + 1.0) * 2.0, "expected_op": jax_lax_identity}, {"fn": lambda x: (x + 1.0) / 2.0, "expected_op": jax_lax_identity}, @@ -143,9 +144,13 @@ def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) + invar = typed_jaxpr.jaxpr.invars[0] denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) - denorm_map, _, _ = denorm_rec_state[0] + denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] + # Not simplifying mapping found. assert len(denorm_map) == 0 + # Denormalization not propagating to the input. + assert invar not in denorm_valid_vars # denormalize_test_cases = [ From ec441248012d73fcdfa454d959a6f1543df030c4 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 14 Feb 2021 14:10:42 +0000 Subject: [PATCH 33/46] wip --- mcx/core/jaxpr_ops.py | 48 +++++++++++++++++++++++++++++++----- tests/core/jaxpr_ops_test.py | 26 +++++++++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index eb313421..7e1b5c1c 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -407,7 +407,7 @@ def jaxpr_denorm_mapping_visitor_fn( """ # Un-stack input complex input state! denorm_map_dict, denorm_valid_vars, constvar_full_state = state - constvar_state, _ = constvar_full_state + constvar_state, constvar_sub_states = constvar_full_state def is_var_constant(v: Any) -> bool: return type(v) is jax.core.Literal or v in constvar_state @@ -423,9 +423,8 @@ def denorm_linear_op(invar, outvar, replace_op): # Valid variables on which to propagate the denormalization. denorm_valid_vars |= {v for v in eqn.invars if type(v) is not jax.core.Literal} else: - # Make sure we can not propagate the inputs. see for instance `x * sin(x+1)` - invalid_invars = {v for v in eqn.invars if type(v) is not jax.core.Literal} - denorm_valid_vars = denorm_valid_vars - invalid_invars + # Make sure we can not propagate the inputs. see for instance `x + sin(x + 1)`. + denorm_valid_vars -= {v for v in eqn.invars if type(v) is not jax.core.Literal} # Add and sub operations which can be simplified. if eqn.primitive == jax.lax.add_p and eqn.outvars[0] in denorm_valid_vars: @@ -443,6 +442,9 @@ def denorm_linear_op(invar, outvar, replace_op): elif is_var_constant(rhs_invar): denorm_linear_op(lhs_invar, eqn.outvars[0], jax_lax_identity) + # Update the constvar sub-states list, to keep sync. with equations in the jaxpr. + constvar_sub_states = constvar_sub_states[:-1] + constvar_full_state = constvar_state, constvar_sub_states # Re-construct updated state. return (denorm_map_dict, denorm_valid_vars, constvar_full_state) @@ -452,16 +454,50 @@ def jaxpr_denorm_mapping_map_sub_states_fn( ) -> List[DenormalizeState]: """""" denorm_map_dict, denorm_valid_vars, constvar_full_state = state + constvar_state, constvar_sub_states = constvar_full_state + sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) - # TODO: fix properly. - return [state for _ in sub_jaxprs] + assert len(constvar_sub_states[-1]) == len(sub_jaxprs) + + primitive_type = type(eqn.primitive) + if primitive_type == jax.core.CallPrimitive: + # Jit compiled sub-jaxpr: map eqn outputs to sub-jaxpr outputs. + sub_jaxpr, sub_const_state = sub_jaxprs[0], constvar_sub_states[-1][0] + # Map the denorm valid vars to the output of the sub-jaxprs. + denorm_sub_valid_vars = { + sub_outvar + for eqn_outvar, sub_outvar in zip(eqn.outvars, sub_jaxpr.outvars) + if eqn_outvar in denorm_valid_vars + } + denorm_sub_state = ({}, denorm_sub_valid_vars, sub_const_state) + return [denorm_sub_state] + else: + # TODO: support other high primitives. No constants passed at the moment. + denorm_sub_states = [state for _ in sub_jaxprs] + return denorm_sub_states def jaxpr_denorm_mapping_reduce_sub_states_fn( eqn: jax.core.JaxprEqn, state: DenormalizeState, sub_states: List[DenormalizeState] ) -> DenormalizeState: """""" + sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) + assert len(sub_states) == len(sub_jaxprs) + + denorm_map_dict, denorm_valid_vars, constvar_full_state = state + primitive_type = type(eqn.primitive) + if primitive_type == jax.core.CallPrimitive: + # Jit compiled sub-jaxpr: map valid sub-jaxpr inputs to update denorm valid variables. + sub_jaxpr = sub_jaxprs[0] + _, sub_denorm_valid_vars, _ = sub_states[0] + denorm_valid_vars |= { + eqn_invar + for eqn_invar, sub_invar in zip(eqn.invars, sub_jaxpr.invars) + if sub_invar in sub_denorm_valid_vars + } + # TODO: fix properly. + state = denorm_map_dict, denorm_valid_vars, constvar_full_state return state diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 3f94784d..b0609596 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -129,6 +129,32 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( assert invar in denorm_valid_vars +denorm_sub_jaxprs_propagating = [ + {"fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), "expected_op": jax.lax.neg}, + # {"fn": lambda x: jnp.full((2,), 2.0) * (1.0 - x), "expected_op": jax.lax.neg}, + # {"fn": lambda x: (1.0 - x) / (jnp.ones((2,)) * 2.0), "expected_op": jax.lax.neg}, +] + + +@pytest.mark.parametrize("case", denorm_sub_jaxprs_propagating) +def test__jaxpr_find_denormalize_mapping__sub_jaxprs_propagating__proper_mapping(case): + typed_jaxpr = jax.make_jaxpr(case["fn"])(1.0) + constvars = {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} + constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) + + denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) + denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] + + # Proper mapping of the output to the input. + # assert len(denorm_map) == 1 + # map_op, map_invar = list(denorm_map.values())[0] + # assert map_op == case["expected_op"] + # assert map_invar == invar + # Input is a valid denorm variable (which could be propagated in sub-jaxpr). + invar = typed_jaxpr.jaxpr.invars[0] + assert invar in denorm_valid_vars + + denorm_non_linear_fn = [ {"fn": lambda x: jnp.sin(x + 1.0)}, {"fn": lambda x: jnp.abs(x + 1.0)}, From 15a5ff94cca2224cf4b665d44b31951ba12322fd Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 14 Feb 2021 19:26:32 +0000 Subject: [PATCH 34/46] wip --- mcx/core/jaxpr_ops.py | 55 +++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 7e1b5c1c..3f6b5378 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -332,21 +332,39 @@ def jaxpr_find_constvars( ] +def jax_is_literal(v: Any) -> bool: + """Is the input variable a Jax core literal?""" + return type(v) is jax.core.Literal + + +def jaxpr_eqn_denorm_no_propagate_invars( + eqn: jax.core.JaxprEqn, state: DenormalizeState +) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: + """Default primitive propagation: blocking the denormalization of input variables.""" + invalid_invars = {v for v in eqn.invars if not jax_is_literal(v)} + return set(), invalid_invars + + def jaxpr_eqn_denorm_propagate_basic_check( eqn: jax.core.JaxprEqn, state: DenormalizeState -) -> bool: +) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd fasdfasd """ # Base check for back-propagate through a valid op: all outputs must be valid denorm variables. + invars = {v for v in eqn.invars if not jax_is_literal(v)} _, denorm_valid_vars, _ = state - return all([o in denorm_valid_vars for o in eqn.outvars]) + check_denorm_propagate = all([o in denorm_valid_vars for o in eqn.outvars]) + if all([o in denorm_valid_vars for o in eqn.outvars]): + return invars, set() + # Default case: blocking back-propagation of denormalization. + return set(), invars def jaxpr_eqn_denorm_propagate_mul_check( eqn: jax.core.JaxprEqn, state: DenormalizeState -) -> bool: +) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd fasdfasd @@ -357,15 +375,19 @@ def jaxpr_eqn_denorm_propagate_mul_check( def is_var_constant(v: Any) -> bool: return type(v) is jax.core.Literal or v in constvar_state + invars = {v for v in eqn.invars if not jax_is_literal(v)} # Propagate denormalization if one of the input is a uniform constant. all_valid_outvars = all([o in denorm_valid_vars for o in eqn.outvars]) any_invar_const = is_var_constant(eqn.invars[0]) or is_var_constant(eqn.invars[1]) - return all_valid_outvars and any_invar_const + if all_valid_outvars and any_invar_const: + return invars, set() + # Default case: blocking back-propagation of denormalization. + return set(), invars def jaxpr_eqn_denorm_propagate_div_check( eqn: jax.core.JaxprEqn, state: DenormalizeState -) -> bool: +) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd fasdfasd @@ -376,10 +398,14 @@ def jaxpr_eqn_denorm_propagate_div_check( def is_var_constant(v: Any) -> bool: return type(v) is jax.core.Literal or v in constvar_state + invars = {v for v in eqn.invars if not jax_is_literal(v)} # Propagate denormalization if second input is a uniform constant. all_valid_outvars = all([o in denorm_valid_vars for o in eqn.outvars]) second_invar_const = is_var_constant(eqn.invars[1]) - return all_valid_outvars and second_invar_const + if all_valid_outvars and second_invar_const: + return invars, set() + # Default case: blocking back-propagation of denormalization. + return set(), invars jaxpr_eqn_denorm_propagate_ops = { @@ -416,15 +442,14 @@ def denorm_linear_op(invar, outvar, replace_op): denorm_valid_vars.add(invar) denorm_map_dict[outvar] = (replace_op, invar) - if ( - eqn.primitive in jaxpr_eqn_denorm_propagate_ops - and jaxpr_eqn_denorm_propagate_ops[eqn.primitive](eqn, state) - ): - # Valid variables on which to propagate the denormalization. - denorm_valid_vars |= {v for v in eqn.invars if type(v) is not jax.core.Literal} - else: - # Make sure we can not propagate the inputs. see for instance `x + sin(x + 1)`. - denorm_valid_vars -= {v for v in eqn.invars if type(v) is not jax.core.Literal} + # Check which input variables to keep propagating the denormalization. + eqn_propagate_check_fn = jaxpr_eqn_denorm_propagate_ops.get( + eqn.primitive, jaxpr_eqn_denorm_no_propagate_invars + ) + valid_invars, invalid_invars = eqn_propagate_check_fn(eqn, state) + # Update the global denorm valid vars accordingly. + denorm_valid_vars |= valid_invars + denorm_valid_vars -= invalid_invars # Add and sub operations which can be simplified. if eqn.primitive == jax.lax.add_p and eqn.outvars[0] in denorm_valid_vars: From facafc262a678fa3d336d4778c6c4bd88e269122 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 14 Feb 2021 19:48:07 +0000 Subject: [PATCH 35/46] wip --- mcx/core/jaxpr_ops.py | 35 +++++++++++++++++++++++++++++++++++ tests/core/jaxpr_ops_test.py | 7 +++++++ 2 files changed, 42 insertions(+) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 3f6b5378..c4e71ed5 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -337,6 +337,13 @@ def jax_is_literal(v: Any) -> bool: return type(v) is jax.core.Literal +def jax_is_non_finite_constant(v: Any, const_state: ConstVarState): + """""" + return (jax_is_literal(v) and not np.isfinite(v.val)) or ( + v in const_state and const_state[v].is_non_finite + ) + + def jaxpr_eqn_denorm_no_propagate_invars( eqn: jax.core.JaxprEqn, state: DenormalizeState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: @@ -408,6 +415,33 @@ def is_var_constant(v: Any) -> bool: return set(), invars +def jaxpr_eqn_denorm_propagate_select_check( + eqn: jax.core.JaxprEqn, state: DenormalizeState +) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: + """fdsfafasd + + fasdfasd + """ + _, denorm_valid_vars, constvar_full_state = state + constvar_state, _ = constvar_full_state + invar_pred, invar_true, invar_false = eqn.invars + all_valid_outvars = all([o in denorm_valid_vars for o in eqn.outvars]) + + if jax_is_non_finite_constant(invar_true, constvar_state) and all_valid_outvars: + valid_vars = set() if jax_is_literal(invar_false) else {invar_false} + invalid_vars = set() if jax_is_literal(invar_pred) else {invar_pred} + return valid_vars, invalid_vars + + if jax_is_non_finite_constant(invar_false, constvar_state) and all_valid_outvars: + valid_vars = set() if jax_is_literal(invar_true) else {invar_true} + invalid_vars = set() if jax_is_literal(invar_pred) else {invar_pred} + return valid_vars, invalid_vars + + # Default case: blocking back-propagation of denormalization. + invars = {v for v in eqn.invars if not jax_is_literal(v)} + return set(), invars + + jaxpr_eqn_denorm_propagate_ops = { jax.lax.broadcast_in_dim_p: jaxpr_eqn_denorm_propagate_basic_check, jax.lax.broadcast_p: jaxpr_eqn_denorm_propagate_basic_check, @@ -419,6 +453,7 @@ def is_var_constant(v: Any) -> bool: jax.lax.sub_p: jaxpr_eqn_denorm_propagate_basic_check, jax.lax.mul_p: jaxpr_eqn_denorm_propagate_mul_check, jax.lax.div_p: jaxpr_eqn_denorm_propagate_div_check, + jax.lax.select_p: jaxpr_eqn_denorm_propagate_select_check, } """ """ diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index b0609596..d2d25228 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -104,6 +104,11 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): "fn": lambda x: jnp.squeeze(jnp.expand_dims(1.0 - x, axis=0)), "expected_op": jax.lax.neg, }, + # Typical case of support in distribution logpdf. + { + "fn": lambda x: jax.lax.select(1.0 > 0.0, 1.0 - x, -np.inf), + "expected_op": jax.lax.neg, + }, # {"fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), "expected_op": jax.lax.neg}, # {"fn": lambda x: jnp.full((2,), 2.0) * (1.0 - x), "expected_op": jax.lax.neg}, # {"fn": lambda x: (1.0 - x) / (jnp.ones((2,)) * 2.0), "expected_op": jax.lax.neg}, @@ -119,6 +124,8 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] + print(typed_jaxpr) + invar = typed_jaxpr.jaxpr.invars[0] # Proper mapping of the output to the input. assert len(denorm_map) == 1 From 9717abe0282c0ab6de6e5062be17759cf5c10323 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 14 Feb 2021 19:50:06 +0000 Subject: [PATCH 36/46] wip --- tests/core/jaxpr_ops_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index d2d25228..8d4aca8e 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -104,14 +104,13 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): "fn": lambda x: jnp.squeeze(jnp.expand_dims(1.0 - x, axis=0)), "expected_op": jax.lax.neg, }, + {"fn": lambda x: jnp.full((2,), 2.0) * (1.0 - x), "expected_op": jax.lax.neg}, + {"fn": lambda x: (1.0 - x) / (jnp.ones((2,)) * 2.0), "expected_op": jax.lax.neg}, # Typical case of support in distribution logpdf. { "fn": lambda x: jax.lax.select(1.0 > 0.0, 1.0 - x, -np.inf), "expected_op": jax.lax.neg, }, - # {"fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), "expected_op": jax.lax.neg}, - # {"fn": lambda x: jnp.full((2,), 2.0) * (1.0 - x), "expected_op": jax.lax.neg}, - # {"fn": lambda x: (1.0 - x) / (jnp.ones((2,)) * 2.0), "expected_op": jax.lax.neg}, ] @@ -124,8 +123,6 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] - print(typed_jaxpr) - invar = typed_jaxpr.jaxpr.invars[0] # Proper mapping of the output to the input. assert len(denorm_map) == 1 From 756781e056f59c38de071386664dd999a4d3af43 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 14 Feb 2021 19:53:40 +0000 Subject: [PATCH 37/46] wip --- mcx/core/jaxpr_ops.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index c4e71ed5..f892f7c9 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -344,7 +344,7 @@ def jax_is_non_finite_constant(v: Any, const_state: ConstVarState): ) -def jaxpr_eqn_denorm_no_propagate_invars( +def jaxpr_denorm_propagate_blocking_eqn( eqn: jax.core.JaxprEqn, state: DenormalizeState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """Default primitive propagation: blocking the denormalization of input variables.""" @@ -352,7 +352,7 @@ def jaxpr_eqn_denorm_no_propagate_invars( return set(), invalid_invars -def jaxpr_eqn_denorm_propagate_basic_check( +def jaxpr_denorm_propagate_linear_eqn( eqn: jax.core.JaxprEqn, state: DenormalizeState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd @@ -369,7 +369,7 @@ def jaxpr_eqn_denorm_propagate_basic_check( return set(), invars -def jaxpr_eqn_denorm_propagate_mul_check( +def jaxpr_denorm_propagate_mul_eqn( eqn: jax.core.JaxprEqn, state: DenormalizeState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd @@ -392,7 +392,7 @@ def is_var_constant(v: Any) -> bool: return set(), invars -def jaxpr_eqn_denorm_propagate_div_check( +def jaxpr_denorm_propagate_div_eqn( eqn: jax.core.JaxprEqn, state: DenormalizeState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd @@ -415,7 +415,7 @@ def is_var_constant(v: Any) -> bool: return set(), invars -def jaxpr_eqn_denorm_propagate_select_check( +def jaxpr_denorm_propagate_select_eqn( eqn: jax.core.JaxprEqn, state: DenormalizeState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd @@ -443,17 +443,17 @@ def jaxpr_eqn_denorm_propagate_select_check( jaxpr_eqn_denorm_propagate_ops = { - jax.lax.broadcast_in_dim_p: jaxpr_eqn_denorm_propagate_basic_check, - jax.lax.broadcast_p: jaxpr_eqn_denorm_propagate_basic_check, - jax.lax.neg_p: jaxpr_eqn_denorm_propagate_basic_check, - jax.lax.reshape_p: jaxpr_eqn_denorm_propagate_basic_check, - jax.lax.squeeze_p: jaxpr_eqn_denorm_propagate_basic_check, - jax.lax.reduce_sum_p: jaxpr_eqn_denorm_propagate_basic_check, - jax.lax.add_p: jaxpr_eqn_denorm_propagate_basic_check, - jax.lax.sub_p: jaxpr_eqn_denorm_propagate_basic_check, - jax.lax.mul_p: jaxpr_eqn_denorm_propagate_mul_check, - jax.lax.div_p: jaxpr_eqn_denorm_propagate_div_check, - jax.lax.select_p: jaxpr_eqn_denorm_propagate_select_check, + jax.lax.broadcast_in_dim_p: jaxpr_denorm_propagate_linear_eqn, + jax.lax.broadcast_p: jaxpr_denorm_propagate_linear_eqn, + jax.lax.neg_p: jaxpr_denorm_propagate_linear_eqn, + jax.lax.reshape_p: jaxpr_denorm_propagate_linear_eqn, + jax.lax.squeeze_p: jaxpr_denorm_propagate_linear_eqn, + jax.lax.reduce_sum_p: jaxpr_denorm_propagate_linear_eqn, + jax.lax.add_p: jaxpr_denorm_propagate_linear_eqn, + jax.lax.sub_p: jaxpr_denorm_propagate_linear_eqn, + jax.lax.mul_p: jaxpr_denorm_propagate_mul_eqn, + jax.lax.div_p: jaxpr_denorm_propagate_div_eqn, + jax.lax.select_p: jaxpr_denorm_propagate_select_eqn, } """ """ @@ -479,7 +479,7 @@ def denorm_linear_op(invar, outvar, replace_op): # Check which input variables to keep propagating the denormalization. eqn_propagate_check_fn = jaxpr_eqn_denorm_propagate_ops.get( - eqn.primitive, jaxpr_eqn_denorm_no_propagate_invars + eqn.primitive, jaxpr_denorm_propagate_blocking_eqn ) valid_invars, invalid_invars = eqn_propagate_check_fn(eqn, state) # Update the global denorm valid vars accordingly. From 05c1be41d62a503683668c5188ecc340da1f3022 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 14 Feb 2021 19:54:53 +0000 Subject: [PATCH 38/46] wip --- mcx/core/jaxpr_ops.py | 63 ++----------------------------------------- 1 file changed, 2 insertions(+), 61 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index f892f7c9..85ef20d2 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -442,7 +442,7 @@ def jaxpr_denorm_propagate_select_eqn( return set(), invars -jaxpr_eqn_denorm_propagate_ops = { +jaxpr_eqn_denorm_propagate_rules = { jax.lax.broadcast_in_dim_p: jaxpr_denorm_propagate_linear_eqn, jax.lax.broadcast_p: jaxpr_denorm_propagate_linear_eqn, jax.lax.neg_p: jaxpr_denorm_propagate_linear_eqn, @@ -478,7 +478,7 @@ def denorm_linear_op(invar, outvar, replace_op): denorm_map_dict[outvar] = (replace_op, invar) # Check which input variables to keep propagating the denormalization. - eqn_propagate_check_fn = jaxpr_eqn_denorm_propagate_ops.get( + eqn_propagate_check_fn = jaxpr_eqn_denorm_propagate_rules.get( eqn.primitive, jaxpr_denorm_propagate_blocking_eqn ) valid_invars, invalid_invars = eqn_propagate_check_fn(eqn, state) @@ -595,65 +595,6 @@ def jaxpr_find_denormalize_mapping( return denorm_rec_state -def jaxpr_find_denormalize_mapping_old( - jaxpr: jax.core.Jaxpr, consts: List[jax.core.Var] -) -> Dict[jax.core.Var, Tuple[jax.core.Primitive, jax.core.Var]]: - """Find all assignment simplifications in a JAX expression when denormalizing. - - More specifically, this method is looking to simplify `add` and `sub` operations, with output linear - with respect to the Jaxpr outputs, and where one of the input is constant. It returns the simplified mapping - between input and output of `add`/`sub` ops which can be removed. - - Parameters - ---------- - jaxpr: JAX expression. - consts: List of known constant variables in the JAX expression. - - Returns - ------- - Simplified mapping between `add` output and input (with the proper assignment lax op `identity` or `neg`). - """ - denormalize_mapping = {} - # List of linear ops which can be traversed backward from the outputs. - denorm_supported_linear_ops = [ - jax.lax.broadcast_in_dim_p, - jax.lax.broadcast_p, - jax.lax.neg_p, - jax.lax.reshape_p, - jax.lax.squeeze_p, - jax.lax.reduce_sum_p, - ] - # Collection of variables linear with respect to the Jaxpr final outputs. - linear_vars = set(jaxpr.outvars) - - # Traversing backward the graph of operations. - for eqn in jaxpr.eqns[::-1]: - if eqn.primitive in denorm_supported_linear_ops: - # Can continue denormalizing inputs if all outputs are in the linear vars collection. - if all([o in linear_vars for o in eqn.outvars]): - linear_vars |= set(eqn.invars) - elif eqn.primitive == jax.lax.add_p and eqn.outvars[0] in linear_vars: - lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] - # Mapping the output to the non-const input. - if lhs_invar in consts or type(lhs_invar) is jax.core.Literal: - linear_vars.add(rhs_invar) - denormalize_mapping[eqn.outvars[0]] = (jax_lax_identity, rhs_invar) - elif rhs_invar in consts or type(rhs_invar) is jax.core.Literal: - linear_vars.add(lhs_invar) - denormalize_mapping[eqn.outvars[0]] = (jax_lax_identity, lhs_invar) - elif eqn.primitive == jax.lax.sub_p and eqn.outvars[0] in linear_vars: - lhs_invar, rhs_invar = eqn.invars[0], eqn.invars[1] - # Mapping the output to the non-const input (or the negative). - if lhs_invar in consts or type(lhs_invar) is jax.core.Literal: - linear_vars.add(rhs_invar) - denormalize_mapping[eqn.outvars[0]] = (jax.lax.neg, rhs_invar) - elif rhs_invar in consts or type(rhs_invar) is jax.core.Literal: - linear_vars.add(lhs_invar) - denormalize_mapping[eqn.outvars[0]] = (jax_lax_identity, lhs_invar) - - return denormalize_mapping - - def jaxpr_denormalize(jaxpr, consts, *args): """Denormalize a Jaxpr, i.e. removing any normalizing constant added to the output. From 1504f11aaf6e4c12779f03aad25617a74ee5a6aa Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 15 Feb 2021 19:23:38 +0000 Subject: [PATCH 39/46] wip --- mcx/core/jaxpr_ops.py | 3 +++ tests/core/jaxpr_ops_test.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 85ef20d2..a690c020 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -556,6 +556,9 @@ def jaxpr_denorm_mapping_reduce_sub_states_fn( if sub_invar in sub_denorm_valid_vars } + # Update the constvar sub-states list, to keep sync. with equations in the jaxpr. + constvar_state, constvar_sub_states = constvar_full_state + constvar_full_state = constvar_state, constvar_sub_states[:-1] # TODO: fix properly. state = denorm_map_dict, denorm_valid_vars, constvar_full_state return state diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 8d4aca8e..a11aaf42 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -121,7 +121,7 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) - denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] + denorm_map, denorm_valid_vars, constvar_full_state = denorm_rec_state[0] invar = typed_jaxpr.jaxpr.invars[0] # Proper mapping of the output to the input. @@ -131,6 +131,9 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( assert map_invar == invar # Input is a valid denorm variable (which could be propagated in sub-jaxpr). assert invar in denorm_valid_vars + # Returned constvar sub states should be empty: consumed in the visitor loop. + _, constvar_sub_states = constvar_full_state + assert len(constvar_sub_states) == 0 denorm_sub_jaxprs_propagating = [ @@ -147,16 +150,20 @@ def test__jaxpr_find_denormalize_mapping__sub_jaxprs_propagating__proper_mapping constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) - denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] + denorm_map, denorm_valid_vars, constvar_full_state = denorm_rec_state[0] # Proper mapping of the output to the input. # assert len(denorm_map) == 1 # map_op, map_invar = list(denorm_map.values())[0] # assert map_op == case["expected_op"] # assert map_invar == invar + # Input is a valid denorm variable (which could be propagated in sub-jaxpr). invar = typed_jaxpr.jaxpr.invars[0] assert invar in denorm_valid_vars + # Returned constvar sub states should be empty: consumed in the visitor loop. + _, constvar_sub_states = constvar_full_state + assert len(constvar_sub_states) == 0 denorm_non_linear_fn = [ From 9ee2b777f02d23349fe59cff5f7026cb06855fe8 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 15 Feb 2021 19:24:39 +0000 Subject: [PATCH 40/46] wip --- mcx/core/jaxpr_ops.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index a690c020..52f5162c 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -319,7 +319,7 @@ def jaxpr_find_constvars( return const_rec_state -DenormalizeState = Tuple[ +DenormMapState = Tuple[ Dict[jax.core.Var, Tuple[Any, jax.core.Var]], Set[jax.core.Var], ConstVarRecState ] """Denormalization state, combination of: @@ -327,9 +327,7 @@ def jaxpr_find_constvars( - set of variables which can be traverse backward for denormalization; - full recursive const variable state of the Jaxpr. """ -DenormalizeRecState = Tuple[ - DenormalizeState, List[Optional[List["DenormalizeRecState"]]] -] +DenormMapRecState = Tuple[DenormMapState, List[Optional[List["DenormMapRecState"]]]] def jax_is_literal(v: Any) -> bool: @@ -345,7 +343,7 @@ def jax_is_non_finite_constant(v: Any, const_state: ConstVarState): def jaxpr_denorm_propagate_blocking_eqn( - eqn: jax.core.JaxprEqn, state: DenormalizeState + eqn: jax.core.JaxprEqn, state: DenormMapState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """Default primitive propagation: blocking the denormalization of input variables.""" invalid_invars = {v for v in eqn.invars if not jax_is_literal(v)} @@ -353,7 +351,7 @@ def jaxpr_denorm_propagate_blocking_eqn( def jaxpr_denorm_propagate_linear_eqn( - eqn: jax.core.JaxprEqn, state: DenormalizeState + eqn: jax.core.JaxprEqn, state: DenormMapState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd @@ -370,7 +368,7 @@ def jaxpr_denorm_propagate_linear_eqn( def jaxpr_denorm_propagate_mul_eqn( - eqn: jax.core.JaxprEqn, state: DenormalizeState + eqn: jax.core.JaxprEqn, state: DenormMapState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd @@ -393,7 +391,7 @@ def is_var_constant(v: Any) -> bool: def jaxpr_denorm_propagate_div_eqn( - eqn: jax.core.JaxprEqn, state: DenormalizeState + eqn: jax.core.JaxprEqn, state: DenormMapState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd @@ -416,7 +414,7 @@ def is_var_constant(v: Any) -> bool: def jaxpr_denorm_propagate_select_eqn( - eqn: jax.core.JaxprEqn, state: DenormalizeState + eqn: jax.core.JaxprEqn, state: DenormMapState ) -> Tuple[Set[jax.core.Var], Set[jax.core.Var]]: """fdsfafasd @@ -461,8 +459,8 @@ def jaxpr_denorm_propagate_select_eqn( def jaxpr_denorm_mapping_visitor_fn( eqn: jax.core.JaxprEqn, - state: DenormalizeState, -) -> DenormalizeState: + state: DenormMapState, +) -> DenormMapState: """pass fdsafas """ @@ -510,8 +508,8 @@ def denorm_linear_op(invar, outvar, replace_op): def jaxpr_denorm_mapping_map_sub_states_fn( - eqn: jax.core.JaxprEqn, state: DenormalizeState -) -> List[DenormalizeState]: + eqn: jax.core.JaxprEqn, state: DenormMapState +) -> List[DenormMapState]: """""" denorm_map_dict, denorm_valid_vars, constvar_full_state = state constvar_state, constvar_sub_states = constvar_full_state @@ -538,8 +536,8 @@ def jaxpr_denorm_mapping_map_sub_states_fn( def jaxpr_denorm_mapping_reduce_sub_states_fn( - eqn: jax.core.JaxprEqn, state: DenormalizeState, sub_states: List[DenormalizeState] -) -> DenormalizeState: + eqn: jax.core.JaxprEqn, state: DenormMapState, sub_states: List[DenormMapState] +) -> DenormMapState: """""" sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) assert len(sub_states) == len(sub_jaxprs) @@ -566,7 +564,7 @@ def jaxpr_denorm_mapping_reduce_sub_states_fn( def jaxpr_find_denormalize_mapping( jaxpr: jax.core.Jaxpr, constvar_state: ConstVarRecState -) -> DenormalizeRecState: +) -> DenormMapRecState: """Find all assignment simplifications in a JAX expression when denormalizing. More specifically, this method is looking to simplify `add` and `sub` operations, with output linear From e1713ae52503f1ee06124311d17c4e8217ceee5a Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 15 Feb 2021 20:05:51 +0000 Subject: [PATCH 41/46] wip --- mcx/core/jaxpr_ops.py | 104 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 52f5162c..949ce5ca 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -596,6 +596,110 @@ def jaxpr_find_denormalize_mapping( return denorm_rec_state +DenormRunState = Tuple[Dict[jax.core.Var, Any], DenormMapRecState] +"""Denormalization run state. +""" +DenormRunRecState = Tuple[DenormRunState, List[Optional[List["DenormRunRecState"]]]] + + +def jaxpr_denorm_run_visitor_fn( + eqn: jax.core.JaxprEqn, + state: DenormRunState, +) -> DenormRunState: + """pass + fdsafas + """ + denorm_env, denorm_map_rec_state = state + denorm_mapping, _, _ = denorm_map_rec_state[0] + + def read_env(var): + if type(var) is jax.core.Literal: + return var.val + return read_env[var] + + def write_env(var, val): + denorm_env[var] = val + + if len(eqn.outvars) == 1 and eqn.outvars[0] in denorm_mapping: + # Output registered: skip the primitive and map directly to one of the input. + outvar = eqn.outvars[0] + map_primitive, map_invar = ( + denorm_mapping[outvar][0], + denorm_mapping[outvar][1], + ) + # Mapping the inval to output var (identity or neg). + inval = read_env(map_invar) + outval = map_primitive(inval) + write_env(outvar, outval) + else: + # Usual map: calling the primitive and mapping the output values. + invals = safe_map(read_env, eqn.invars) + outvals = eqn.primitive.bind(*invals, **eqn.params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + safe_map(write_env, eqn.outvars, outvals) + + # Returning updated environment. + return denorm_env, denorm_map_rec_state + + +def jaxpr_denorm_run_map_sub_states_fn( + eqn: jax.core.JaxprEqn, state: DenormRunState +) -> List[DenormRunState]: + """""" + # denorm_map_dict, denorm_valid_vars, constvar_full_state = state + # constvar_state, constvar_sub_states = constvar_full_state + + +def jaxpr_denorm_run_reduce_sub_states_fn( + eqn: jax.core.JaxprEqn, state: DenormRunState, sub_states: List[DenormRunState] +) -> DenormRunState: + """""" + # sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) + # assert len(sub_states) == len(sub_jaxprs) + + +def jaxpr_denormalize_run(jaxpr: jax.core.Jaxpr, consts, *args) -> DenormRunRecState: + """TODO + + Parameters + ---------- + jaxpr: JAX expression. + consts: List of known constant variables in the JAX expression. + + Returns + ------- + Simplified mapping between `add` output and input (with the proper assignment lax op `identity` or `neg`). + """ + # Generate the denormalization simplifying mapping. + constvars = {v: ConstVarInfo(False, True) for v in jaxpr.constvars} + constvar_full_state = jaxpr_find_constvars(jaxpr, constvars) + denorm_map_state = jaxpr_find_denormalize_mapping(jaxpr, constvar_full_state) + + # Initialize the denormalize env state, starting from the input variables. + denormalize_env = {} + + def write_env(var, val): + denormalize_env[var] = val + + # Bind args and consts to denormalization environment. + write_env(jax.core.unitvar, jax.core.unit) + safe_map(write_env, jaxpr.invars, args) + safe_map(write_env, jaxpr.constvars, consts) + + denorm_init_state = (denormalize_env, denorm_map_state) + # NOTE: scanning the jaxpr in reverse order. + denorm_run_state = jaxpr_visitor( + jaxpr, + denorm_init_state, + jaxpr_denorm_run_visitor_fn, + jaxpr_denorm_run_map_sub_states_fn, + jaxpr_denorm_run_reduce_sub_states_fn, + reverse=False, + ) + return denorm_run_state + + def jaxpr_denormalize(jaxpr, consts, *args): """Denormalize a Jaxpr, i.e. removing any normalizing constant added to the output. From 825c79a7d6d43c5ba5d3a0851736b4bb14c666d0 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 17 Feb 2021 19:51:47 +0000 Subject: [PATCH 42/46] wip --- mcx/core/jaxpr_ops.py | 24 ++++++++---- tests/core/jaxpr_ops_test.py | 74 +++++++++++++++++++++++------------- 2 files changed, 64 insertions(+), 34 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 949ce5ca..60ea1181 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -615,7 +615,7 @@ def jaxpr_denorm_run_visitor_fn( def read_env(var): if type(var) is jax.core.Literal: return var.val - return read_env[var] + return denorm_env[var] def write_env(var, val): denorm_env[var] = val @@ -623,10 +623,9 @@ def write_env(var, val): if len(eqn.outvars) == 1 and eqn.outvars[0] in denorm_mapping: # Output registered: skip the primitive and map directly to one of the input. outvar = eqn.outvars[0] - map_primitive, map_invar = ( - denorm_mapping[outvar][0], - denorm_mapping[outvar][1], - ) + outvar_mapping = denorm_mapping[outvar] + map_primitive, map_invar = (outvar_mapping[0], outvar_mapping[1]) + print(denorm_mapping, map_primitive, map_invar) # Mapping the inval to output var (identity or neg). inval = read_env(map_invar) outval = map_primitive(inval) @@ -639,6 +638,9 @@ def write_env(var, val): outvals = [outvals] safe_map(write_env, eqn.outvars, outvals) + # Pop first element in the recursive denorm state, to keep in sync. + denorm_map_sub_states = denorm_map_rec_state[1][1:] + denorm_map_rec_state = (denorm_map_rec_state[0], denorm_map_sub_states) # Returning updated environment. return denorm_env, denorm_map_rec_state @@ -687,6 +689,10 @@ def write_env(var, val): safe_map(write_env, jaxpr.invars, args) safe_map(write_env, jaxpr.constvars, consts) + print(denormalize_env) + print(jaxpr) + print(denorm_map_state) + denorm_init_state = (denormalize_env, denorm_map_state) # NOTE: scanning the jaxpr in reverse order. denorm_run_state = jaxpr_visitor( @@ -697,10 +703,12 @@ def write_env(var, val): jaxpr_denorm_run_reduce_sub_states_fn, reverse=False, ) - return denorm_run_state + denorm_outenv = denorm_run_state[0][0] + outvals = [denorm_outenv[v] for v in jaxpr.outvars] + return outvals -def jaxpr_denormalize(jaxpr, consts, *args): +def jaxpr_denormalize_old(jaxpr, consts, *args): """Denormalize a Jaxpr, i.e. removing any normalizing constant added to the output. This method is analysing the Jaxpr graph, simplifying it by skipping any unnecessary constant @@ -770,7 +778,7 @@ def denormalize(logpdf_fn): def wrapped(*args, **kwargs): # TODO: flattening/unflattening of inputs/outputs? closed_jaxpr = jax.make_jaxpr(logpdf_fn)(*args, **kwargs) - out = jaxpr_denormalize(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) + out = jaxpr_denormalize_run(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) # Assuming a single output at the moment? return out[0] diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index a11aaf42..fd55cf84 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -190,29 +190,51 @@ def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): assert invar not in denorm_valid_vars -# denormalize_test_cases = [ -# {"fn": lambda x: x + 1.0, "denorm_fn": lambda x: x, "inval": 2.0}, -# { -# "fn": lambda x: 2.0 - np.sin(x + 1.0), -# "denorm_fn": lambda x: -np.sin(x + 1.0), -# "inval": 2.0, -# }, -# { -# "fn": lambda x: 2.0 - np.sin(x + 1.0), -# "denorm_fn": lambda x: -np.sin(x + 1.0), -# "inval": 2.0, -# }, -# { -# "fn": lambda x: np.sum(x + 2.0), -# "denorm_fn": lambda x: np.sum(x), -# "inval": np.ones((10,)), -# }, -# ] - - -# @pytest.mark.parametrize("case", denormalize_test_cases) -# def test__denormalize__proper_simplication(case): -# denorm_fn = denormalize(case["fn"]) -# expected_denorm_fn = case["denorm_fn"] -# inval = case["inval"] -# assert np.allclose(denorm_fn(inval), expected_denorm_fn(inval)) +denormalize_simple_test_cases = [ + {"fn": lambda x: x + 1.0, "denorm_fn": lambda x: x, "inval": 2.0}, + { + "fn": lambda x: 2.0 - jnp.sin(x + 1.0), + "denorm_fn": lambda x: -jnp.sin(x + 1.0), + "inval": 2.0, + }, + { + "fn": lambda x: 2.0 + x - jnp.sin(x + 1.0), + "denorm_fn": lambda x: x - jnp.sin(x + 1.0), + "inval": 2.0, + }, + { + "fn": lambda x: np.sum(x + 2.0), + "denorm_fn": lambda x: np.sum(x), + "inval": np.ones((10,)), + }, + { + "fn": lambda x: (x + 1.0) * 2.0, + "denorm_fn": lambda x: x * 2.0, + "inval": 3.0, + }, + { + "fn": lambda x: (x + 1.0) / 2.0, + "denorm_fn": lambda x: x / 2.0, + "inval": 3.0, + }, + { + "fn": lambda x: jnp.squeeze(jnp.expand_dims(1.0 - x, axis=0)), + "denorm_fn": lambda x: jnp.squeeze(jnp.expand_dims(-x, axis=0)), + "inval": 3.0, + }, + { + "fn": lambda x: jax.lax.select(1.0 > 0.0, 1.0 - x, -np.inf), + "denorm_fn": lambda x: jax.lax.select(1.0 > 0.0, -x, -np.inf), + "inval": 2.0, + }, +] + + +@pytest.mark.parametrize("case", denormalize_simple_test_cases) +def test__denormalize__simple_methods__proper_simplication(case): + denorm_fn = denormalize(case["fn"]) + expected_denorm_fn = case["denorm_fn"] + inval = case["inval"] + + denorm_fn_outval = denorm_fn(inval) + assert np.allclose(denorm_fn_outval, expected_denorm_fn(inval)) From be7f8e02de3c3075d82f0c8bf1d446366ff68210 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 17 Feb 2021 19:53:00 +0000 Subject: [PATCH 43/46] wip --- tests/core/jaxpr_ops_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index fd55cf84..4649022e 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -227,6 +227,11 @@ def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): "denorm_fn": lambda x: jax.lax.select(1.0 > 0.0, -x, -np.inf), "inval": 2.0, }, + { + "fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), + "denorm_fn": lambda x: -x, + "inval": 2.0, + }, ] From c594ff6ec1e09bc303f7372c6f1aa6b9c9a15ceb Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 19 Feb 2021 20:04:17 +0000 Subject: [PATCH 44/46] wip --- mcx/core/jaxpr_eval.py | 47 ++++++++ mcx/core/jaxpr_ops.py | 224 ++++++++++++++++++----------------- tests/core/jaxpr_ops_test.py | 26 ++-- 3 files changed, 172 insertions(+), 125 deletions(-) create mode 100644 mcx/core/jaxpr_eval.py diff --git a/mcx/core/jaxpr_eval.py b/mcx/core/jaxpr_eval.py new file mode 100644 index 00000000..3ea53a08 --- /dev/null +++ b/mcx/core/jaxpr_eval.py @@ -0,0 +1,47 @@ +import jax.core +import jax.lax + +from typing import Dict, Any +from jax.core import Jaxpr, Literal, Var, unitvar, unit, extract_call_jaxpr +from jax.util import ( + safe_zip, + safe_map, + partial, + curry, + prod, + partialmethod, + tuple_insert, + tuple_delete, +) +import jax.linear_util as lu +from jax._src import source_info_util + + +def eval_jaxpr(jaxpr: Jaxpr, consts, *args): + def read(v): + if type(v) is Literal: + return v.val + else: + return env[v] + + def write(v, val): + env[v] = val + + env: Dict[Var, Any] = {} + write(unitvar, unit) + map(write, jaxpr.constvars, consts) + map(write, jaxpr.invars, args) + for eqn in jaxpr.eqns: + in_vals = map(read, eqn.invars) + call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params) + if call_jaxpr: + subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))] + else: + subfuns = [] + with source_info_util.user_context(eqn.source_info): + ans = eqn.primitive.bind(*(subfuns + in_vals), **params) + if eqn.primitive.multiple_results: + map(write, eqn.outvars, ans) + else: + write(eqn.outvars[0], ans) + return map(read, jaxpr.outvars) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 60ea1181..10c3a30a 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -10,7 +10,18 @@ from dataclasses import dataclass from functools import wraps -from typing import List, Dict, Optional, Set, Tuple, Any, Type, TypeVar, Callable +from typing import ( + List, + Dict, + Optional, + Set, + Tuple, + Any, + Type, + TypeVar, + Callable, + Generic, +) Array = Any """Generic Array type. @@ -18,11 +29,24 @@ TState = TypeVar("TState") """Generic Jaxpr visitor state. """ -TRecState = Tuple[TState, List[Optional[List["TRecState"]]]] """Full recursive state, representing the visitor state of the Jaxpr as well as the sub-states of all sub-jaxprs. """ + +@dataclass +class RecState(Generic[TState]): + """Jaxpr dataflow visitor recursive state. + + Parameters + ---------- + + """ + + state: TState + sub: List[Optional[List["RecState[TState]"]]] + + jaxpr_high_order_primitives_to_subjaxprs = { jax.lax.cond_p: lambda jxpr: jxpr.params["branches"], jax.lax.while_p: lambda jxpr: ( @@ -89,7 +113,7 @@ def jaxpr_visitor( map_sub_states_fn: Callable[[jax.core.JaxprEqn, TState], List[TState]], reduce_sub_states_fn: Callable[[jax.core.JaxprEqn, TState, List[TState]], TState], reverse: bool = False, -) -> TRecState: +) -> RecState[TState]: """Visitor pattern on a Jaxpr, traversing equations and supporting higher-order primitives with sub-Jaxprs. @@ -105,8 +129,12 @@ def jaxpr_visitor( ------- Output state of the last iteration. """ + # None empty input state: by convention, just return None and an empty list. + if initial_state is None: + return None, [] + state = initial_state - subjaxprs_visit = [] + sub_states_list = [] equations = jaxpr.eqns if not reverse else jaxpr.eqns[::-1] for eqn in equations: @@ -126,13 +154,13 @@ def jaxpr_visitor( for sub_jaxpr, sub_state in zip(sub_jaxprs, init_sub_states) ] # Reduce to update the current state. - state = reduce_sub_states_fn(eqn, state, [v[0] for v in res_sub_states]) - subjaxprs_visit.append(res_sub_states) + state = reduce_sub_states_fn(eqn, state, [v.state for v in res_sub_states]) + sub_states_list.append(res_sub_states) else: # Common Jaxpr equation: apply the visitor and update state. state = visitor_fn(eqn, state) - subjaxprs_visit.append(None) - return state, subjaxprs_visit + sub_states_list.append(None) + return RecState(state, sub_states_list) @dataclass @@ -156,7 +184,7 @@ class ConstVarInfo: ConstVarState = Dict[jax.core.Var, ConstVarInfo] """Const variables visitor state: dictionary associating const variables with their info. """ -ConstVarRecState = Tuple[ConstVarState, List[Optional[List["ConstVarRecState"]]]] +# ConstVarRecState = Tuple[ConstVarState, List[Optional[List["ConstVarRecState"]]]] # Garthoks = Union[Garthok, Iterable['Garthoks']] @@ -294,7 +322,7 @@ def jaxpr_find_constvars_reduce_sub_states_fn( def jaxpr_find_constvars( jaxpr: jax.core.Jaxpr, constvars: Dict[jax.core.Var, ConstVarInfo] -) -> ConstVarRecState: +) -> RecState[ConstVarState]: """Find all intermediates variables in a JAX expression which are expected to be constants. Parameters @@ -320,14 +348,16 @@ def jaxpr_find_constvars( DenormMapState = Tuple[ - Dict[jax.core.Var, Tuple[Any, jax.core.Var]], Set[jax.core.Var], ConstVarRecState + Dict[jax.core.Var, Tuple[Any, jax.core.Var]], + Set[jax.core.Var], + RecState[ConstVarState], ] """Denormalization state, combination of: - dictionary of variable mapping, corresponding to `add` or `sub` ops which can be simplified; - set of variables which can be traverse backward for denormalization; - full recursive const variable state of the Jaxpr. """ -DenormMapRecState = Tuple[DenormMapState, List[Optional[List["DenormMapRecState"]]]] +# DenormMapRecState = Tuple[DenormMapState, List[Optional[List["DenormMapRecState"]]]] def jax_is_literal(v: Any) -> bool: @@ -374,11 +404,10 @@ def jaxpr_denorm_propagate_mul_eqn( fasdfasd """ - _, denorm_valid_vars, constvar_full_state = state - constvar_state, _ = constvar_full_state + _, denorm_valid_vars, constvar_rec_state = state def is_var_constant(v: Any) -> bool: - return type(v) is jax.core.Literal or v in constvar_state + return type(v) is jax.core.Literal or v in constvar_rec_state.state invars = {v for v in eqn.invars if not jax_is_literal(v)} # Propagate denormalization if one of the input is a uniform constant. @@ -397,11 +426,10 @@ def jaxpr_denorm_propagate_div_eqn( fasdfasd """ - _, denorm_valid_vars, constvar_full_state = state - constvar_state, _ = constvar_full_state + _, denorm_valid_vars, constvar_rec_state = state def is_var_constant(v: Any) -> bool: - return type(v) is jax.core.Literal or v in constvar_state + return type(v) is jax.core.Literal or v in constvar_rec_state.state invars = {v for v in eqn.invars if not jax_is_literal(v)} # Propagate denormalization if second input is a uniform constant. @@ -420,8 +448,8 @@ def jaxpr_denorm_propagate_select_eqn( fasdfasd """ - _, denorm_valid_vars, constvar_full_state = state - constvar_state, _ = constvar_full_state + _, denorm_valid_vars, constvar_rec_state = state + constvar_state = constvar_rec_state.state invar_pred, invar_true, invar_false = eqn.invars all_valid_outvars = all([o in denorm_valid_vars for o in eqn.outvars]) @@ -465,11 +493,10 @@ def jaxpr_denorm_mapping_visitor_fn( fdsafas """ # Un-stack input complex input state! - denorm_map_dict, denorm_valid_vars, constvar_full_state = state - constvar_state, constvar_sub_states = constvar_full_state + denorm_map_dict, denorm_valid_vars, constvar_rec_state = state def is_var_constant(v: Any) -> bool: - return type(v) is jax.core.Literal or v in constvar_state + return type(v) is jax.core.Literal or v in constvar_rec_state.state def denorm_linear_op(invar, outvar, replace_op): denorm_valid_vars.add(invar) @@ -501,26 +528,25 @@ def denorm_linear_op(invar, outvar, replace_op): denorm_linear_op(lhs_invar, eqn.outvars[0], jax_lax_identity) # Update the constvar sub-states list, to keep sync. with equations in the jaxpr. - constvar_sub_states = constvar_sub_states[:-1] - constvar_full_state = constvar_state, constvar_sub_states + constvar_sub_states = constvar_rec_state.sub[:-1] + constvar_rec_state = RecState(constvar_rec_state.state, constvar_sub_states) # Re-construct updated state. - return (denorm_map_dict, denorm_valid_vars, constvar_full_state) + return (denorm_map_dict, denorm_valid_vars, constvar_rec_state) def jaxpr_denorm_mapping_map_sub_states_fn( eqn: jax.core.JaxprEqn, state: DenormMapState ) -> List[DenormMapState]: """""" - denorm_map_dict, denorm_valid_vars, constvar_full_state = state - constvar_state, constvar_sub_states = constvar_full_state + denorm_map_dict, denorm_valid_vars, constvar_rec_state = state sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) - assert len(constvar_sub_states[-1]) == len(sub_jaxprs) + assert len(constvar_rec_state.sub[-1]) == len(sub_jaxprs) primitive_type = type(eqn.primitive) if primitive_type == jax.core.CallPrimitive: # Jit compiled sub-jaxpr: map eqn outputs to sub-jaxpr outputs. - sub_jaxpr, sub_const_state = sub_jaxprs[0], constvar_sub_states[-1][0] + sub_jaxpr, sub_const_state = sub_jaxprs[0], constvar_rec_state.sub[-1][0] # Map the denorm valid vars to the output of the sub-jaxprs. denorm_sub_valid_vars = { sub_outvar @@ -542,7 +568,7 @@ def jaxpr_denorm_mapping_reduce_sub_states_fn( sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) assert len(sub_states) == len(sub_jaxprs) - denorm_map_dict, denorm_valid_vars, constvar_full_state = state + denorm_map_dict, denorm_valid_vars, constvar_rec_state = state primitive_type = type(eqn.primitive) if primitive_type == jax.core.CallPrimitive: # Jit compiled sub-jaxpr: map valid sub-jaxpr inputs to update denorm valid variables. @@ -555,16 +581,19 @@ def jaxpr_denorm_mapping_reduce_sub_states_fn( } # Update the constvar sub-states list, to keep sync. with equations in the jaxpr. - constvar_state, constvar_sub_states = constvar_full_state - constvar_full_state = constvar_state, constvar_sub_states[:-1] + constvar_sub_states = constvar_rec_state.sub[:-1] # TODO: fix properly. - state = denorm_map_dict, denorm_valid_vars, constvar_full_state + state = ( + denorm_map_dict, + denorm_valid_vars, + RecState(constvar_rec_state.state, constvar_sub_states), + ) return state def jaxpr_find_denormalize_mapping( - jaxpr: jax.core.Jaxpr, constvar_state: ConstVarRecState -) -> DenormMapRecState: + jaxpr: jax.core.Jaxpr, constvar_state: RecState[ConstVarState] +) -> RecState[DenormMapState]: """Find all assignment simplifications in a JAX expression when denormalizing. More specifically, this method is looking to simplify `add` and `sub` operations, with output linear @@ -596,10 +625,10 @@ def jaxpr_find_denormalize_mapping( return denorm_rec_state -DenormRunState = Tuple[Dict[jax.core.Var, Any], DenormMapRecState] +DenormRunState = Tuple[Dict[jax.core.Var, Any], RecState[DenormMapState]] """Denormalization run state. """ -DenormRunRecState = Tuple[DenormRunState, List[Optional[List["DenormRunRecState"]]]] +# DenormRunRecState = Tuple[DenormRunState, List[Optional[List["DenormRunRecState"]]]] def jaxpr_denorm_run_visitor_fn( @@ -610,7 +639,7 @@ def jaxpr_denorm_run_visitor_fn( fdsafas """ denorm_env, denorm_map_rec_state = state - denorm_mapping, _, _ = denorm_map_rec_state[0] + denorm_mapping, _, _ = denorm_map_rec_state.state def read_env(var): if type(var) is jax.core.Literal: @@ -639,8 +668,8 @@ def write_env(var, val): safe_map(write_env, eqn.outvars, outvals) # Pop first element in the recursive denorm state, to keep in sync. - denorm_map_sub_states = denorm_map_rec_state[1][1:] - denorm_map_rec_state = (denorm_map_rec_state[0], denorm_map_sub_states) + denorm_map_sub_states = denorm_map_rec_state.sub[1:] + denorm_map_rec_state = RecState(denorm_map_rec_state.state, denorm_map_sub_states) # Returning updated environment. return denorm_env, denorm_map_rec_state @@ -649,19 +678,51 @@ def jaxpr_denorm_run_map_sub_states_fn( eqn: jax.core.JaxprEqn, state: DenormRunState ) -> List[DenormRunState]: """""" - # denorm_map_dict, denorm_valid_vars, constvar_full_state = state - # constvar_state, constvar_sub_states = constvar_full_state + sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) + + # Skipping the sub-jaxprs run, directly calling the bind in the reduce fn. + sub_states = [None] * len(sub_jaxprs) + return sub_states + + # denorm_map_dict, denorm_valid_vars, constvar_rec_state = state + # constvar_state, constvar_sub_states = constvar_rec_state def jaxpr_denorm_run_reduce_sub_states_fn( eqn: jax.core.JaxprEqn, state: DenormRunState, sub_states: List[DenormRunState] ) -> DenormRunState: """""" - # sub_jaxprs = jaxpr_find_sub_jaxprs(eqn) - # assert len(sub_states) == len(sub_jaxprs) + denorm_env, denorm_map_rec_state = state + denorm_mapping, _, _ = denorm_map_rec_state[0] + + def read_env(var): + if type(var) is jax.core.Literal: + return var.val + return denorm_env[var] + + def write_env(var, val): + denorm_env[var] = val + + # Usual map: calling the primitive and mapping the output values. + invals = safe_map(read_env, eqn.invars) + print(eqn.primitive, invals, eqn.params, eqn) + + outvals = eqn.primitive.bind(*invals, **eqn.params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + safe_map(write_env, eqn.outvars, outvals) -def jaxpr_denormalize_run(jaxpr: jax.core.Jaxpr, consts, *args) -> DenormRunRecState: + # Pop first element in the recursive denorm state, to keep in sync. + denorm_map_sub_states = denorm_map_rec_state[1][1:] + denorm_map_rec_state = (denorm_map_rec_state[0], denorm_map_sub_states) + # Returning updated environment. + return denorm_env, denorm_map_rec_state + + +def jaxpr_denormalize_run( + jaxpr: jax.core.Jaxpr, consts, *args +) -> RecState[DenormRunState]: """TODO Parameters @@ -675,8 +736,8 @@ def jaxpr_denormalize_run(jaxpr: jax.core.Jaxpr, consts, *args) -> DenormRunRecS """ # Generate the denormalization simplifying mapping. constvars = {v: ConstVarInfo(False, True) for v in jaxpr.constvars} - constvar_full_state = jaxpr_find_constvars(jaxpr, constvars) - denorm_map_state = jaxpr_find_denormalize_mapping(jaxpr, constvar_full_state) + constvar_rec_state = jaxpr_find_constvars(jaxpr, constvars) + denorm_map_state = jaxpr_find_denormalize_mapping(jaxpr, constvar_rec_state) # Initialize the denormalize env state, starting from the input variables. denormalize_env = {} @@ -689,9 +750,9 @@ def write_env(var, val): safe_map(write_env, jaxpr.invars, args) safe_map(write_env, jaxpr.constvars, consts) - print(denormalize_env) - print(jaxpr) - print(denorm_map_state) + # print(denormalize_env) + # print(jaxpr) + # print(denorm_map_state) denorm_init_state = (denormalize_env, denorm_map_state) # NOTE: scanning the jaxpr in reverse order. @@ -703,70 +764,11 @@ def write_env(var, val): jaxpr_denorm_run_reduce_sub_states_fn, reverse=False, ) - denorm_outenv = denorm_run_state[0][0] + denorm_outenv = denorm_run_state.state[0] outvals = [denorm_outenv[v] for v in jaxpr.outvars] return outvals -def jaxpr_denormalize_old(jaxpr, consts, *args): - """Denormalize a Jaxpr, i.e. removing any normalizing constant added to the output. - - This method is analysing the Jaxpr graph, simplifying it by skipping any unnecessary constant - addition, and then it runs the method step-by-step to get the output values. - - Parameters - ---------- - jaxpr: JAX expression. - consts: Values assigned to the Jaxpr constant variables. - args: Input values to the method. - - Returns - ------- - Output values of the denormalized logpdf. - """ - # Denormalized simplification mapping. - denorm_mapping = jaxpr_find_denormalize_mapping(jaxpr, jaxpr.constvars) - # Mapping from variable -> value - env = {} - - def read(var): - # Literals are values baked into the Jaxpr - if type(var) is jax.core.Literal: - return var.val - return env[var] - - def write(var, val): - env[var] = val - - # Bind args and consts to environment - write(jax.core.unitvar, jax.core.unit) - safe_map(write, jaxpr.invars, args) - safe_map(write, jaxpr.constvars, consts) - - # Similar to a classic eval Jaxpr loop, just skipping the op with mapping available - for eqn in jaxpr.eqns: - if len(eqn.outvars) == 1 and eqn.outvars[0] in denorm_mapping: - # Output registered: skip the primitive and map directly to one of the input. - outvar = eqn.outvars[0] - map_primitive, map_invar = ( - denorm_mapping[outvar][0], - denorm_mapping[outvar][1], - ) - # Mapping the inval to output var (identity or neg). - inval = read(map_invar) - outval = map_primitive(inval) - write(outvar, outval) - else: - # Usual map: calling the primitive and mapping the output values. - invals = safe_map(read, eqn.invars) - outvals = eqn.primitive.bind(*invals, **eqn.params) - if not eqn.primitive.multiple_results: - outvals = [outvals] - safe_map(write, eqn.outvars, outvals) - # Read the final result of the Jaxpr from the environment - return safe_map(read, jaxpr.outvars) - - def denormalize(logpdf_fn): """Denormalizing decorator for MCX logpdfs. diff --git a/tests/core/jaxpr_ops_test.py b/tests/core/jaxpr_ops_test.py index 4649022e..eb14743c 100644 --- a/tests/core/jaxpr_ops_test.py +++ b/tests/core/jaxpr_ops_test.py @@ -56,7 +56,7 @@ def test__jaxpr_find_constvars__propagate_constants(case): {v: ConstVarInfo(False, True) for v in typed_jaxpr.jaxpr.constvars} ) - constvars, _ = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) + constvars = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars).state for outvar in typed_jaxpr.jaxpr.outvars: assert outvar in constvars assert constvars[outvar] == expected_const_info @@ -77,7 +77,7 @@ def test__jaxpr_find_denormalize_mapping__add_sub__proper_mapping(case): constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) - denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] + denorm_map, denorm_valid_vars, _ = denorm_rec_state.state invar = typed_jaxpr.jaxpr.invars[0] outvar = typed_jaxpr.jaxpr.outvars[0] @@ -121,7 +121,7 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) - denorm_map, denorm_valid_vars, constvar_full_state = denorm_rec_state[0] + denorm_map, denorm_valid_vars, constvar_rec_state = denorm_rec_state.state invar = typed_jaxpr.jaxpr.invars[0] # Proper mapping of the output to the input. @@ -132,8 +132,7 @@ def test__jaxpr_find_denormalize_mapping__linear_op_propagating__proper_mapping( # Input is a valid denorm variable (which could be propagated in sub-jaxpr). assert invar in denorm_valid_vars # Returned constvar sub states should be empty: consumed in the visitor loop. - _, constvar_sub_states = constvar_full_state - assert len(constvar_sub_states) == 0 + assert len(constvar_rec_state.sub) == 0 denorm_sub_jaxprs_propagating = [ @@ -150,7 +149,7 @@ def test__jaxpr_find_denormalize_mapping__sub_jaxprs_propagating__proper_mapping constvar_state = jaxpr_find_constvars(typed_jaxpr.jaxpr, constvars) denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) - denorm_map, denorm_valid_vars, constvar_full_state = denorm_rec_state[0] + denorm_map, denorm_valid_vars, constvar_rec_state = denorm_rec_state.state # Proper mapping of the output to the input. # assert len(denorm_map) == 1 @@ -162,8 +161,7 @@ def test__jaxpr_find_denormalize_mapping__sub_jaxprs_propagating__proper_mapping invar = typed_jaxpr.jaxpr.invars[0] assert invar in denorm_valid_vars # Returned constvar sub states should be empty: consumed in the visitor loop. - _, constvar_sub_states = constvar_full_state - assert len(constvar_sub_states) == 0 + assert len(constvar_rec_state.sub) == 0 denorm_non_linear_fn = [ @@ -183,7 +181,7 @@ def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): invar = typed_jaxpr.jaxpr.invars[0] denorm_rec_state = jaxpr_find_denormalize_mapping(typed_jaxpr.jaxpr, constvar_state) - denorm_map, denorm_valid_vars, _ = denorm_rec_state[0] + denorm_map, denorm_valid_vars, _ = denorm_rec_state.state # Not simplifying mapping found. assert len(denorm_map) == 0 # Denormalization not propagating to the input. @@ -227,11 +225,11 @@ def test__jaxpr_find_denormalize_mapping__non_linear_fn__empty_mapping(case): "denorm_fn": lambda x: jax.lax.select(1.0 > 0.0, -x, -np.inf), "inval": 2.0, }, - { - "fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), - "denorm_fn": lambda x: -x, - "inval": 2.0, - }, + # { + # "fn": lambda x: jax.jit(lambda y: 1.0 - y)(x), + # "denorm_fn": lambda x: -x, + # "inval": jnp.array(2.0), + # }, ] From 66133f529af9ac9752dec9964a28c13615ec8174 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 19 Feb 2021 20:08:15 +0000 Subject: [PATCH 45/46] wip --- mcx/core/jaxpr_ops.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 10c3a30a..32b036d5 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -46,6 +46,18 @@ class RecState(Generic[TState]): state: TState sub: List[Optional[List["RecState[TState]"]]] + def pop_first_sub(self) -> Optional[List["RecState[TState]"]]: + """Pop the first element in the sub-states collection.""" + sub_states = self.sub[0] + self.sub = self.sub[1:] + return sub_states + + def pop_last_sub(self) -> Optional[List["RecState[TState]"]]: + """Pop the first element in the sub-states collection.""" + sub_states = self.sub[-1] + self.sub = self.sub[:-1] + return sub_states + jaxpr_high_order_primitives_to_subjaxprs = { jax.lax.cond_p: lambda jxpr: jxpr.params["branches"], @@ -527,9 +539,8 @@ def denorm_linear_op(invar, outvar, replace_op): elif is_var_constant(rhs_invar): denorm_linear_op(lhs_invar, eqn.outvars[0], jax_lax_identity) - # Update the constvar sub-states list, to keep sync. with equations in the jaxpr. - constvar_sub_states = constvar_rec_state.sub[:-1] - constvar_rec_state = RecState(constvar_rec_state.state, constvar_sub_states) + # Always update the constvar sub-states list, to keep sync. with equations in the jaxpr. + constvar_rec_state.pop_last_sub() # Re-construct updated state. return (denorm_map_dict, denorm_valid_vars, constvar_rec_state) @@ -668,8 +679,7 @@ def write_env(var, val): safe_map(write_env, eqn.outvars, outvals) # Pop first element in the recursive denorm state, to keep in sync. - denorm_map_sub_states = denorm_map_rec_state.sub[1:] - denorm_map_rec_state = RecState(denorm_map_rec_state.state, denorm_map_sub_states) + denorm_map_rec_state.pop_first_sub() # Returning updated environment. return denorm_env, denorm_map_rec_state @@ -714,8 +724,7 @@ def write_env(var, val): safe_map(write_env, eqn.outvars, outvals) # Pop first element in the recursive denorm state, to keep in sync. - denorm_map_sub_states = denorm_map_rec_state[1][1:] - denorm_map_rec_state = (denorm_map_rec_state[0], denorm_map_sub_states) + denorm_map_rec_state.pop_first_sub() # Returning updated environment. return denorm_env, denorm_map_rec_state From d8d11804d8a0ba31706440125516cf516816f5da Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 19 Feb 2021 20:12:47 +0000 Subject: [PATCH 46/46] wip --- mcx/core/jaxpr_ops.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mcx/core/jaxpr_ops.py b/mcx/core/jaxpr_ops.py index 32b036d5..19227a01 100644 --- a/mcx/core/jaxpr_ops.py +++ b/mcx/core/jaxpr_ops.py @@ -26,12 +26,10 @@ Array = Any """Generic Array type. """ + TState = TypeVar("TState") """Generic Jaxpr visitor state. """ -"""Full recursive state, representing the visitor state of the Jaxpr as well as -the sub-states of all sub-jaxprs. -""" @dataclass @@ -196,10 +194,6 @@ class ConstVarInfo: ConstVarState = Dict[jax.core.Var, ConstVarInfo] """Const variables visitor state: dictionary associating const variables with their info. """ -# ConstVarRecState = Tuple[ConstVarState, List[Optional[List["ConstVarRecState"]]]] - - -# Garthoks = Union[Garthok, Iterable['Garthoks']] def get_variable_const_info(v: Any, state: ConstVarState) -> ConstVarInfo: @@ -369,7 +363,6 @@ def jaxpr_find_constvars( - set of variables which can be traverse backward for denormalization; - full recursive const variable state of the Jaxpr. """ -# DenormMapRecState = Tuple[DenormMapState, List[Optional[List["DenormMapRecState"]]]] def jax_is_literal(v: Any) -> bool: @@ -639,7 +632,6 @@ def jaxpr_find_denormalize_mapping( DenormRunState = Tuple[Dict[jax.core.Var, Any], RecState[DenormMapState]] """Denormalization run state. """ -# DenormRunRecState = Tuple[DenormRunState, List[Optional[List["DenormRunRecState"]]]] def jaxpr_denorm_run_visitor_fn(