Matrix Multiply using Lookup Tables, Part 2: Why AVX-512 Makes It Fast
Part 1 (Matrix Multiply using Lookup Tables) showed the
basic idea. This post is about how to make it go vroom.
Code is at lut_mm.
Note: Claude Fable wrote the code. I just gave it the high level idea. It's way more fluent with SIMD and matrix kernels than I am.
Quick Review
- Many ternary weights {-1, 0, 1} can fit into a byte.
- We can treat groups of weights as indices into a lookup table.
- We DON'T have to unpack them for inference.
Motivation
How many trits can we fit into a byte?
5 Trits!
How does Bitnet pack trits?
3 trits per 5 bits.
Why? That's noticeably less dense. 8/5 = 1.6 and 5/3 = 1.66666. It seems like a strange choice of numbers.
The reason is that AVX2 makes it harder to do bigger lookup tables.
Okay, then how are we going to make this work?
AVX-512 to the rescue! Big vectors big results!
First some spicy results!
The headline result, on one core of a Ryzen 9 9950X (Zen 5), at , , :
| Implementation | Gop/s | vs dense |
|---|---|---|
| dense int8 GEMM (MSVC auto-vectorized AVX2) | 84 | 1.0x |
BitNet TL2 (AVX2 pshufb kernel) |
182 | 2.2x |
| BitNet TL2 widened to 512-bit (our port) | 343 | 4.1x |
dense int8 with AVX-512 VNNI (vpdpbusd) |
555 | 6.6x |
| ours: 5 trits/byte + AVX-512 lookup | 704 | 8.4x |
Every implementation produces bit-identical int32 output.
Why Is This Even Hard?
A lookup table only helps if the lookup is fast — and reading from memory isn't. Even a gather pulls about one entry per cycle, slower than just doing the multiply-adds it was meant to replace. Lookups win only when the whole table lives in registers and you index it with a shuffle, which reads dozens of entries at once. So it all comes down to one question: how big a table can a single shuffle index?
Limitations of AVX2
AVX2's shuffle, vpshufb, does 32 byte-lookups at once but indexes only a
16-entry table per lane. That tiny budget is what shapes Microsoft's
BitNet TL2 format: three trits have
values — too many — but the sign symmetry from Part 1 halves
that to 14 magnitudes, which fits. The price is real: signs ride in a
separate bitstream, and each int16 result is split across two byte tables
because vpshufb returns bytes.
What AVX-512 Changes
AVX-512BW gives us a different lookup primitive: vpermt2w, exposed in
intrinsics as _mm512_permutex2var_epi16. It indexes 16-bit words from a
table spanning two 512-bit registers. One such permute covers 64 entries;
two permutes plus a blend cover 128 entries.
Now redo the packing arithmetic:
| Trits per index | Combinations | Magnitudes after mirror symmetry | Fits? |
|---|---|---|---|
| 3 | 27 | 14 | 16-entry AVX2 pshufb |
| 5 | 243 | 122 | 128-entry AVX-512 pair |
| 6 | 729 | 365 | no |
This is the important coincidence:
So the packed byte never has to be unpacked: it's loaded, used directly as a table index, signed, and accumulated. One row, 32 columns:
__m512i v = _mm512_cvtepi8_epi16(packed_bytes); // 32 packed codes
__mmask32 neg = _mm512_movepi16_mask(v); // sign bit
__m512i m = _mm512_abs_epi16(v); // magnitude index
__mmask32 hi = _mm512_test_epi16_mask(m, c64); // index >= 64
__m512i r0 = _mm512_permutex2var_epi16(T[0], m, T[1]); // entries 0..63
__m512i r1 = _mm512_permutex2var_epi16(T[2], m, T[3]); // entries 64..121
__m512i r = _mm512_mask_blend_epi16(hi, r0, r1);
r = _mm512_mask_sub_epi16(r, neg, zero, r); // apply sign
acc = _mm512_add_epi16(acc, r);
Nine instructions, 32 results — 160 ternary multiply-adds, and the loop never touches memory: the tables stay in registers.
Two mask-register tricks pull their weight. The packed byte already
carries its sign, so vpmovw2m lifts the sign bits into a mask and a
masked subtract negates just those lanes — no separate sign stream. The
same masking handles the last N % 32 columns, so there's no scalar tail
loop (N=2047 runs within ~5% of N=2048).
Building The Table Cheaply
Each group of five activations needs its own 122-entry table, rebuilt every group — and building those tables naively is surprisingly expensive, easily a third of the runtime. But the table factors: split a code into its top three trits and bottom two, and the result is the sum of the halves,
so build a 27-entry H and a 9-entry L and add — like reading a
two-digit number off its tens and ones. Both are linear in the activations
(a few multiplies and adds), and vpermw assembles the full table with
constant indices. The build drops to a rounding error.
Two more wins: share the weight decode across four activation rows, and accumulate in int16, flushing to int32 every 51 groups (safe: ).
The Takeaway
It all comes down to one number: how big a table a single shuffle can
index. AVX-512's 512-bit registers are what make that table big enough —
vpermt2w reaches 128 word-sized entries across a register pair, exactly
enough for the 122 magnitudes of a five-trit code, while mask registers
and vpermw handle the signs, tails, and table build along the way. That
capability turns the Part 1 idea into a 704 Gop/s kernel on one core: 3.9x
over BitNet's AVX2 TL2 kernel, at 1.6 bits per weight. And the packing
keeps paying off as matrices grow — see the notes below.
Appendix
What about multiple threads? Every number here is single-threaded, to keep the comparison clean — but we ran the multithreaded experiments too. The kernel parallelizes across output rows with no shared state, and on this 16-core chip it scales to roughly 3.1–4.4 Top/s at 8 threads (4–6x the single-core figure) before memory bandwidth caps it.
Can it go faster if you restrict the activations? Yes. Cap activations
at and every five-trit table entry fits in an int8 — so a
single AVX-512 VBMI byte permute (vpermt2b) indexes all 122 entries at
once, replacing the int16 kernel's two word permutes and a blend. That
takes the headline shape from 704 to ~1110 Gop/s, about 1.5x, still
bit-exact within the restricted range. The catch is that it's no longer a
general int8 matmul: is roughly 5-bit activations, so it's a real
tradeoff rather than a free win. (In the repo as lut_mm_i8lut, behind
--act-max 25.)
Isn't this just a register-width advantage over BitNet's AVX2 kernel?
We ported the TL2 design up to 512-bit ourselves (bitnet_tl2@512b in the
headline table); at the same width our kernel is still ~2x faster (704 vs
343). So the gap is the five-trit format, not the bit width — and it comes
from a few rounds of optimization against a kernel BitNet has tuned for
years.
What about VNNI? AVX-512 VNNI is the strongest dense int8 baseline on
this CPU — vpdpbusd does dense int8 dot products directly, the
instruction you'd bet on to beat a lookup table. At cache-resident sizes
the LUT kernel only edges it, but the gap widens fast as the matrix grows:
| dense weights | packed weights | VNNI Gop/s | LUT Gop/s | LUT advantage | |
|---|---|---|---|---|---|
| 2080 x 2048 | 4.3 MB | 0.85 MB | 555 | 704 | 1.27x |
| 4160 x 4096 | 17 MB | 3.4 MB | 485 | 718 | 1.5x |
| 8320 x 8192 | 68 MB | 13.6 MB | 127 | 690 | 5.4x |
| 16640 x 16384 | 273 MB | 54.5 MB | 123 | 582 | 4.7x |
Dense int8 weights are 5x larger than packed ternary, so they fall out of cache 5x sooner. Through (17 MB) both fit this CPU's 32 MB L3 and the gap stays ~1.5x; at the 68 MB of dense weights spill to DRAM while the 13.6 MB packed form stays resident, VNNI drops to memory speed, and the gap jumps past 5x. At (GEMV, the token-decode shape, pure weight bandwidth) it's 426 vs 77. A packed format's cache cliff sits wherever the packed weights stop fitting — denser packing pushes it further out, and the lookups keep it fast until then.