diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index ac00a97997..beaa945534 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -203,7 +203,7 @@ impl Shape { /// Check whether the two shapes are compatible for broadcast, and if it is the case return the /// broadcasted shape. This is to be used for binary pointwise ops. - pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result { + pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result { let lhs = self; let lhs_dims = lhs.dims(); let rhs_dims = rhs.dims(); diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 40c908806e..4b75a6871e 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -19,10 +19,6 @@ extern crate accelerate_src; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; -mod utils; - -use utils::broadcast_shapes; - pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) } @@ -726,7 +722,11 @@ impl PyTensor { compare(&self.0, &rhs.0) } else { // We broadcast manually here because `candle.cmp` does not support automatic broadcasting - let broadcast_shape = broadcast_shapes(&self.0, &rhs.0).map_err(wrap_err)?; + let broadcast_shape = self + .0 + .shape() + .broadcast_shape_binary_op(rhs.0.shape(), "cmp") + .map_err(wrap_err)?; let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; diff --git a/candle-pyo3/src/utils.rs b/candle-pyo3/src/utils.rs deleted file mode 100644 index 5fd1df3137..0000000000 --- a/candle-pyo3/src/utils.rs +++ /dev/null @@ -1,40 +0,0 @@ -use ::candle::{Error as CandleError, Result as CandleResult}; -use candle::Shape; - -/// Tries to broadcast the `rhs` shape into the `lhs` shape. -pub fn broadcast_shapes(lhs: &::candle::Tensor, rhs: &::candle::Tensor) -> CandleResult { - // see `Shape.broadcast_shape_binary_op` - let lhs_dims = lhs.dims(); - let rhs_dims = rhs.dims(); - let lhs_ndims = lhs_dims.len(); - let rhs_ndims = rhs_dims.len(); - let bcast_ndims = usize::max(lhs_ndims, rhs_ndims); - let mut bcast_dims = vec![0; bcast_ndims]; - for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() { - let rev_idx = bcast_ndims - idx; - let l_value = if lhs_ndims < rev_idx { - 1 - } else { - lhs_dims[lhs_ndims - rev_idx] - }; - let r_value = if rhs_ndims < rev_idx { - 1 - } else { - rhs_dims[rhs_ndims - rev_idx] - }; - *bcast_value = if l_value == r_value { - l_value - } else if l_value == 1 { - r_value - } else if r_value == 1 { - l_value - } else { - return Err(CandleError::BroadcastIncompatibleShapes { - src_shape: lhs.shape().clone(), - dst_shape: rhs.shape().clone(), - } - .bt()); - } - } - Ok(Shape::from(bcast_dims)) -}