From slow to SIMD: A Go optimization story
So, there's this function. It's called a lot. More importantly, all those calls are on the critical path of a key user interaction. Let's talk about making it fast.
Spoiler: it's a dot product.
skip to the juicy stuff)
Some background (orAt Sourcegraph, we're working on a Code AI tool named Cody. In order for Cody to answer questions well, we need to give them enough context to work with. One of the ways we do this is by leveraging embeddings.
For our purposes, an embedding is a vector representation of a chunk of text. They are constructed in such a way that semantically similar pieces of text have more similar vectors. When Cody needs more information to answer a query, we run a similarity search over the embeddings to fetch a set of related chunks of code and feed those results to Cody to improve the relevance of results.
The piece relevant to this blog post is that similarity metric, which is the function that determines how similar two vectors are. For similarity search, a common metric is cosine similarity. However, for normalized vectors (vectors with unit magnitude), the dot product yields a ranking that's equivalent to cosine similarity. To run a search, we calculate the dot product for every embedding in our data set and keep the top results. And since we cannot start execution of the LLM until we get the necessary context, optimizing this step is crucial.
You might be thinking: why not just use an indexed vector DB? Outside of adding yet another piece of infra that we need to manage, the construction of an index adds latency and increases resource requirements. Additionally, standard nearest-neighbor indexes only provide approximate retrieval, which adds another layer of fuzziness compared to a more easily explainable exhaustive search. Given that, we decided to invest a little in our hand-rolled solution to see how far we could push it.
The target
This is a simple Go implementation of a function that calculates the dot product of two vectors. My goal is to outline the journey I took to optimize this function, and to share some tools I picked up along the way.
Unless otherwise stated, all benchmarks are run on an Intel Xeon Platinum 8481C 2.70GHz CPU. This is a c3-highcpu-44
GCE VM. The code in this blog post can all be found in runnable form here.
Loop unrolling
Modern CPUs do this thing called instruction pipelining where it can run multiple instructions simultaneously if it finds no data dependencies between them. A data dependency just means that the input of one instruction depends on the output of another.
In our simple implementation, we have data dependencies between our loop iterations. A couple, in fact. Both i
and
sum
have a read/write pair each iteration, meaning an iteration cannot start executing until the previous is finished.
A common method of squeezing more out of our CPUs in situations like this is known as loop unrolling. The basic idea is to rewrite our loop so more of our relatively-high-latency multiply instructions can execute simultaneously. Additionally, it amortizes the fixed loop costs (increment and compare) across multiple operations.
In our unrolled code, the dependencies between multiply instructions are removed, enabling the CPU to take more advantage of pipelining. This increases our throughput by 37% compared to our naive implementation.
DotNaive
DotUnroll4
Note that we can actually improve this slightly more by twiddling with the number of iterations we unroll. On the benchmark machine, 8 seemed to be optimal, but on my laptop, 4 performs best. However, the improvement is quite platform dependent and fairly minimal, so for the rest of the post, I'll stick with an unroll depth of 4 for readability.
Bounds-checking elimination
In order to keep out-of-bounds slice accesses from being a security vulnerability (like the famous
Heartbleed exploit), the go compiler inserts checks before each read. You
can check it out in the generated assembly (look for runtime.panic
).
The compiled code makes it look like we wrote something like this:
In a hot loop like this, even with modern branch prediction, the additional branches per iteration can add up to a pretty significant performance penalty. This is especially true in our case because the inserted jumps limit how much we can take advantage of pipelining.
If we can convince the compiler that these reads can never be out of bounds, it won't insert these runtime checks. This technique is known as "bounds-checking elimination", and the same patterns can apply to languages other than Go.
In theory, we should be able to do all checks once, outside the loop, and the compiler would be able to determine that all the slice indexing is safe. However, I couldn't find the right combination of checks to convince the compiler that what I'm doing is safe. I landed on a combination of asserting the lengths are equal and moving all the bounds checking to the top of the loop. This was enough to hit nearly the speed of the bounds-check-free version.
The minimizing of bounds checking nets a 9% improvement. Consistently non-zero, but nothing to write home about.
DotNaive
DotUnroll4
DotBCE
This technique translates well to many memory-safe compiled languages like Rust.
Exercise for the reader: why is it significant that we slice like a[i:i+4:i+4]
rather than just a[i:i+4]
?
Quantization
We've improved single-core search throughput by ~50% at this point, but now we've hit a new bottleneck: memory usage. Our vectors are 1536 dimensions. With 4-byte elements, this comes out to 6KiB per vector, and we generate roughly a million vectors per GiB of code. That adds up quickly. We had a few customers come to us with some massive monorepos, and we wanted to reduce our memory usage so we can support those cases more cheaply.
One possible mitigation would be to move the vectors to disk, but loading them from disk at search time can add
significant latency, especially on
slow disks. Instead, we chose to compress our vectors with int8
quantization.
There are plenty of ways to compress vectors, but we'll be
talking about integer quantization, which is
relatively simple, but effective. The idea is to reduce the precision of the 4-byte float32
vector elements by
converting them to 1-byte int8
s.
I won't get into exactly how we decide to do the translation between float32
and int8
, since that's a pretty deep
topic, but suffice it to say our function now looks
like the following:
This change yields a 4x reduction in memory usage at the cost of some accuracy (which we carefully measured, but is irrelevant to this blog post).
Unfortunately, re-running the benchmarks shows our search speed regressed a bit from the change. Taking a look at the
generated assembly (with go tool compile -S
), there are some new instructions for converting int8
to int32
, which
might explain the difference. I didn't dig too deep though, since all our performance improvements up to this point
become irrelevant in the next section.
DotNaive
DotUnroll4
DotBCE
DotInt8BCE
SIMD
The speed improvements so far were nice, but still not enough for our largest customers. So we started dabbling with more dramatic approaches.
I always love an excuse to play with SIMD. And this problem seemed like the perfect nail for that hammer.
For those unfamiliar, SIMD stands for "Single Instruction Multiple Data". Just like it's says, it lets you run an
operation over a bunch of pieces of data with a single instruction. As an example, to add two int32
vectors
element-wise, we could add them together one by one with the ADD
instruction and, or we can use the VPADDD
instruction to add 64 pairs at a time with the same
latency (depending on the architecture).
We have a problem though. Go does not expose SIMD intrinsics like C or Rust. We have two options here: write it in C and use Cgo, or write it by hand for Go's assembler. I try hard to avoid Cgo whenever possible for many reasons that are not at all original, but one of those reasons is that Cgo imposes a performance penalty, and performance of this snippet is paramount. Also, getting my hands dirty with some assembly sounds fun, so that's what I'm going to do.
I want this routine to be reasonably portable, so I'm going to restrict myself to only AVX2 instructions, which are
supported on most x86_64
server CPUs these days. We can use runtime feature
detection
to fall back to a slower option in pure Go.
Full code for DotAVX2
The core loop of the implementation depends on three main instructions:
VPMOVSXBW
, which loadsint8
s into a vectorint16
sVPMADDWD
, which multiplies twoint16
vectors element-wise, then adds fuzzy stack. together adjacent pairs to produce a vector ofint32
sVPADDD
, which accumulates the resultingint32
vector into our running sum
VPMADDWD
is a real heavy lifter here. By combining the multiply and add steps into one, not only does it save
instructions, it also helps us avoid overflow issues by simultaneously widening the result to an int32
.
Let's see what this earned us.
DotNaive
DotUnroll4
DotBCE
DotInt8BCE
DotAVX2
Woah, that's a 530% increase in throughput from our previous best! SIMD for the win 🚀
Now, it wasn't all sunshine and rainbows. Hand-writing assembly in Go is weird. It uses a custom assembler, which means that its assembly language looks just-different-enough-to-be-confusing compared to the assembly snippets you usually find online. It has some weird quirks like changing the order of instruction operands or using different names for instructions. Some instructions don't even have names in the go assembler and can only be used via their binary encoding. Shameless plug: I found sourcegraph.com invaluable for finding examples of Go assembly to draw from.
That said, compared to Cgo, there are some nice benefits. Debugging still works well, the assembly can be stepped
through, and registers can be inspected using delve
. There are no extra build steps (a C toolchain doesn't need to be
set up). It's easy to set up a pure-Go fallback so cross-compilation still works. Common problems are caught by go vet
.
SIMD...but bigger
Previously, we limited ourselves to AVX2, but what if we didn't? The VNNI extension to AVX-512 added the
VPDPBUSD
instruction, which computes the dot product on int8
vectors
rather than int16
s. This means we can process four times as many elements in a single instruction because we don't
have to convert to int16
first and our vector width doubles with AVX-512!
The only problem is that the instruction requires one vector to be signed bytes, and the other to be unsigned bytes.
Both of our vectors are signed. We can employ a trick from Intel's developer
guide
to help us out. Given two int8
elements, an
and bn
, we do the
element-wise calculation as an* (bn + 128) - an * 128
. The
an * 128
term is the overshoot from adding 128 to bump bn
into u8
range. We keep track of that separately and subtract it at the end. Each of the operations in that expression can be
vectorized.
Full code for DotVNNI
This implementation yielded another 21% improvement. Not bad!
DotNaive
DotUnroll4
DotBCE
DotInt8BCE
DotAVX2
DotVNNI
What's next?
Well, I'm pretty happy with an 9.3x increase in throughput and a 4x reduction in memory usage, so I'll probably leave it here.
The real life answer here is probably "use an index". There is a ton of good work out there focused on making nearest neighbor search fast, and there are plenty of batteries-included vector DBs that make it pretty easy to deploy.
However, if you want some fun food for thought, a colleague of mine built a proof-of-concept dot product on the GPU.
Bonus material
- If you haven't used benchstat, you should. It's great. Super simple statistical comparison of benchmark results.
- Don't miss the compiler explorer, which is an extremely useful tool for digging into compiler codegen.
- There's also that time I got nerd sniped into implementing a version with ARM NEON, which made for some interesting comparisons.
- If you haven't come across it, the Agner Fog Instruction Tables make for some great reference material for low-level optimizations. For this work, I used them to help grok differences instruction latencies and why some pipeline better than others.