This is the paralell scan operation for the SSM block in Mamba. There is a standalone example in ssm_scan.c
static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // dt
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
const int ith = params->ith;
const int nth = params->nth;
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;
In the example I've set the threads to 1 to simplify stepping through the code. This enables us to ignore the threading code.
(lldb) p nc
(const int64_t) 16
(lldb) p nr
(const int64_t) 8
(lldb) p n_t
(const int64_t) 4
(lldb) p n_s
(const int64_t) 1
Perhaps nc
stands for number of channels?
The main loop will loop over all the sequences, which in our case is only 1:
for (int i3 = 0; i3 < n_s; ++i3) {
Next it will loop over all the tokens in each sequence:
for (int i2 = 0; i2 < n_t; ++i2) {
Next we have a pointer to the data in the s tensor:
const float * s0 = (const float *) ((const char *)
src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
Now, the const char cast is there to enable pointer arithmetic. src0
is the s
tensor which is the tensor after the input embeddings have been projected into
the inner state dimensions, gone through the convolution layer, and the Silu
operation.
The s
tensor is the current state of the system.
So it will look like something like this:
d_state
0 [0 ... 15]
1 [0 ... 15]
2 [0 ... 15]
3 [0 ... 15] d_inner
4 [0 ... 15]
5 [0 ... 15]
6 [0 ... 15]
7 [0 ... 15]
And since we are only using one thread we can ignore ir0 as it will be zero:
const float * s0 = (const float *) ((const char *)
src0->data + i3*(src0->nb[2]));
And we only have one sequence which is represented by i3 so this is also zero at this point:
const float * s0 = (const float *) ((const char *)
src0->data;
So for this iteration s0
is simply a pointer to the beginning of s data.
Next we have our x tensor, which is the input to the SSM block. This is the output of the input embeddings->projection layer->convolution layer->Silu:
const float * x = (const float *) ((const char *)
src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2]));
This looks like this:
d_inner
token 0 [0 ... 7]
token 1 [0 ... 7] seq_len
token 2 [0 ... 7]
token 3 [0 ... 7]
And again we can do the same simplification as above:
const float * x = (const float *) ((const char *)
src1->data;
So x will be a pointer to the beginning of x data.
Then we have dt (delta):
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
This has the same shape as x:
d_inner
token 0 [0 ... 7]
token 1 [0 ... 7] seq_len
token 2 [0 ... 7]
token 3 [0 ... 7]
And with the same simplification as above:
const float * dt = (const float *) ((const char *)
src2->data;
So dt will be a pointer to the beginning of dt data.
Next we have the A tensor (state transition matrix):
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
And A looks like this:
d_state
0 [0 ... 15]
1 [0 ... 15]
2 [0 ... 15]
3 [0 ... 15] d_inner
4 [0 ... 15]
5 [0 ... 15]
6 [0 ... 15]
7 [0 ... 15]
And we'll simplify this as well:
const float * A = (const float *) ((const char *)
src3->data;
Next we have the B tensor (input state transition matrix):
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
d_state
token 0 [0 ... 7]
token 1 [0 ... 7] seq_len
token 2 [0 ... 7]
token 3 [0 ... 7]
And again we'll simplify this:
const float * B = (const float *) ((const char *)
src4->data;
Next we have the C tensor (output transistion matrix):
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
And this looks like this:
d_state
token 0 [0 ... 7]
token 1 [0 ... 7] seq_len
token 2 [0 ... 7]
token 3 [0 ... 7]
And we'll simplify this as well:
const float * C = (const float *) ((const char *)
src5->data;
Next we have y
which is the ouput (y = ...):
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
And this looks like this:
0 [0 ... 159]
Note that this is not const so we can expect it to be updated.
float * y = ( float *) (( char *)
dst->data;
So this is currently just a pointer to the dst tensor data which is our
y
tensor in the example code:
(lldb) p dst
(ggml_tensor *) 0x000000013e000fc0
(lldb) p dst->name
(char[64]) "y"
(lldb) p dst->ne
(int64_t[4]) ([0] = 160, [1] = 1, [2] = 1, [3] = 1)
So y
will be the output of this function. Note that this is a one dimensional
tensor and was created in the ggml_ssm_scan
function:
// concatenated y + ssm_states
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
So it looks like the this will contain both the output values (y) and the ssm_states.
Next we have an s
tensor:
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
Notice that this is also a pointer to dst data. If we simplify this we get:
float * s = ( float *) (( char *)
dst->data + src1->nb[3]);
(lldb) p src1->nb[3]
(const size_t) 128
So this is a pointer to the 128th element in the dst tensor data. Why 128?
Notice that this is using nb
which is number of bytes (stride), so we need
divide by 4, 128/4 = 32. So this is the offset for the s tensor. And this should
match the number of elements of the y part of the output tensor.
So I think that the output y for each dimension is stored first in this tensor
so there will be 128 elements for the y values.
After that we have:
if (i2 > 0) { s0 = s; }
So if we are not at the first token we set s0
to the value of s
, the last
computed state (not the output) I think.
After that we hae a final loop which is iterating over all the dimensions of the
inner state (d_inner
) (in our case 8):
for (int i1 = 0; i1 < ir; ++i1) {
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
So here we are taking the first element of delta and checking if it is less than
or equal to 20.0, and if it is then we pass it to the log1pf
function which
computes the natrual logarithm for the delta value for this inner state. And if
the current delta value is greater than 20.0 we just use the delta value as is.
soft plus is defined as:
f(n) = ln(1 + e^n)
of
f(n) = log(1 + exp(n))
In this case the + 1 is performed by log1pf (log plus 1 float).
Next we multiply the delta value with the current x value:
float x_dt = x[i1] * dt_soft_plus;
So we will have one value for each inner state dimension. This because this is "broadcasted" (used for all the channels/dimensions below). This is where the input is "mixed" with the delta value making the delta input dependent.
Next we will iteratate over the d_state
(16 in our case):
float sumf = 0.0f;
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[i0];
s[i] = state;
}
y[i1] = sumf;
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
Just as a reminder the state space model is defined as:
h_t = Ā h_{t-1} + B̂ x_t
Where:
A_bar = is the state transition matrix.
B_bar = input projection matrix.
x_t = the input at time t.
h_t = the hidden state at time t.
h_t-1 = the previous hidden state.
Now, lets start from the right side and we can see that we are multiplying the
x_dt
with the input transition matrix B[i0]
. And x_dt
is the
input value with the delta time step incorporated. Though it might not look like
it this is also the descretization of the input projection matrix. This is done
by the multiplication of x_dt
which is possible because for small values
of delta ∫(0 to Δt) exp(A*τ) * B dτ
can be approximated by Δt * B
.
Then we have the descretization of A which is done by dt_soft_plus * A[i]
.
The expf
function is part of the zero-hold order (ZOH) descretization which
And then we have the previous state value s0[i]
multiplied by the descretized
state transition matrix.
We can visualize the inner state interation (over d_state
) after we exit the inner
most loop:
s[0] = state0
s[1] = state1
s[2] = state2
s[3] = state3
s[4] = state4
s[5] = state5
s[6] = state6
s[7] = state7
s[8] = state8
s[9] = state9
s[10] = state10
s[11] = state11
s[12] = state12
s[13] = state13
s[14] = state14
s[15] = state15
y[0] = weighted sum of all the above states
This completes the first d_inner
and we continue with the next d_inner
.
This time through the inner most loop i will get updated and be offset by 16:
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
s[16] = state16
s[17] = state17
s[18] = state18
s[19] = state19
s[20] = state20
s[21] = state21
s[22] = state22
s[23] = state23
s[24] = state24
s[25] = state25
s[26] = state26
s[27] = state27
s[28] = state28
s[29] = state29
s[30] = state30
s[31] = state31
y[1] = weighted sum of all the above states
This will continue for all ir
(d_inner
) dimensions. So there will be d_inner
*
d_state
elements in the s part of the tensor.
And there should be ir
elements in the y part of the tensor.
So if d_inner/ir
is 4096 and d_state/nc
is 16, and n_t
(the number of tokens) we
should have:
y = 4096 * 5 = 20480
s = 4096 * 16 = 65536
Total: 86016
(lldb) p ir * n_t
(int64_t) 20480
(lldb) p ir * nc
(int64_t) 65536
(lldb) p (ir * nc) + ir * n_t
(int64_t) 86016
(lldb) p ggml_nelements(dst)
(int64_t) 86016
As we can see the number of elements in the tensor is correct. Notice that the "y" part has the the same number of tokens (5 in this case) and the dimension is 4096.
Just touch on the offset of the s tensor which is created using:
(lldb) p src1->nb[3]
(const size_t) 81920
This had me confused as forgot that nb is number of bytes (stride). And we are dealing with floats so we can assume 4 bytes. So 81920/4 = 20480 which is the number of elements in y.
Now, y and s point to the same tensor but s uses an offset from the beginning. Which in this case would be 81920
To clarify, ignoring the sequences as we only have one in this case, we are looping
over the tokens in the sequence n_t
. And then we are doing to iterate over all the
d_inner
dimensions. And for each of these dimensions we are going to iterate over
all the states in the system d_state
.
So that gives as the updated state of the system. This is then multiplied by the
output transition matrix C[i0]
and summed and stored in sumf
. And this
state is also stored in s[i] for the next iteration (next token iteration that
is).
So that was the first iteration of the first channel/feature and we will then
do that same for the second channel/feature.
So notice that we do this operation for each channel/feature/dimension in the
inner state.
(Channels/Features/Dimensions)
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
+--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
+--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
+--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+
SSM | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
+--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+ +--+
| | | | | | | | | | | | | | | |
+-----+----+-----+-----+-----+-----+-----+----+----+----+----+----+----+----+----+
|
+---+
| + |
+---+
|
↓
+---+
| y |
+---+
And this is for the first token. We then do the same for the rest of the tokens
in the sequence and the internal state is updated by the previous state. And
the output is stored in the y tensor which is the output to the next layer.
Just to recap we iterated over the sequences (n_s
), then the number of tokens
in the sequence (n_t
), then the number of states (d_state
), and then the
channels in each state (d_inner
).
So there will be a y value for each dimension of the inner state. But the state of the system will be update for each channel. One way to think about this is that we are updating the state of the system one token (time step) at a time. This will move the system from one state to a new state.
wip