Skip to content

Commit

Permalink
remove duplicated broadcast_shape_binary_op
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Oct 28, 2023
1 parent dd06ff8 commit b58056f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 46 deletions.
2 changes: 1 addition & 1 deletion candle-core/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ impl Shape {

/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops.
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
Expand Down
10 changes: 5 additions & 5 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ extern crate accelerate_src;

use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};

mod utils;

use utils::broadcast_shapes;

pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
Expand Down Expand Up @@ -726,7 +722,11 @@ impl PyTensor {
compare(&self.0, &rhs.0)
} else {
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
let broadcast_shape = broadcast_shapes(&self.0, &rhs.0).map_err(wrap_err)?;
let broadcast_shape = self
.0
.shape()
.broadcast_shape_binary_op(rhs.0.shape(), "cmp")
.map_err(wrap_err)?;
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;

Expand Down
40 changes: 0 additions & 40 deletions candle-pyo3/src/utils.rs

This file was deleted.

0 comments on commit b58056f

Please sign in to comment.