-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathssm_scan.c
132 lines (117 loc) · 3.76 KB
/
ssm_scan.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <stdio.h>
#include "ggml.h"
#include "ggml-cpu.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
int main(int argc, char **argv) {
printf("GGML ssm_scan example\n");
struct ggml_init_params params = {
.mem_size = 16*1024*1024,
.mem_buffer = NULL,
};
struct ggml_context* ctx = ggml_init(params);
// d_inner is the dimension of the inner layer (after the projection layer).
int d_inner = 8;
// seq_len is the length of the input sequence.
int seq_len = 4;
// d_state is the dimension of the state vector
int d_state = 16;
// s is the current state of the system.
//
// d_state
// 0 [0 ... 15]
// 1 [0 ... 15]
// 2 [0 ... 15]
// 3 [0 ... 15] d_inner
// 4 [0 ... 15]
// 5 [0 ... 15]
// 6 [0 ... 15]
// 7 [0 ... 15]
//
struct ggml_tensor* s = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_state , d_inner);
ggml_set_name(s, "s");
ggml_set_zero(s);
printf("s nelements: %lld:\n", ggml_nelements(s));
/*
for (int i = 0; i < ggml_nelements(s); i++) {
printf("%.2f ", ggml_get_f32_1d(s, i));
}
printf("\n");
*/
// x is the output of the input->projection->convolution->silu.
//
// d_inner
// token 0 [0 ... 7]
// token 1 [0 ... 7] seq_len
// token 2 [0 ... 7]
// token 3 [0 ... 7]
//
struct ggml_tensor* x = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_inner, seq_len);
ggml_set_name(x, "x");
printf("x nelements: %lld:\n", ggml_nelements(x));
printf("x ne[0]: %lld:\n", x->ne[0]);
printf("x ne[1]: %lld:\n", x->ne[1]);
printf("x ne[2]: %lld:\n", x->ne[2]);
ggml_set_f32_nd(x, 0, 0, 0, 0, 1.0f);
// dt is the delta and we have one delta value per token.
//
// d_inner
// token 0 [0 ... 7]
// token 1 [0 ... 7] seq_len
// token 2 [0 ... 7]
// token 3 [0 ... 7]
//
struct ggml_tensor* dt = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_inner, seq_len);
ggml_set_name(dt, "delta");
// A is the learned state transition matrix.
//
// d_state
// 0 [0 ... 15]
// 1 [0 ... 15]
// 2 [0 ... 15]
// 3 [0 ... 15] d_inner
// 4 [0 ... 15]
// 5 [0 ... 15]
// 6 [0 ... 15]
// 7 [0 ... 15]
//
struct ggml_tensor* A = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_state, d_inner);
ggml_set_name(A, "A");
// B is the dynamic (not here but in a real Mamba model) input transition matrix.
// d_state
// token 0 [0 ... 7]
// token 1 [0 ... 7] seq_len
// token 2 [0 ... 7]
// token 3 [0 ... 7]
//
struct ggml_tensor* B = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_state, seq_len);
ggml_set_name(B, "B");
// C is the output transition matrix.
// d_state
// token 0 [0 ... 7]
// token 1 [0 ... 7] seq_len
// token 2 [0 ... 7]
// token 3 [0 ... 7]
//
struct ggml_tensor* C = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_state, seq_len);
ggml_set_name(C, "C");
// y is the result of the scan operation. Which is a one dimensional tensor
// with the number of elements of x plus the number of elements in s.
//
// [0 ... 31 ... 160]
//
struct ggml_tensor* y = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
ggml_set_name(y, "y");
struct ggml_cgraph* c_graph = ggml_new_graph(ctx);
ggml_build_forward_expand(c_graph, y);
int n_threads = 1;
enum ggml_status st = ggml_graph_compute_with_ctx(ctx, c_graph, n_threads);
if (st != GGML_STATUS_SUCCESS) {
printf("could not compute graph\n");
return 1;
}
printf("y nelements: %lld:\n", ggml_nelements(y));
printf("y ne[0]: %lld:\n", y->ne[0]);
ggml_free(ctx);
return 0;
}