Add control flow callbacks for JIT-compiled frameworks

I code a lot in Jax too. I sympathize with wanting to write Jitted code more cleanly, but your proposal is one of many proposals to make DSLs more Python-like. One problem is that you lose the ability to see what the Python code is actually doing.

Couple points on your code:

  • I assume that count_found_sums has type jax.Array or else you probably don’t need cond.
  • You probably should be using dataclasses to bundle up all the loose variables. Your code is harder to read than it needs to be.

I think it would help to brainstorm other solutions. E.g., creating your own DSL, or introducing better abstractions in your own Jax library.