Skip to content

Commit

Permalink
fix #172 (#173)
Browse files Browse the repository at this point in the history
- steel_attenion.metal (new) was missing from the build
  • Loading branch information
davidkoski authored Dec 5, 2024
1 parent 7f02cd8 commit 70dbb62
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright © 2024 Apple Inc.

// clang-format off
#include "../../../utils.h"

#include "../../../steel/attn/attn.h"
#include "../../../steel/attn/kernels/steel_attention.h"

#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
const device dtype* Q [[buffer(0)]], \
const device dtype* K [[buffer(1)]], \
const device dtype* V [[buffer(2)]], \
device dtype* O [[buffer(3)]],\
const constant AttnParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);

#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)

instantiate_attn_shapes_helper(float16, half);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);

instantiate_attn_shapes_helper(float32, float);
// clang-format on
26 changes: 26 additions & 0 deletions Tests/MLXTests/MLXFastKernelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,30 @@ class MLXFastKernelTests: XCTestCase {
XCTAssertTrue(allClose(out[0], full([2, 2], values: 14.0484)).all().item())
XCTAssertTrue(allClose(out[1], full([3, 2], values: -2)).all().item())
}

func testFastSDPA() {
// https://github.com/ml-explore/mlx-swift/issues/172
// this will just make sure the MLXFast.scaled_dot_product_attention is
// callable in the various cases, based on
// https://github.com/ml-explore/mlx/blob/main/python/tests/test_fast_sdpa.py#L65-L87

let Dk = 64
let scale = 1.0 / sqrt(Float(Dk))
let dTypes = [DType.float32, DType.float16]
for SEQUENCE_LENGTH in [63, 129, 400] {
for dtype in dTypes {
let B = 2
let H = 24
let q = MLXRandom.normal([B, H, SEQUENCE_LENGTH, Dk]).asType(dtype)
let k = MLXRandom.normal([B, H, SEQUENCE_LENGTH, Dk]).asType(dtype)
let v = MLXRandom.normal([B, H, SEQUENCE_LENGTH, Dk]).asType(dtype)

let result = MLXFast.scaledDotProductAttention(
queries: q, keys: k, values: v, scale: scale, mask: nil,
memoryEfficientThreshold: 2)

eval(result)
}
}
}
}
5 changes: 3 additions & 2 deletions tools/fix-metal-includes.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ KERNEL_LIST=" \
arg_reduce.metal \
conv.metal \
gemv.metal \
layer_norm.metal \
random.metal \
rms_norm.metal \
layer_norm.metal \
rope.metal \
scaled_dot_product_attention.metal"
scaled_dot_product_attention.metal \
steel/attn/kernels/steel_attention.metal"

# We fixup all the header files AND the listed kernel files
HEADERS=$(find "${KERNELS_DIR}" -name "*.h")
Expand Down

0 comments on commit 70dbb62

Please sign in to comment.