Compile-time computation has obvious appeal: any work you do when you build the program is already done when you run it—hence, faster programs.
This sounds great in theory, but a problem soon arises in practice. Most of the things you want to compute, you won’t know until you run the program. (That’s why we have programs, instead of just recording the final answer: we want to be able to deal with a wide variety of inputs, on demand.)
That’s what this post is about: bridging the gap between what we can know when we build the program, and what we need to know when we run the program.
I won’t bury the lede; here is the basic strategy:
- Compute all values we could possibly need, up to some maximum.
- Store the results in specific memory locations when we build the program.
- Access these memory locations when we run the program.
There are a few twists. We’ll use constexpr
for the functions that tell us where to store and find the values: this lets us call the exact same functions when we build and run the program, and also lets us write tests for them. We’ll also protect our users from the gory details by exposing only a polished, lightweight, usable class, which communicates its limitations in a natural way.
The problem of interest
So, what are we actually computing?
Well, I’m writing poker software. I very often need to know how many ways I can choose a certain number of objects out of a larger collection. For example: how many poker hands are there? As many as there are ways to choose 5 cards out of a 52 card deck. This is “52 choose 5”, usually written \(\binom{52}{5}\), and it’s the kind of quantity we want to compute.
A couple of properties make this a nice application for some compile-time programming.
- First, we know the maximum when we build the program: we’ll never be choosing from more than 52 things. As I said, compile-time computation needs to know a maximum number ahead of time, and it’s very appealing that our limit is a natural feature of the problem we’re solving, and not artificial.
- Second, it’s a 2-dimensional example (how many total things, and how many things we choose). Almost every other example uses either factorials or Fibonacci numbers, which are 1-dimensional. It’s obvious how to lay those out in memory; the two-dimensional example adds a pleasing wrinkle to the problem.
So, this is the problem at hand: compute all combinatorics results for up to 52 items at compile time, and make them conveniently available at runtime.
The post is rather long, as it’s meant as a complete guide for readers looking to do their own compile-time computation (including code samples, and my commentary on the design decisions). Less ambitious readers might rather skim than read in depth. That said, some aspects that may appeal to the template-averse crowd include:
- my utter delight at rediscovering the simple elegance of the formula for computing combinatorics (so simple, a child could understand it!);
- a surprisingly concise, fast, and accurate square-root function;
- and, the final API we expose to the users—I hope it communicates its abilities and limitations in a natural way.
Let’s begin.
Computations for specific values of \(N\) and \(K\)
We’ll start with the core code—the part which does the actual computations.
// The recurrence relation (see Pascal's Triangle).
template <size_t N, size_t K>
struct Choose : std::integral_constant<uint64_t,
Choose<N - 1, K - 1>::value +
Choose<N - 1, K>::value>
{
static_assert(K <= N, "Cannot choose more than N items");
};
(The static_assert
is there for good hygiene: now, nonsense values won’t even attempt to compile.)
This recurrence relation lets us build more complicated answers out of simpler ones. Let’s pause and take a moment to appreciate just how simple, transparent, and elegant it is.
We want to know how many ways we can choose \(k\) items from \(n\), written as \(\binom{n}{k}\). Let’s imagine we already crunched all the numbers for fewer than \(n\) items. Well, every subset of \(k\) items from \(n\) will either include that new \(n\)th item, or it won’t.
- If the subset does include it, then there are \((k-1)\) items left to choose from the original \((n-1)\); this is just \(\binom{n-1}{k-1}\).
- If the subset doesn’t include it, then we have to choose all \(k\) items from the previous \((n-1)\); this is \(\binom{n-1}{k}\).
Since these groups can’t overlap, we simply add these two counts to get the total, and the famous equation falls right out:
\[\binom{n}{k} = \binom{n-1}{k-1} + \binom{n-1}{k}.\]
That’s the mathematical elegance here. The computational elegance is that we get the benefits of dynamic programming for free. In other words, the compiler computes every \((n,k)\) pair only once, and automatically reuses the answer after that (because it is encoded in a type, and each type is defined only once). This is a very good thing, since otherwise, the time to compute each number would be proportional to the number itself. \(\binom{n}{k}\) gets pretty big, pretty fast. If we let \(n\) go up to only, say, 52 (since there are 52 cards in the deck), there are some individual answers which would take half a quadrillion steps to compute!
So, we have a computationally efficient implementation of a mathematically elegant relationship. That’s great, but so far, it won’t actually compute anything. We’ve defined our computation in terms of simpler values, but we’ve never said what those simpler values are. We need a base case—or rather, two.
// Base case for when we choose all the items.
template <size_t N>
struct Choose<N, N> : std::integral_constant<uint64_t, 1u>
{
};
// Base case for when we choose none of the items.
template <size_t N>
struct Choose<N, 0> : std::integral_constant<uint64_t, 1u>
{
};
The base cases for \(n\)-choose-\(k\) occur when we choose all of the items, or none of them. In either case, the answer is 1: there is exactly one way to do that.
This is now good enough to actually start computing some values. If we write, say, Choose<52, 5>
, the compiler will replace it with 2598960
everywhere in our program, just as if we’d hardcoded the constant. In fact, it’s good enough to compute almost all the values, with one exception.
Consider \(\binom{0}{0}\): the number of ways to choose 0 items out of 0. Yes, it is well-defined; it happens to be 1, as we might have guessed from either of the rules above. But the compiler can’t figure out which of these rules to use (even though they both happen to give the same answer). We have to tell it explicitly:
// Base case for when there are no items.
template <>
struct Choose<0, 0> : std::integral_constant<uint64_t, 1u>
{
};
You could be forgiven for wondering whether we really need this special case at all: when in our code are we ever going to compute \(\binom{0}{0}\)? Spoiler alert: we won’t, but if we eliminate this bit of code, we’ll need to add ugly, hacky special cases for the rest of the code we’ll still need to write. It’s simpler just to cover all the cases now.
How do we use it?
There’s a problem here.
We can compute any values we want for \(\binom{n}{k}\) at compile time—that is, at the time the program gets built. But we won’t know which values we actually want to compute until the program gets run! In other words, we can handle something like Choose<52, 5>
, no sweat. But if we have variables \(n = 52\) and \(k = 5\) when our program’s running, we can’t compute Choose<n, k>
, because everything inside the <...>
has to finish before the program gets run.
Here’s a two-step strategy to make this actually useful:
- Decide ahead of time which values we could possibly need;
- Figure out how to store them so we can actually get them when we run the program.
The first step is easy. This program is always choosing cards, and there are 52 of them. If we store all the values for every \(n\) and \(k\) up to \(n = 52\), we’ll be all set. This works out to 1431 different values: not remotely taxing.
For the second step, a simple array seems natural. It’s not totally trivial, because we’ll need a two-way mapping.
- For any given slot index \(i\), which \((n, k)\) pair is stored there? (We’ll use this when we write the answers.)
- For each \((n,k)\) pair, which of the 1431 slots \(i\) should we store its value in? (We’ll use this when we read the answers.)
Index math
I decided to use the most natural mapping: first do \(\binom{0}{0}\), then both of the \(\binom{1}{k}\), then all 3 of the \(\binom{2}{k}\), etc. We see the pattern here: the first index for a given \(n\) is just \((1 + 2 + ... + n) = n(n+1)/2\). After that, we just need to add \(k\):
constexpr size_t index(size_t n, size_t k) { return n * (n + 1) / 2 + k; }
Inverting these mappings is reasonably straightforward. We notice that index()
is quadratic in \(n\), so we can solve for it using the quadratic formula. Then once we have \(n\), we know \(k\) is just how far we are past the first index for that \(n\):
\[ \begin{align} n(i) &= \left\lfloor \frac{\sqrt{1 + 8i} - 1}{2} \right\rfloor \\ k(i) &= i - \mathtt{index}(n(i), 0) \end{align} \]
Here, we hit a slight detour. It turns out that we can’t use C++’s standard sqrt()
function at compile time, because of its side effects. We’ll have to roll our own! Fortunately, this isn’t too hard:
constexpr double sqrt_newton_raphson(double x, double curr, double prev)
{
return (curr == prev) ? curr
: sqrt_newton_raphson(x, 0.5 * (curr + x / curr), curr);
}
constexpr double sqrt(double x)
{
return (x >= 0. && x < std::numeric_limits<double>::max())
? sqrt_newton_raphson(x, x, 0.)
: std::numeric_limits<double>::quiet_NaN();
}
Such a neat algorithm! Not only is it fast, accurate, and easy to understand, but it’s also one of the few instances where comparing floating point numbers for exact equality actually makes sense (as opposed to being a huge mistake, a bug waiting to happen).
My favourite aspect of this function is how transparent the logic is. Forget the calculus used to derive it, and just look at what it’s actually doing. We have a guess for the square root; call our current guess curr
. If curr
were the true square root, then dividing our number by curr
would give us curr
again. Instead, it gives us another number, which is too high when curr
is too low, and vice versa. The square root must be in-between these two numbers; so, we take their average as our next guess.
With a compiler-compatible sqrt()
in our toolbox, we can close the loop on our index functions:
constexpr size_t n(size_t i) {
return static_cast<size_t>((sqrt(1 + 8 * i) - 1) / 2);
}
constexpr size_t k(size_t i) { return i - index(n(i), 0); }
Pro tip: this is a great opportunity to add a few unit tests to check the round-trip identity (i.e., making sure that \(\mathtt{index}(n(i), k(i)) = i\)).
Building an array
This is the bridge between compile-time and runtime, the way our precomputed values can actually get used. We need a function which takes in a bunch of indices, and returns the \(\binom{n}{k}\)-values corresponding to those indices, in an array.
// Compile-time populated array with the first consecutive values of N-choose-K
// (as enumerated in the natural ordering).
template <size_t... Is>
constexpr std::array<uint64_t, sizeof...(Is)>
choose_values(std::index_sequence<Is...>)
{
return std::array<uint64_t, sizeof...(Is)>({Choose<n(Is), k(Is)>::value...});
}
We’re using variadic templates—that is, templates that accept an arbitrary number of parameters. I had always shied away from learning how to use these, as the syntax looked strange and forbidding to me. But working through a simple example like this helped me see they’re actually pretty straightforward. We get an index sequence (whose size and contents are known at compile time), and we build an array of the same size, turning each index into the corresponding \(\binom{n}{k}\) value.
With this in hand, all that’s left is to pass the right indices.
Getting the “first \(m\) indices”
We’re going to want all the indices from \(0\) up to \(m(n_\text{max})\), where \(m\) is some function that tells us how many values there are up to \(n_\text{max}\). It’s pretty straightforward to deduce from the pattern of the first few values:
// Compute the number of N-choose-K values with N at most some maximum value.
constexpr size_t num_values_up_to(size_t n) { return (n + 1) * (n + 2) / 2; }.
Now we have all we need to create our final array:
// Compile-time populated array with all n-choose-k up to some maximum n.
template <size_t N_max>
constexpr std::array<uint64_t, num_values_up_to(N_max)> choose_values_up_to()
{
return choose_values(std::make_index_sequence<num_values_up_to(N_max)>());
}
Polish and usability
Calling a function that returns a 1431-valued array is not a very appealing interface, even if all the values were computed at compile time! It would be better to create the array once, at the beginning of the program, and make it easy to access. At the same time, we want to be careful to avoid messy global variables.
A good solution is a well-named templated class, which stores the (const
!) array as a private static
member variable. The static
means that all members share the same copy of the array, and the private
means nobody will be able to mess with it. We can make instances of this class without having to think too hard about it, confident that they’ll be as small as possible.
Assuming all the messy details above are hidden inside of an internal
namespace, this would look something like the following.
template <size_t N_max>
class CappedCombinator
{
public:
uint64_t choose(size_t n, size_t k)
{
assert(n <= N_max);
assert(k <= n);
return cached_values_[internal::index(n, k)];
}
private:
const static std::array<uint64_t, internal::num_values_up_to(N_max)>
cached_values_;
};
template <size_t N_max>
const std::array<uint64_t, internal::num_values_up_to(N_max)>
CappedCombinator<N_max>::cached_values_ =
internal::choose_values_up_to<N_max>();
The assert()
s here are worth commenting on. Some folks feel that they should never be used. I don’t agree—I think this is a proper use of assert()
. It guards against programming errors only, not user input (since this class should never take user input directly). And, given that the whole point of this is to make combinatorics as fast as possible, I want the compiler to remove these statements completely when I turn off debug mode.
I plan on building separate debug and production versions of my program; assert()
is perfect for this.
Using the API
The real test of what I’ve built: how usable is it? How does it look at the call site? Is it easy to use correctly, and hard to use incorrectly? Let’s take a look at some example code.
CappedCombinator<cards::DECK_SIZE> combinator;
// A number which uniquely identifies this poker hand.
size_t hand_index = 0;
size_t start = 0;
size_t hand_cards_left = hand.card_indices.size();
for (const size_t index : hand.card_indices) // Assume the indices are sorted.
{
hand_index += (
combinator.choose(cards::DECK_SIZE - start, hand_cards_left) -
combinator.choose(cards::DECK_SIZE - index, hand_cards_left));
start = index + 1;
--hand_cards_left;
}
Don’t worry too much about the details of what the code is doing, or why we want to do it. The point is that it shows how easy it is to take the numbers we computed at compile time, and use them to compute meaningful quantities with arbitrary values at runtime. The limitations of the class (i.e., don’t pass anything higher than DECK_SIZE
) are clearly communicated when we declare the variable. And the only way we can violate those limitations is via a programming bug, not bad user input (assuming that hand.card_indices
is generated by some other part of the program, and not directly by the user).
All in all, my first adventure in figuring out how to useful work at compile time turned out pretty satisfyingly.