Hello Python community,
First - I’ll say that python is the best language in the world, but it’s quite difficult to make it performant (as everybody knows).
So we have come up with a just-in-time compilation for performance - offloading work to C++ and CUDA code which makes the code run extremely fast. Jax
is currently the fastest framework for any numerical computation in the world. It easily beats even most optimal C++ code in, for example, structural analysis by an order of magnitude.
Thing is, JIT-compiled code is difficult to write. JIT
can not compile if
statements because they are untraceable and jax
has a way of knowing about them.
We have to write attrocities like these:
import jax.numpy as np
import jax.lax.cond as cond()
def headfunc(count_found_sums:bool, multiply:int):
# a,b,c,d,e,f are defined higher.
def func1(a:int,b:int,c:np.ndarray,_,_,_):
sum=a+b
count_found_items=np.count_nonzero(c==sum)
return count_found_items
def func2(_,_,_,d:np.ndarray, e:np.ndarray, f: np.ndarray):
diff1 = np.count_nonzero(d==e)
diff2 = np.count_nonzero(d==f)
overlap = np.logical_and(diff1=diff2))
return np.count_nonzero(overlap))
result = cond( # cond is the `if` statement in Jax.
count_found_sums,
func1,
func2,
(a,b,c,d,e,f) #notice: we have to pass all because conditionals require
)
return result
Which is an equivalent of:
def headfunc(count_found_sums:bool, multiply:int):
# a,b,c,d,e,f are defined higher.
if cound_found_sums:
sum=a+b
count_found_items=np.count_nonzero(c==sum)
return count_found_items
else:
diff1 = np.count_nonzero(d==e)
diff2 = np.count_nonzero(d==f)
overlap = np.logical_and(diff1=diff2))
return np.count_nonzero(overlap))
# **7 lines shorter**
On thousands of lines of code the code becomes 50% longer due to if
conditions being more difficult to write, in my usecase at least.
Pytorch has something similar with it’s jit
not being able to trace if
, and they have to write something similar. (only read about it online though, never worked with it.)
Proposal:
In order to allow much simpler jit
-compiled code, add a flag to python launch that would trigger a callback every time compilation will find python if
and preferably for
statements.
something like this:
python3 main.py --flag-control-flow
I personally had to create a 8k SLOS codebase of performance-critical scientific compute code recently and let me tell you, there is every reason to add this if python wants to be a fast for scientists and machine learning developers (which it wants to).
Jax developer jakevdp (Jake Vanderplas) · GitHub will probably be able to know implementation a little better than myself, since I’m not a developer of jax
, although I do use it a lot.