diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index aef0707d51..b0f05de591 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -324,6 +324,12 @@ class Tensor: """ pass + def gather(self, index, dim): + """ + Gathers values along an axis specified by dim. + """ + pass + def get(self, index: int) -> Tensor: """ Gets the value at the specified index. diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 7b9a741340..e0d3bf300f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -448,6 +448,12 @@ impl PyTensor { Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?)) } + /// Gathers values along an axis specified by dim. + fn gather(&self, index: &Self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?)) + } + #[pyo3(text_signature = "(self, rhs:Tensor)")] /// Performs a matrix multiplication between the two tensors. /// &RETURNS&: Tensor