Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b4d63bf
Implement a `denormalize` custom Jaxpr operator simplifying MCX logpd…
balancap Jan 26, 2021
6398e8f
wip
balancap Jan 31, 2021
04cc11f
wip
balancap Feb 6, 2021
ffb7329
wip
balancap Feb 6, 2021
c2c53a4
wip
balancap Feb 6, 2021
5d8b681
wip
balancap Feb 6, 2021
098ca4f
wip
balancap Feb 6, 2021
2f125da
wip
balancap Feb 6, 2021
5369cc7
wip
balancap Feb 6, 2021
3652b60
wip
balancap Feb 7, 2021
cf60bf3
wip
balancap Feb 7, 2021
ba4b560
wip
balancap Feb 7, 2021
3d80f5c
wip
balancap Feb 7, 2021
f39bce6
wip
balancap Feb 7, 2021
29435e7
wip
balancap Feb 7, 2021
9b9f5fb
wip
balancap Feb 8, 2021
89d76d2
wip
balancap Feb 8, 2021
da25e8e
wip
balancap Feb 8, 2021
8113a51
wip
balancap Feb 9, 2021
f1c310e
wip
balancap Feb 9, 2021
d6e452b
wip
balancap Feb 9, 2021
51c7ed5
wip
balancap Feb 10, 2021
c370c66
wip
balancap Feb 11, 2021
6c635e2
wip
balancap Feb 13, 2021
9396129
wip
balancap Feb 13, 2021
c3096c7
wip
balancap Feb 13, 2021
0286705
wip
balancap Feb 13, 2021
7f895d5
wip
balancap Feb 13, 2021
542177d
wip
balancap Feb 13, 2021
1c784e9
wip
balancap Feb 13, 2021
85cec94
wip
balancap Feb 14, 2021
2853d37
wip
balancap Feb 14, 2021
ec44124
wip
balancap Feb 14, 2021
15a5ff9
wip
balancap Feb 14, 2021
facafc2
wip
balancap Feb 14, 2021
9717abe
wip
balancap Feb 14, 2021
756781e
wip
balancap Feb 14, 2021
05c1be4
wip
balancap Feb 14, 2021
1504f11
wip
balancap Feb 15, 2021
9ee2b77
wip
balancap Feb 15, 2021
e1713ae
wip
balancap Feb 15, 2021
825c79a
wip
balancap Feb 17, 2021
be7f8e0
wip
balancap Feb 17, 2021
c594ff6
wip
balancap Feb 19, 2021
66133f5
wip
balancap Feb 19, 2021
d8d1180
wip
balancap Feb 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mcx/core/jaxpr_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jax.core
import jax.lax

from typing import Dict, Any
from jax.core import Jaxpr, Literal, Var, unitvar, unit, extract_call_jaxpr
from jax.util import (
safe_zip,
safe_map,
partial,
curry,
prod,
partialmethod,
tuple_insert,
tuple_delete,
)
import jax.linear_util as lu
from jax._src import source_info_util


def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
def read(v):
if type(v) is Literal:
return v.val
else:
return env[v]

def write(v, val):
env[v] = val

env: Dict[Var, Any] = {}
write(unitvar, unit)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.invars)
call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params)
if call_jaxpr:
subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))]
else:
subfuns = []
with source_info_util.user_context(eqn.source_info):
ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
return map(read, jaxpr.outvars)
Loading