Using jit
We know how to find hotspots now, how do we improve their performance?
We jit
them!
We'll start with a trivial example but get to some more realistic applications shortly.
Array sum
The function below is a naive sum
function that sums all the elements of a given array.
def sum_array(inp):
J, I = inp.shape
#this is a bad idea
mysum = 0
for j in range(J):
for i in range(I):
mysum += inp[j, i]
return mysum
import numpy
arr = numpy.random.random((300, 300))
First hand the array arr
off to sum_array
to make sure it works (or at least doesn't error out)
sum_array(arr)
45102.033677230997
Now run and save timeit
results of sum_array
as a baseline to compare against.
plain = %timeit -o sum_array(arr)
100 loops, best of 3: 14.7 ms per loop
Let's get started
from numba import jit
Note: There are two ways to jit
a function. These are just two ways of doing the same thing. You can choose whichever you prefer.
As a function call
sum_array_numba = jit()(sum_array)
What's up with the weird double ()
s? We'll cover that in a little bit.
Now we have a new function, called sum_array_numba
which is the jit
ted version of sum_array
.
We can again make sure that it works (and hopefully produces the same result as sum_array
).
sum_array_numba(arr)
45102.033677231
Good, that's the same result as the first version, so nothing has gone horribly wrong.
Now let's time and save these results.
jitted = %timeit -o sum_array_numba(arr)
10000 loops, best of 3: 73.9 µs per loop
Wow. 73.7 µs is a lot faster than 15.5 ms... How much faster? Let's see.
plain.best / jitted.best
198.80740381645145
So, a factor of 210x. Not too shabby. But we're comparing the best runs, what about the worst runs?
plain.worst / jitted.worst
278.7481542534447
Yeah, that's still an incredible speedup.
(more commonly) As a decorator
The second way to jit
a function is to use the jit
decorator. This is a very easy syntax to handle and makes applying jit
to a function trivial.
Note that the only difference in terms of the outcome (compared to the other jit
method) is that there will be only one function, called sum_array
that is a Numba jit
ted function. The "original" sum_array
will no longer exist, so this method, while convenient, doesn't allow you to compare results between "vanilla" and jit
ted Python.
When should you use one or the other? That's up to you. If I'm investigating whether Numba can help, I use jit
as a function call, so I can compare results. Once I've decided to use Numba, I stick with the decorator syntax since it's much prettier (and I don't care if the "original" function is available).
@jit
def sum_array(inp):
I, J = inp.shape
mysum = 0
for i in range(I):
for j in range(J):
mysum += inp[i, j]
return mysum
sum_array(arr)
45102.033677231
So again, we can see that we have the same result. That's good. And timing?
%timeit sum_array(arr)
10000 loops, best of 3: 73.8 µs per loop
As expected, more or less identical to the first jit
example.
How does this compare to NumPy?
NumPy, of course, has built in methods for summing arrays, how does Numba stack up against those?
%timeit arr.sum()
10000 loops, best of 3: 33.9 µs per loop
Right. Remember, NumPy has been hand-tuned over many years to be very, very good at what it does. For simple operations, Numba is not going to outperform it, but when things get more complex Numba can save the day.
Also, take a moment to appreciate that our jit
ted code, which was compiled on-the-fly is offering performance in the same order of magnitude as NumPy. That's pretty incredible.
When does numba
compile things?
numba
is a just-in-time (hence, jit
) compiler. The very first time you run a numba
compiled function, there will be a little bit of overhead for the compilation step to take place. In practice, this is usually not noticeable. You may get a message from timeit
that one "run" was much slower than most; this is due to the compilation overhead.