Home
The Basics
Introducing bucketMul
The GPU implementation
MoE
Pesky details
About the Author(s)
Download and Run

Introducing BucketMul

There's a saying that data structures matter more than algorithms. This is certainly true in this case.

We'll start with a 12x12 example weight matrix of a model. Although it breaks the formatting, it's the simplest way to demonstrate the concept.

It's transposed from the regular implementation; hence, each full row—not column—is multiplied by a given input vector. \[ \begin{bmatrix} .46 & .87 & -.19 & .27 & .18 & -.39 & -.29 & -.62 & -.81 & -.34 & -.84 & .33 \\ -.87 & .11 & .03 & .5 & .43 & .87 & -.49 & .59 & .5 & -.42 & -.23 & .02 \\ -.44 & .35 & .76 & .85 & -.5 & -.4 & -.26 & .05 & -.37 & .0 & -.36 & -.07 \\ ... \end{bmatrix} \]

Now, let's focus on the first row. Elements from this row will be multiplied by the first element in the state vector. \[ \begin{bmatrix} .46 & .87 & -.19 & .27 & .18 & -.39 & -.29 & -.62 & -.81 & -.34 & -.84 & .33 \\ \end{bmatrix} \]

In our example, we'll do buckets sized 4. Let's split the vector into three such buckets \[ \begin{bmatrix} .46 & .87 & -.19 & .27 \end{bmatrix} \begin{bmatrix} .18 & -.39 & -.29 & -.62 \end{bmatrix} \begin{bmatrix} -.81 & -.34 & -.84 & .33 \end{bmatrix} \] Now, let's sort elements within each bucket, by abs of their values, keeping their positional information.

\[ \begin{bmatrix} .87 \scriptstyle \searrow 1 & .46 \scriptstyle \searrow 0 & .27 \scriptstyle \searrow 3 & -.19 \scriptstyle \searrow 2 \end{bmatrix} \]\[ \begin{bmatrix} -.62 \scriptstyle \searrow 3 & -.39 \scriptstyle \searrow 1 & -.29 \scriptstyle \searrow 2 & .18 \scriptstyle \searrow 0 \end{bmatrix} \]\[ \begin{bmatrix} -.84 \scriptstyle \searrow 2 & -.81 \scriptstyle \searrow 0 & -.34 \scriptstyle \searrow 1 & .33 \scriptstyle \searrow 3 \end{bmatrix} \]

Transpose...

\[ \begin{bmatrix} .87 \scriptstyle \searrow 1 \\ .46 \scriptstyle \searrow 0 \\ .27 \scriptstyle \searrow 3 \\ -.19 \scriptstyle \searrow 2 \end{bmatrix} \begin{bmatrix} -.62 \scriptstyle \searrow 3 \\ -.39 \scriptstyle \searrow 1 \\ -.29 \scriptstyle \searrow 2 \\ .18 \scriptstyle \searrow 0 \end{bmatrix} \begin{bmatrix} -.84 \scriptstyle \searrow 2 \\ -.81 \scriptstyle \searrow 0 \\ -.34 \scriptstyle \searrow 1 \\ .33 \scriptstyle \searrow 3 \end{bmatrix} \]

And reshape.

\[\begin{bmatrix} .87 \scriptstyle \searrow 1 & -.62 \scriptstyle \searrow 3 & -.84 \scriptstyle \searrow 2 \end{bmatrix}\] \[\begin{bmatrix} .46 \scriptstyle \searrow 0 & -.39 \scriptstyle \searrow 1 & -.81 \scriptstyle \searrow 0 \end{bmatrix}\] \[\begin{bmatrix} .27 \scriptstyle \searrow 3 & -.29 \scriptstyle \searrow 2 & -.34 \scriptstyle \searrow 1 \end{bmatrix}\] \[\begin{bmatrix} -.19 \scriptstyle \searrow 2 & .18 \scriptstyle \searrow 0 & .33 \scriptstyle \searrow 3 \end{bmatrix}\]

Thx AK. I know - a bit confusing, but that is how it works.

As an exercise, I recommend you figure out for a given number, what was it's original position in the vector. E.g. .33's position is (colNo * 4 + idx) = 2 * 4 + 3 = 11. This is the position in the output vector this weight belongs to.

Let's now calculate averages of absolute values of each bucket row. \[\begin{bmatrix} .87 \scriptstyle \searrow 1 & -.62 \scriptstyle \searrow 3 & -.84 \scriptstyle \searrow 2 \end{bmatrix} \rightarrow avg. abs. 0.777\] \[\begin{bmatrix} .46 \scriptstyle \searrow 0 & -.39 \scriptstyle \searrow 1 & -.81 \scriptstyle \searrow 0 \end{bmatrix} \rightarrow avg. abs. 0.553\] \[\begin{bmatrix} .27 \scriptstyle \searrow 3 & -.29 \scriptstyle \searrow 2 & -.34 \scriptstyle \searrow 1 \end{bmatrix} \rightarrow avg. abs. 0.3\] \[\begin{bmatrix} -.19 \scriptstyle \searrow 2 & .18 \scriptstyle \searrow 0 & .33 \scriptstyle \searrow 3 \end{bmatrix} \rightarrow avg. abs. 0.233 \]

Look what happened here!

Now, we have the input row divided into buckets with decreasing average scores. It's not perfect - for example .33 from the last row should be in a row higher, but in practice it's good enough.

Keep in mind, however, that this was only the first row. We have additional rows in our source matrix W.

We will proceed similarly with the other rows, interleaving them within the output structure. In the future we will multiply them by certain dimensions of v. \[ v_o : \begin{bmatrix} \ .87 \scriptstyle \searrow 1 & -.62 \scriptstyle \searrow 3 & -.84 \scriptstyle \searrow 2 \end{bmatrix} \rightarrow avg. abs. 0.777 \] \[ v_1 : \begin{bmatrix} -0.87 \scriptstyle \searrow 0 & 0.87 \scriptstyle \searrow 1 & 0.5 \scriptstyle \searrow 0 \end{bmatrix} \rightarrow avg. abs. 0.747 \] \[ ... \] \[ v_o : \begin{bmatrix} .46 \scriptstyle \searrow 0 & -.39 \scriptstyle \searrow 1 & -.81 \scriptstyle \searrow 0 \end{bmatrix} \rightarrow avg. abs. 0.553 \] \[ v_1 : \begin{bmatrix} 0.5 \scriptstyle \searrow 3 & 0.59 \scriptstyle \searrow 3 & -0.42 \scriptstyle \searrow 2 \end{bmatrix} \rightarrow avg. abs. 0.503 \] \[...\] \[ v_o : \begin{bmatrix} \ .27 \scriptstyle \searrow 3 & -.29 \scriptstyle \searrow 2 & -.34 \scriptstyle \searrow 1 \end{bmatrix} \rightarrow avg. abs. 0.3 \] \[ v_1 : \begin{bmatrix} 0.11 \scriptstyle \searrow 1 & -0.49 \scriptstyle \searrow 2 & -0.23 \scriptstyle \searrow 2 \end{bmatrix} \rightarrow avg. abs. 0.28 \] \[...\] \[ v_o : \begin{bmatrix} \ -.19 \scriptstyle \searrow 2 & .18 \scriptstyle \searrow 0 & .33 \scriptstyle \searrow 3 \end{bmatrix} \rightarrow avg. abs. 0.223 \] \[ v_1 : \begin{bmatrix} 0.03 \scriptstyle \searrow 2 & 0.43 \scriptstyle \searrow 0 & 0.02 \scriptstyle \searrow 3 \end{bmatrix} \rightarrow avg. abs. 0.16 \]

The output shape of the matrix will be [12*4, 12/4], or in real world [inDim * bSize, outDim / bSize].

Let's call this list our bucket list, and the averages - bucket stats.

These, along with probes from the previous chapter, are the three structures we need from the preprocessing stage.

We don't need to calculate them efficiently, we only do this once per model, and that's it.

Oh, and by the way - since the bucket list is organised such that the least important weights are at the end, unlike traditional matrixes, we don't need to load it whole into the memory. We can just skip however many last rows we want during load time, and from practice - if it's 20-30%, the model may not even notice.

There you have it: ad hoc distillation. You're welcome.

Inference time!

Let's revisit the original algorithm from the previous chapter, the one that drew laughter from children.


def precompute(W):
  W = W.T
  probes = get_probes(W)
  W_idx, W_val = sortMatrixRows(W)

def approxMul(v, W_idx, W_val, probes):
  cutoff_chart = v * probes
  cutoff = topK(cutoff_chart, effort)

  # the rest is the same as beffore
  for el_v in V:
   el_cutoff = cutoff / el_v
   for el_w, idx_w in W:
      if el_w > el_cutoff:
         O[idx_w] += el_v * el_w
      else:
         break
         # other els in the rows are smaller, 
         # so we can skip checking them

The initial steps remain unchanged. We select an effort level, determine a cutoff. But we introduce a new approach to multiplications.

We'll split the algorithm into two parts.

First one I call "calculating dispatch".

  
for rowId in bucketList:
  if cutoff > stat[rowId] * v[rowId % inDim]
      dispatchList.append((rowId, v[rowId % inDim]))

This process filters our bucket list, creating a dispatch list that pairs each rowId with its corresponding value for multiplication. This step is efficiently parallelizable.

Dispatch contains a list of buckets that will be multiplied, along with the value to be multiplied by them. Once we have created the dispatch, we can discard the source vector.

And now, the bucketMul itself:

for (bucket_id, v_val) in dispatchList:
    for (weight, id) in buckets[bucket_id]:
       output_vec[id] += weight*v_val

But wait, how is this different from before? Now, we need to delve into more technical details and discuss the GPU implementation.

Where do we go from here?

- The GPU implementation

- MoE, Quantization and the others.

- Pesky details (or: Help Needed!)

At any time

- Install and Download

And of course...

- About the Author(s)

- Citations, notes and so on

Or going back...

- The landing page.

- The basics.