Higher-Order Functions#
A higher-order function is a function that takes another function as an
argument or returns one as its result. You met the idea in
Functions as Values: sorted is
higher-order because it accepts a key function. Three higher-order
functions come up so often that Python builds them in — map,
filter, and reduce. Each takes a function and a sequence and gives
back a result, with no loop and no mutable counter in sight.
map: Transform Every Element#
map(f, seq) applies f to every element of seq. It returns a
lazy iterator, so wrap it in list to see the results:
>>> def fahrenheit(c):
... return c * 9 / 5 + 32
...
>>> temps = [-5, 0, 12, 100]
>>> list(map(fahrenheit, temps))
[23.0, 32.0, 53.6, 212.0]
This is the same transformation a list comprehension expresses, and for most code a comprehension reads more clearly:
>>> nums = [1, 2, 3, 4, 5]
>>> list(map(lambda n: n * n, nums))
[1, 4, 9, 16, 25]
>>> [n * n for n in nums]
[1, 4, 9, 16, 25]
The guidance from List Comprehensions still
holds: prefer the comprehension when you would otherwise write a small
lambda, and reach for map when you already have a named function to
apply (map(fahrenheit, temps) reads better than repeating its body).
filter: Keep the Elements That Qualify#
filter(predicate, seq) keeps only the elements for which predicate
returns a true value. A predicate is just a function that returns a
boolean:
>>> def is_freezing(c):
... return c <= 0
...
>>> temps = [-5, 0, 12, 100]
>>> list(filter(is_freezing, temps))
[-5, 0]
Again there is a comprehension form, [c for c in temps if c <= 0], and
again it is usually the clearer of the two. The value of knowing map
and filter is partly that you will meet them in other people’s code and
in other languages, and partly that they name the two most common shapes of
data processing: transform each element, and keep some elements.
reduce: Combine Everything Into One Value#
The third shape is combine the elements into a single value — a sum, a
product, a maximum. functools.reduce(f, seq, start) folds the sequence
up by repeatedly applying f to the running result and the next element:
>>> from functools import reduce
>>> reduce(lambda total, n: total * n, [1, 2, 3, 4, 5], 1)
120
>>> reduce(lambda total, n: total + n, [1, 2, 3, 4, 5], 0)
15
The start value (1 for a product, 0 for a sum) is what
reduce returns for an empty sequence, which is why it is also called the
identity for the operation. Most common reductions already have a
dedicated built-in — sum, max, min, all, any — and you
should prefer those when they exist. reduce is for the cases that do
not, such as a running product:
def product(numbers: list) -> int:
"""Multiply every number together, starting from 1."""
return reduce(lambda total, n: total * n, numbers, 1)
A wrapper like product is worth writing because the name says what the
fold means, where the bare reduce(lambda ...) makes the reader work it
out. That is the recurring lesson of higher-order functions: they let you
package a pattern of computation behind a meaningful name.
accumulate: Keep the Running Results#
reduce throws away its work in progress and returns only the final
value. Often you want the whole trail of intermediate results instead — a
running total, a running maximum, a bank balance after each transaction.
Scala calls this scanLeft because it scans from the left, accumulating as
it goes; Python provides it as itertools.accumulate.
Where reduce returns one number, accumulate yields the accumulator
after each step:
>>> from itertools import accumulate
>>> list(accumulate([1, 2, 3, 4, 5])) # running sums
[1, 3, 6, 10, 15]
>>> from functools import reduce
>>> reduce(lambda a, b: a + b, [1, 2, 3, 4, 5]) # reduce keeps only the last
15
By default accumulate adds, but you can pass any function of two
arguments to combine differently, and an initial value to seed the
running result — which is exactly Scala’s scanLeft(initial)(op):
>>> import operator
>>> from itertools import accumulate
>>> list(accumulate([1, 2, 3, 4, 5], operator.mul)) # running products
[1, 2, 6, 24, 120]
>>> list(accumulate([10, -4, 3, -8], initial=0)) # balance after each change
[0, 10, 6, 9, 1]
The last example reads as a tiny bank statement: start at 0, then show
the balance after a deposit of 10, a withdrawal of 4, and so on.
Naming the pattern keeps the intent clear:
def running_total(amounts: list) -> list:
"""Return the balance after each amount, starting from 0.
Like Scala's scanLeft, this keeps every intermediate sum, not just the
final one that reduce would give.
"""
return list(accumulate(amounts, initial=0))
So reduce answers “what is the final total?” and accumulate answers
“what was the total at every point along the way?” Reach for the scan
whenever the history of the computation is as interesting as its result.