Add control flow callbacks for JIT-compiled frameworks

Hello Python community,
First - I’ll say that python is the best language in the world, but it’s quite difficult to make it performant (as everybody knows).
So we have come up with a just-in-time compilation for performance - offloading work to C++ and CUDA code which makes the code run extremely fast. Jax is currently the fastest framework for any numerical computation in the world. It easily beats even most optimal C++ code in, for example, structural analysis by an order of magnitude.

Thing is, JIT-compiled code is difficult to write. JIT can not compile if statements because they are untraceable and jax has a way of knowing about them.

We have to write attrocities like these:

import jax.numpy as np
import jax.lax.cond as cond()

def headfunc(count_found_sums:bool, multiply:int):
    # a,b,c,d,e,f are defined higher.
    def func1(a:int,b:int,c:np.ndarray,_,_,_): 
        sum=a+b
        count_found_items=np.count_nonzero(c==sum)
        return count_found_items

   def func2(_,_,_,d:np.ndarray, e:np.ndarray, f: np.ndarray): 
      diff1 = np.count_nonzero(d==e)
      diff2 = np.count_nonzero(d==f)
      overlap = np.logical_and(diff1=diff2))
      return np.count_nonzero(overlap))

  result = cond( # cond is the `if` statement in Jax.
        count_found_sums, 
        func1, 
        func2, 
        (a,b,c,d,e,f)  #notice: we have to pass all because conditionals require 
    )
    return result

Which is an equivalent of:

def headfunc(count_found_sums:bool, multiply:int):
    # a,b,c,d,e,f are defined higher.
    if cound_found_sums: 
        sum=a+b
        count_found_items=np.count_nonzero(c==sum)
        return count_found_items
   else:
      diff1 = np.count_nonzero(d==e)
      diff2 = np.count_nonzero(d==f)
      overlap = np.logical_and(diff1=diff2))
      return np.count_nonzero(overlap))

# **7 lines shorter**

On thousands of lines of code the code becomes 50% longer due to if conditions being more difficult to write, in my usecase at least.
Pytorch has something similar with it’s jit not being able to trace if, and they have to write something similar. (only read about it online though, never worked with it.)

Proposal:
In order to allow much simpler jit-compiled code, add a flag to python launch that would trigger a callback every time compilation will find python if and preferably for statements.

something like this:
python3 main.py --flag-control-flow

I personally had to create a 8k SLOS codebase of performance-critical scientific compute code recently and let me tell you, there is every reason to add this if python wants to be a fast for scientists and machine learning developers (which it wants to).

Jax developer jakevdp (Jake Vanderplas) · GitHub will probably be able to know implementation a little better than myself, since I’m not a developer of jax, although I do use it a lot.

1 Like

I code a lot in Jax too. I sympathize with wanting to write Jitted code more cleanly, but your proposal is one of many proposals to make DSLs more Python-like. One problem is that you lose the ability to see what the Python code is actually doing.

Couple points on your code:

  • I assume that count_found_sums has type jax.Array or else you probably don’t need cond.
  • You probably should be using dataclasses to bundle up all the loose variables. Your code is harder to read than it needs to be.

I think it would help to brainstorm other solutions. E.g., creating your own DSL, or introducing better abstractions in your own Jax library.

  • I assume that count_found_sums has type jax.Array or else you probably don’t need cond.
  • You probably should be using dataclasses to bundle up all the loose variables. Your code is harder to read than it needs to be.

Yes; but it’s a minimal reproducible example. I usually have 5-20 operations per such cond, but it’s for non-jax folks to understand what’s happening.

I’m also writing a 20k SLOS codebase and look… 309 cond statements is only a half of it; there will be at least 600 by the time I finish.

I’m rewrititing a big scicomp library from C++ to jax, and it’s, well, necessary.

creating your own DSL

I’ll could code in Mojo language, which is fast too, although not as fast as Jax (yet), and is missing some critical features. But then why if Python is kind of jack-of-all-trades already; it’s simply a matter of adding a few features to python that will make it usable for accelerators.

The flag that I’m suggesting will not interfere with performance of existing code, and I think it will be fairly straightforward to implement.

The “new feature” you’re proposing is a gigantic change to the language with a huge cost. You really need to brainstorm something else IMO.

The flag that I’m suggesting will not interfere with performance of existing code, and I think it will be fairly straightforward to implement.

Performance isn’t the only consideration. My guess is that it will not be a popular idea.

Yes, but my point is that you could recast your code in such a way that you don’t need Python to change at all. How about something like:

def f(a, b, c): ...
def g(d, e): ...
retval = super_cond(f, g)  # automatically pulls a,b,c,d,e out of the scope and sends them to f and g.

Note that it is absolutely possible for JIT compiles to deal with if statement, see numba. They look at the byte code, which does ofcourse have other issues, most notably having to make potential gigantic rewrites every python release because the byte code changes. But IMO making this kind of approach more stable is a better idea than what is suggested in OP.

Yeah, and PyTorch has the ability to do that too. In my experience that’s a horribly flimsy solution. Jax’s approach is a lot better.

Incidentally, I’ve written many thousands of lines of Jax code and haven’t needed many conds at all. I’m not sure why the OP needs so many of them? Are you not using jnp.where for all your math code?

With the right implementation in shouldn’t be. Should it?

The “new feature” you’re proposing is a gigantic change >to the language with a huge cost. You really need to brainstorm something else IMO.

Look, I’m not a pro in C, but I think…
Steps:

  1. Create a separate compilation mode/flow which would pull up if and for functions for only this flagged mode;
  2. Whenever the flag is called use if and for code from the flagged version and not the usual ones; I think it can be compiled somehow to be optimal.

Kind of like XLA chooses GPU or TPU backends for its calls, python can choose “backends” for for and if statements.

Which is simple and fast.

Like I said, the cost isn’t only speed. The cost is the human cost of maintaining this, the human cost of debugging things every time something doesn’t work to know whether the option is on and whether it’s causing the problem, the cost of debugging Jax code when this option is on, etc.

The cost is that you’re making Python itself way more complicated. That’s what I think people will balk at. Essentially, you’re asking the Python developers to write the DSL that you want.

Also, like I said, I don’t see why you even need this. I’ve written tens of thousands of lines of Jax code and haven’t needed anything like this. Can you shed some more light on why your proposal is the simplest way of writing this?

Where is separate:

So is vmap - I have 133 calls to vmap in the code, and even 43 fori loops… Although I’m dead fori’s can be refactored; I haven’t touched them code in months.

Nevertheless it would be no less than 400 control flow loops

So a lot of control flow.

It’s computationally expensive to run CPU code and why I’m moving it to Jax. It takes 50+ hours to compute on CPU of something of that range and I need thousands of those simulations. GPUs can be 10-100x faster, and I’m basically moving C++ code to jax because it’s fast, simpler than C++, and works on any GPUs automatically.

I don’t think the “backend” implementation I suggested above is very complicated to implement or maintain. At all, it’s basically <300 SLOS, maybe just a little more.

Additionally - python is the go-to language for modern ML development; I mean all of the ML infrastructure is on it, and as a scicomp+ML engineer… It’s not really a DSL when that many people use it in this way. (mostly very smart people too)

Python is very commonly datascience/ML and it’s the best usage of the language. The option I’m suggesting is for those tens of thousands of people who would benefit from it.

Sounds like you have a prototype working? Can you show it?

Or are you just assuming based on your incomplete knowledge of how Python is implemented?

Also, I still am not really sure how your proposal will work. Will it change what byte code is generated? Could you show an example of how you imagine this feature being used?

How about a different interface? Command line flags are not very flexible and have many downsides. Could this be toggled locally via function calls or something like a decorator?