Skip to content

Commit

Permalink
add scatter add (#2656)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp authored Dec 1, 2024
1 parent dba7a9c commit 6f715f9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,7 @@ impl BackendStorage for MetalStorage {
(DType::U8, DType::F32) => "sa_u8_f32",
(DType::U8, DType::F16) => "sa_u8_f16",
(DType::U8, DType::BF16) => "sa_u8_bf16",
(DType::U32, DType::U32) => "sa_u32_u32",
(DType::U32, DType::F32) => "sa_u32_f32",
(DType::U32, DType::F16) => "sa_u32_f16",
(DType::U32, DType::BF16) => "sa_u32_bf16",
Expand Down
1 change: 1 addition & 0 deletions candle-metal-kernels/src/indexing.metal
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ GATHER_OP(gather_u32_u32, uint, uint)
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
SCATTER_ADD_OP(sa_u32_u32, uint32_t, uint32_t)
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
Expand Down

0 comments on commit 6f715f9

Please sign in to comment.