diff --git a/rust/examples/classification-example/src/main.rs b/rust/examples/classification-example/src/main.rs index 3b3b1a6..c273a4f 100644 --- a/rust/examples/classification-example/src/main.rs +++ b/rust/examples/classification-example/src/main.rs @@ -11,44 +11,52 @@ pub fn main() { let weights = fs::read("fixture/mobilenet.bin").unwrap(); println!("Read graph weights, size in bytes: {}", weights.len()); - let context = wasi_nn::backend_init( - &[&xml.into_bytes(), &weights], - wasi_nn::GRAPH_ENCODING_OPENVINO, - wasi_nn::EXECUTION_TARGET_CPU, - ).unwrap(); + let xmlbytes = xml.as_bytes(); + let weightbytes = weights.as_slice(); + let mut builders: Vec<&[u8]> = vec![xmlbytes, weightbytes]; - println!("Created wasi-nn execution context with ID: {}", context); + let my_graph = wasi_nn::WasiNnGraph::load( + builders.into_iter(), + wasi_nn::GRAPH_ENCODING_OPENVINO, + wasi_nn::EXECUTION_TARGET_CPU, + ); + let graph = my_graph.unwrap(); + let mut context = graph.get_execution_context(); + + // TODO: Need to add code to the Wasmtime side to get the input / output tensor shapes + let intypes = graph.get_input_types(); + let outtypes = graph.get_output_types(); // Load a tensor that precisely matches the graph input tensor (see // `fixture/frozen_inference_graph.xml`). for i in 0..5 { let filename: String = format!("{}{}{}", "fixture/images/", i, ".jpg"); // Convert the image. If it fails just exit - let tensor_data = convert_image_to_bytes(&filename, 224, 224, TensorType::F32, ColorOrder::BGR).or_else(|e| { - Err(e) - }).unwrap(); + let tensor_data = + convert_image_to_bytes(&filename, 224, 224, TensorType::F32, ColorOrder::BGR) + .or_else(|e| Err(e)) + .unwrap(); println!("Read input tensor, size in bytes: {}", tensor_data.len()); - wasi_nn::set_input_tensor( - context, - 0, - &[1, 3, 224, 224], - wasi_nn::TENSOR_TYPE_F32, - &tensor_data, - ).unwrap(); + let tensor = wasi_nn::Tensor { + dimensions: &[1, 3, 224, 224], + type_: wasi_nn::TENSOR_TYPE_F32, + data: &tensor_data, + }; + + context.set_input(0, tensor); // Execute the inference and get the output. let mut output_buffer = vec![0f32; 1001]; - let _written = wasi_nn::execute( - context, + let _res = context.compute(); + let _wrote = context.get_output( 0, &mut output_buffer[..] as *mut [f32] as *mut u8, (output_buffer.len() * 4).try_into().unwrap(), - ).unwrap(); + ); println!("Executed graph inference"); - let results = sort_results(&output_buffer); println!("Found results, sorted top 5: {:?}", &results[..5]); diff --git a/rust/src/generated.rs b/rust/src/generated.rs index 22255d5..51de256 100644 --- a/rust/src/generated.rs +++ b/rust/src/generated.rs @@ -279,42 +279,3 @@ pub mod wasi_ephemeral_nn { pub fn compute(arg0: i32) -> i32; } } - -pub fn backend_init( - builder: GraphBuilderArray<'_>, - encoding: GraphEncoding, - target: ExecutionTarget,) -> Result { - unsafe { - init_execution_context(load(builder, encoding, target)?) - } -} - -pub fn set_input_tensor( - context: GraphExecutionContext, - index: u32, - dimensions: &[u32], - ttype: TensorType, - tdata: TensorData, -) -> Result<(), NnErrno> { - let tensor = Tensor { - dimensions: dimensions, - type_: ttype, - data: &tdata, - }; - - unsafe { - set_input(context, index, tensor) - } -} - -pub fn execute( - context: GraphExecutionContext, - index: u32, - out_buffer: *mut u8, - out_buffer_max_size: BufferSize, -) -> Result { - unsafe { - compute(context)?; - get_output(context, index, out_buffer, out_buffer_max_size) - } -} diff --git a/rust/src/hl.rs b/rust/src/hl.rs new file mode 100644 index 0000000..a1e3e6a --- /dev/null +++ b/rust/src/hl.rs @@ -0,0 +1,94 @@ +use crate::generated::*; + +pub struct WasiNnGraph { + execution_ctx: ExecutionContext, + input_types: TensorTypes, + output_types: TensorTypes, +} + +impl WasiNnGraph { + pub fn load<'a>( + data: impl Iterator, + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> Result { + unsafe { + let builders: Vec<&'a [u8]> = data.map(|x| x).collect(); + + if builders.len() > 1 { + let ctx = init_execution_context(load(&builders.as_slice(), encoding, target)?); + match ctx { + Ok(ctx) => Ok(Self { + execution_ctx: ExecutionContext::new(ctx), + input_types: TensorTypes::new(), + output_types: TensorTypes::new(), + }), + Err(ctx) => Err(ctx), + } + } else { + Err(NN_ERRNO_MISSING_MEMORY) + } + } + } + + pub fn get_execution_context(&self) -> ExecutionContext { + self.execution_ctx + } + + pub fn get_input_types(&self) -> &TensorTypes { + &self.input_types + } + + pub fn get_output_types(&self) -> &TensorTypes { + &self.output_types + } +} + +pub struct TensorTypes { + ttypes: Vec, +} + +impl TensorTypes { + pub fn new() -> Self { + Self { ttypes: Vec::new() } + } + + pub fn len(&self) -> usize { + self.ttypes.len() + } + pub fn get(&self, index: u32) -> Result { + if index < self.ttypes.len() as u32 { + Ok(self.ttypes[index as usize]) + } else { + Err(format!("Invalid index {}", index)) + } + } +} + +#[derive(Copy, Clone, Debug)] +pub struct ExecutionContext { + execution_ctx: GraphExecutionContext, +} + +impl ExecutionContext { + pub fn new(ctx: GraphExecutionContext) -> Self { + Self { execution_ctx: ctx } + } + + pub fn set_input(&mut self, index: u32, tensor: Tensor) -> Result<(), NnErrno> { + unsafe { set_input(self.execution_ctx, index, tensor) } + } + + pub fn compute(&mut self) -> Result<(), NnErrno> { + unsafe { compute(self.execution_ctx) } + } + + pub fn get_output( + &self, + index: u32, + out_buffer: *mut u8, + out_buffer_max_size: BufferSize, + ) -> Result { + unsafe { get_output(self.execution_ctx, index, out_buffer, out_buffer_max_size) } + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 14ed671..c356b7f 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,2 +1,4 @@ mod generated; pub use generated::*; +mod hl; +pub use hl::*;