From 2f608858d69c70b384d88bec8dd73834df1182ea Mon Sep 17 00:00:00 2001 From: Ionut Mihalcea Date: Tue, 26 Nov 2024 23:10:09 +0100 Subject: [PATCH] Onnx Support for Sign operation #2641 (#2642) * Support for Sign operation #2641 * Apply rustfmt. --------- Co-authored-by: Laurent --- candle-onnx/tests/ops.rs | 41 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index b444fb64b0..4023d0c008 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -5855,3 +5855,44 @@ fn test_xor() -> Result<()> { } Ok(()) } + +#[test] +fn test_sign_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Sign".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap = HashMap::new(); + inputs.insert( + INPUT_X.to_string(), + Tensor::new(vec![-2f32, -1., 0., 1., 2.], &Device::Cpu)?, + ); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + assert_eq!( + z.to_dtype(candle::DType::I64)?.to_vec1::()?.to_vec(), + vec![-1, -1, 0, 1, 1] + ); + Ok(()) +}