forked from hatsu3/Sanger
-
Notifications
You must be signed in to change notification settings - Fork 3
/
bench_sanger.py
57 lines (47 loc) · 2.29 KB
/
bench_sanger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
from pathlib import Path
import pandas as pd
def bert_base_gflops(seq_len):
HIDDEN_SIZE = 768
linear_flops = seq_len * HIDDEN_SIZE * HIDDEN_SIZE * 2 * 3
qk_flops = seq_len * seq_len * HIDDEN_SIZE * 2
pv_flops = seq_len * seq_len * HIDDEN_SIZE * 2
out_proj_flops = seq_len * HIDDEN_SIZE * HIDDEN_SIZE * 2
stage1_flops = linear_flops
stage2_flops = qk_flops + pv_flops + out_proj_flops
stage1_gflops = stage1_flops / 1e9
stage2_gflops = stage2_flops / 1e9
return stage1_gflops, stage2_gflops
def calc_sanger_latency(sparsity, load_balance, seq_len):
PE_ARRAY_SIZE = 64 * 16
STAGE1_GOPS = PE_ARRAY_SIZE * 1 * 2 # pe-size * 1(GHz) * 2(ops/mac) = 2048
stage2_gops = STAGE1_GOPS / sparsity * load_balance
stage1_gflops, stage2_gflops = bert_base_gflops(seq_len)
stage1_lat = stage1_gflops / STAGE1_GOPS
stage2_lat = stage2_gflops / stage2_gops
total_lat = stage1_lat + stage2_lat
return total_lat
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--sparsity", default=None, type=float, required=False)
parser.add_argument("--load_balance", default=None, type=float, required=False)
parser.add_argument("--seq_len", default=512, type=int, required=False)
parser.add_argument("--csv_file", default="load_balance.csv", type=str, required=False,
help="Path to the csv file generated by gen_sparsity_mask.")
args = parser.parse_args()
if args.sparsity is not None:
assert args.load_balance is not None
total_lat = calc_sanger_latency(args.sparsity, args.load_balance, args.seq_len)
print(f"Sanger Latency: {total_lat * 1000:.3f} ms")
else:
assert Path(args.csv_file).exists(), f"{args.csv_file} does not exist."
metrics = pd.read_csv(args.csv_file).mean()
sparsity = metrics['overall-sparsity']
print(f"Average Sparsity: {sparsity:.3f}")
for lb_key in ['50%-no-skip', '50%-skip', '25%-no-skip', '25%-skip']:
load_balance = metrics[lb_key]
total_lat = calc_sanger_latency(sparsity, load_balance, args.seq_len)
print(f"Load Balance ({lb_key}): {load_balance:.3f}")
print(f"Sanger Latency ({lb_key}): {total_lat * 1000:.3f} ms")
if __name__ == "__main__":
main()