The softmax function is used frequently in neural networks, such as within large language models like ChatGPT. It converts a list of numbers into a probability distribution, with the values adding up to 1.

Softmax does this by exponentiating each value and then dividing by the sum of all the exponentiated values.

Using the previous example, here is how the softmax was computed:

Obviously, softmax uses the exponential functional, or base \(e\) operations!

However, in real world implementations, softmax is modified to use base 2 instead of base \(e\) because computers are much faster at computing base 2 operations compared to base \(e\):

This is a standard optimization in FlashAttention, where the softmax exponents are a non-negligible cost of the inner loop. It rescales the inputs by 1/ln2 so that you can directly use exp2 instead of exp.

So how do we convert the base \(e\) to base 2? We could just use the exponential base change formula1, but it's fun (and satisfying) to derive it ourselves.

We want to find some expression \(y\) so that \(e^x = 2^y\). We'll take the natural log of both sides to cancel out the \(e\) and use the logarithmic power rule to bring the \(y\) out.

We can now plug \(y\) back in:

So that's why FlashAttention scales inputs by \(1/\ln 2\)!

  1. \(a^b = c^{b \log_c a}\)