Is this just magic? What is Numba doing to make code run quickly?
When you add the jit
decorator (or function call), Numba examines the code in the function and then tries to compile it using the LLVM compiler. LLVM takes Numba's translation of the Python code and compiles it into something like assembly code, which is a set of very low-level and very fast instructions.
Let's create a small, simple example function to poke around in:
from numba import jit
@jit
def add(a, b):
return a + b
add(1, 1)
2
Now that we've run add
once, it is now compiled and we can check out what's happened behind the scenes. Use the inspect_types
method to see how Numba translated the function.
add.inspect_types()
add (int64, int64)
--------------------------------------------------------------------------------
# File: <ipython-input-2-1c683d2d00ee>
# --- LINE 1 ---
# label 0
# del b
# del a
# del $0.3
@jit
# --- LINE 2 ---
def add(a, b):
# --- LINE 3 ---
# a = arg(0, name=a) :: int64
# b = arg(1, name=b) :: int64
# $0.3 = a + b :: int64
# $0.4 = cast(value=$0.3) :: int64
# return $0.4
return a + b
================================================================================
That is a bit more complicated than our original line -- and in fact, there's a bunch of even more complicated stuff going on behind the scenes, but we won't go into that right now. For now, just recognize that Numba is examining the code we wrote, then translating it into a more complex representation that can be efficiently compiled into a super-fast version.
However...
This translation business is hard and Numba isn't perfect. If it encounters something that it doesn't understand, then it will still work, but it will operate in what is called "Object Mode". This is fine, except "Object Mode" can be really, really slow.
So what can we do to avoid object mode?
Well, first, there's a list of supported features that Numba understands that you can browse at your leisure (but do this later): http://numba.pydata.org/numba-doc/latest/reference/pysupported.html
Forcing nopython
mode
The opposite of the slow "object mode" is called nopython
mode. That's a kind of confusing name, but it is what it is. The important thing to remember is that nopython
mode is when Numba is fast, so that's what we want.
But how do we know what "mode" Numba is using?
That's a good question. We don't always know, and we can't know ahead of time, but we do have one helper to look out for us.
If we specify nopython=True
, then Numba will throw an exception and fail to compile when it can't make a function work in nopython
mode. Then we can try to rewrite that function until it can compile.
Here's a quick example. First import numpy
and the linalg
module from scipy
.
import numpy
from scipy import linalg
Define a random square array:
a = numpy.random.random((5, 5))
Now write a function to pass that array to the linalq
QR decomposition method:
def qr_decomposition(a):
return linalg.decomp_qr.qr(a)
Now let's try it out:
qr_decomposition(a)
(array([[-0.61251218, 0.3080602 , 0.61705293, -0.29224924, -0.25251496],
[-0.51533223, -0.15379307, -0.29646637, -0.32818401, 0.71776273],
[-0.55845413, -0.50433313, -0.22256529, 0.49326552, -0.37540764],
[-0.10076405, 0.33415965, -0.67259288, -0.44155593, -0.48044886],
[-0.19296921, 0.71793594, -0.1715718 , 0.60712694, 0.22201543]]),
array([[-1.58468704, -1.26245499, -1.01284515, -1.29478127, -0.61076935],
[ 0. , 0.62312427, -0.04924057, 0.17049891, 0.09324721],
[ 0. , 0. , -0.77693122, -0.43097318, -0.72632735],
[ 0. , 0. , 0. , -0.46093787, -0.15692113],
[ 0. , 0. , 0. , 0. , -0.05197661]]))
It works! Ok, now let's try to jit
it:
qr_jit = jit()(qr_decomposition)
qr_jit(a)
(array([[-0.61251218, 0.3080602 , 0.61705293, -0.29224924, -0.25251496],
[-0.51533223, -0.15379307, -0.29646637, -0.32818401, 0.71776273],
[-0.55845413, -0.50433313, -0.22256529, 0.49326552, -0.37540764],
[-0.10076405, 0.33415965, -0.67259288, -0.44155593, -0.48044886],
[-0.19296921, 0.71793594, -0.1715718 , 0.60712694, 0.22201543]]),
array([[-1.58468704, -1.26245499, -1.01284515, -1.29478127, -0.61076935],
[ 0. , 0.62312427, -0.04924057, 0.17049891, 0.09324721],
[ 0. , 0. , -0.77693122, -0.43097318, -0.72632735],
[ 0. , 0. , 0. , -0.46093787, -0.15692113],
[ 0. , 0. , 0. , 0. , -0.05197661]]))
It works! Or did it? What if try to add the nopython=True
flag?
(Also, remember how we talked about those super weird second set of parentheses? Here's where they come in)
qr_jit = jit(nopython=True)(qr_decomposition)
Prepare for a very long error message...
qr_jit(a)
---------------------------------------------------------------------------
UntypedAttributeError Traceback (most recent call last)
<ipython-input-12-b8e9f2ef8dc2> in <module>()
----> 1 qr_jit(a)
/home/gil/anaconda/lib/python3.5/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
307 for i, err in failed_args))
308 e.patch_message(msg)
--> 309 raise e
310
311 def inspect_llvm(self, signature=None):
/home/gil/anaconda/lib/python3.5/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
284 argtypes.append(self.typeof_pyval(a))
285 try:
--> 286 return self.compile(tuple(argtypes))
287 except errors.TypingError as e:
288 # Intercept typing error that may be due to an argument
/home/gil/anaconda/lib/python3.5/site-packages/numba/dispatcher.py in compile(self, sig)
530
531 self._cache_misses[sig] += 1
--> 532 cres = self._compiler.compile(args, return_type)
533 self.add_overload(cres)
534 self._cache.save_overload(sig, cres)
/home/gil/anaconda/lib/python3.5/site-packages/numba/dispatcher.py in compile(self, args, return_type)
79 impl,
80 args=args, return_type=return_type,
---> 81 flags=flags, locals=self.locals)
82 # Check typing error if object mode is used
83 if cres.typing_error is not None and not flags.enable_pyobject:
/home/gil/anaconda/lib/python3.5/site-packages/numba/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library)
682 pipeline = Pipeline(typingctx, targetctx, library,
683 args, return_type, flags, locals)
--> 684 return pipeline.compile_extra(func)
685
686
/home/gil/anaconda/lib/python3.5/site-packages/numba/compiler.py in compile_extra(self, func)
346 self.lifted = ()
347 self.lifted_from = None
--> 348 return self._compile_bytecode()
349
350 def compile_ir(self, func_ir, lifted=(), lifted_from=None):
/home/gil/anaconda/lib/python3.5/site-packages/numba/compiler.py in _compile_bytecode(self)
647 """
648 assert self.func_ir is None
--> 649 return self._compile_core()
650
651 def _compile_ir(self):
/home/gil/anaconda/lib/python3.5/site-packages/numba/compiler.py in _compile_core(self)
634
635 pm.finalize()
--> 636 res = pm.run(self.status)
637 if res is not None:
638 # Early pipeline completion
/home/gil/anaconda/lib/python3.5/site-packages/numba/compiler.py in run(self, status)
233 # No more fallback pipelines?
234 if is_final_pipeline:
--> 235 raise patched_exception
236 # Go to next fallback pipeline
237 else:
/home/gil/anaconda/lib/python3.5/site-packages/numba/compiler.py in run(self, status)
225 try:
226 event(stage_name)
--> 227 stage()
228 except _EarlyPipelineCompletion as e:
229 return e.result
/home/gil/anaconda/lib/python3.5/site-packages/numba/compiler.py in stage_nopython_frontend(self)
434 self.args,
435 self.return_type,
--> 436 self.locals)
437
438 with self.fallback_context('Function "%s" has invalid return type'
/home/gil/anaconda/lib/python3.5/site-packages/numba/compiler.py in type_inference_stage(typingctx, interp, args, return_type, locals)
783
784 infer.build_constraint()
--> 785 infer.propagate()
786 typemap, restype, calltypes = infer.unify()
787
/home/gil/anaconda/lib/python3.5/site-packages/numba/typeinfer.py in propagate(self, raise_errors)
759 if errors:
760 if raise_errors:
--> 761 raise errors[0]
762 else:
763 return errors
/home/gil/anaconda/lib/python3.5/site-packages/numba/typeinfer.py in propagate(self, typeinfer)
126 lineno=loc.line):
127 try:
--> 128 constraint(typeinfer)
129 except TypingError as e:
130 errors.append(e)
/home/gil/anaconda/lib/python3.5/site-packages/numba/typeinfer.py in __call__(self, typeinfer)
447 attrty = typeinfer.context.resolve_getattr(ty, self.attr)
448 if attrty is None:
--> 449 raise UntypedAttributeError(ty, self.attr, loc=self.inst.loc)
450 else:
451 typeinfer.add_type(self.target, attrty, loc=self.loc)
UntypedAttributeError: Failed at nopython (nopython frontend)
Unknown attribute 'qr' of type Module(<module 'scipy.linalg.decomp_qr' from '/home/gil/anaconda/lib/python3.5/site-packages/scipy/linalg/decomp_qr.py'>)
File "<ipython-input-7-aa69e1f11031>", line 2
[1] During: typing of get attribute at <ipython-input-7-aa69e1f11031> (2)
Ack
Yes, that's a very long and intimidating looking error message, but just focus on the last few lines, specifically "Failed at nopython". That's Numba telling us that it has no idea what scipy.linalg.decomp_qr
is, so it can't try to accelerate it.
It worked the first time because it was in "object" mode but we just asked Numba to force nopython
mode and it tried (and failed).
nopython
mode is so useful, that people got tired of typing it out all of the time, so there's a shortcut!
from numba import njit
njit
is exactly the same as jit
but it always forces nopython=True
. And like jit
, you can use it in a function call, or as a decorator. Let's try it out on the simple add
function we started out with:
def add(a, b):
return a + b
Function call:
add_jit = njit(add) # no extra parentheses needed with `njit`
add_jit(3, 4)
7
Decorator:
@njit
def add(a, b):
return a + b
add(4, 6)
10
And that's it!
Unless you have very specific requirements, we recommend always using njit
over jit
, so you can guarantee that you're taking advantage of all of Numba's speedup power.