Skip to content

Implementing JIT-compiled functions in quadgk? #71

@nwood99-ctrl

Description

@nwood99-ctrl

Hello,

Thanks so much for all your work on developing this project! I have a question regarding the extent to which JIT interfaces with methods like quadgk. In short, I have a function that changes structure based on a set of switching variables, and I wish to integrate this function over a set of inputs that occasionally trigger the switch. Declaring static variables in the JIT compiler handles the switching, I don't understand how to pass that information into quadgk (or similar methods). For the very simple minimum working example shown below, it's possible to explicitly enumerate each case of the switching schedule as its own function, but in my application this would be extremely cumbersome to do. If possible, I'd much prefer handling it through static variables.

The MWE code is below:

import numpy as np
import jax.numpy as jnp
from quadax import quadgk

def funLeft(t):
    return t*jnp.log(1+t)

def funRight(t):
    return jnp.exp(-t)

def funCombined(t,conditional):        
    if conditional==0:
        funBranch = lambda t: t*jnp.log(1+t)
    else:
        funBranch = lambda t: jnp.exp(-t)
    return funBranch(t)
    
if __name__=='__main__':
    a,b=0,1
    epsabs=epsrel=1e-5
    # treating each branch as its own separately-enumerated function could
    # potentially get extremely cumbersome as funCombined gets more complex
    print('GROUND TRUTH')
    for ii in range(10):
        if np.mod(ii,2)==0:
            y=quadgk(funLeft,[a,b],epsabs=epsabs,epsrel=epsrel)[0]
        else:
            y=quadgk(funRight,[a,b],epsabs=epsabs,epsrel=epsrel)[0]
        print(y)
    # would prefer something more elegant like this
    funCombined_jit=jax.jit(funCombined,static_argnums=jnp.array([1]))
    # validate JIT function construction
    print('VALIDATE JIT')
    for ii in range(10):
        conditional=np.mod(ii,2,dtype=np.int32)
        print(funCombined_jit(ii,conditional))
    # (this crashes)
    print('TEST INTEGRATE JIT')
    for ii in range(10):
        conditional=np.mod(ii,2,dtype=np.int32)
        y=quadgk(funCombined_jit,[a,b],args=(conditional,),epsabs=epsabs,epsrel=epsrel)[0]
        print(y)

The traceback error is:

ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, Traced<int32[]>with<DynamicJaxprTrace>. The error was:
TypeError: unhashable type: 'DynamicJaxprTracer'

It seems like the function doesn't accept JIT-compiled functions as arguments, but at the same time, in the API docs I don't see a way to pass static variable arguments to quadgk, which would otherwise be necessary to pass funCombined to it directly. Am I misunderstanding how this would be implemented?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions