Using jit

We know how to find hotspots now, how do we improve their performance?

We jit them!

We'll start with a trivial example but get to some more realistic applications shortly.

Array sum

The function below is a naive sum function that sums all the elements of a given array.

def sum_array(inp):
J, I = inp.shape

mysum = 0
for j in range(J):
for i in range(I):
mysum += inp[j, i]

return mysum

import numpy

arr = numpy.random.random((300, 300))


First hand the array arr off to sum_array to make sure it works (or at least doesn't error out)

sum_array(arr)

45102.033677230997


Now run and save timeit results of sum_array as a baseline to compare against.

plain = %timeit -o sum_array(arr)

100 loops, best of 3: 14.7 ms per loop


Let's get started

from numba import jit


Note: There are two ways to jit a function. These are just two ways of doing the same thing. You can choose whichever you prefer.

As a function call

sum_array_numba = jit()(sum_array)


What's up with the weird double ()s? We'll cover that in a little bit.

Now we have a new function, called sum_array_numba which is the jitted version of sum_array. We can again make sure that it works (and hopefully produces the same result as sum_array).

sum_array_numba(arr)

45102.033677231


Good, that's the same result as the first version, so nothing has gone horribly wrong.

Now let's time and save these results.

jitted = %timeit -o sum_array_numba(arr)

10000 loops, best of 3: 73.9 µs per loop


Wow. 73.7 µs is a lot faster than 15.5 ms... How much faster? Let's see.

plain.best / jitted.best

198.80740381645145


So, a factor of 210x. Not too shabby. But we're comparing the best runs, what about the worst runs?

plain.worst / jitted.worst

278.7481542534447


Yeah, that's still an incredible speedup.

(more commonly) As a decorator

The second way to jit a function is to use the jit decorator. This is a very easy syntax to handle and makes applying jit to a function trivial.

Note that the only difference in terms of the outcome (compared to the other jit method) is that there will be only one function, called sum_array that is a Numba jitted function. The "original" sum_array will no longer exist, so this method, while convenient, doesn't allow you to compare results between "vanilla" and jitted Python.

When should you use one or the other? That's up to you. If I'm investigating whether Numba can help, I use jit as a function call, so I can compare results. Once I've decided to use Numba, I stick with the decorator syntax since it's much prettier (and I don't care if the "original" function is available).

@jit
def sum_array(inp):
I, J = inp.shape

mysum = 0
for i in range(I):
for j in range(J):
mysum += inp[i, j]

return mysum

sum_array(arr)

45102.033677231


So again, we can see that we have the same result. That's good. And timing?

%timeit sum_array(arr)

10000 loops, best of 3: 73.8 µs per loop


As expected, more or less identical to the first jit example.

How does this compare to NumPy?

NumPy, of course, has built in methods for summing arrays, how does Numba stack up against those?

%timeit arr.sum()

10000 loops, best of 3: 33.9 µs per loop


Right. Remember, NumPy has been hand-tuned over many years to be very, very good at what it does. For simple operations, Numba is not going to outperform it, but when things get more complex Numba can save the day.

Also, take a moment to appreciate that our jitted code, which was compiled on-the-fly is offering performance in the same order of magnitude as NumPy. That's pretty incredible.

When does numba compile things?

numba is a just-in-time (hence, jit) compiler. The very first time you run a numba compiled function, there will be a little bit of overhead for the compilation step to take place. In practice, this is usually not noticeable. You may get a message from timeit that one "run" was much slower than most; this is due to the compilation overhead.