I code a lot in Jax too. I sympathize with wanting to write Jitted code more cleanly, but your proposal is one of many proposals to make DSLs more Python-like. One problem is that you lose the ability to see what the Python code is actually doing.
Couple points on your code:
- I assume that
count_found_sumshas typejax.Arrayor else you probably don’t needcond. - You probably should be using dataclasses to bundle up all the loose variables. Your code is harder to read than it needs to be.
I think it would help to brainstorm other solutions. E.g., creating your own DSL, or introducing better abstractions in your own Jax library.