Skip to content

Commit

Permalink
Updating API to be more in line with previous plan
Browse files Browse the repository at this point in the history
See:
#68
  • Loading branch information
brianjjones committed Jan 9, 2023
1 parent d7d52af commit 90eb701
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 59 deletions.
48 changes: 28 additions & 20 deletions rust/examples/classification-example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,52 @@ pub fn main() {
let weights = fs::read("fixture/mobilenet.bin").unwrap();
println!("Read graph weights, size in bytes: {}", weights.len());

let context = wasi_nn::backend_init(
&[&xml.into_bytes(), &weights],
wasi_nn::GRAPH_ENCODING_OPENVINO,
wasi_nn::EXECUTION_TARGET_CPU,
).unwrap();
let xmlbytes = xml.as_bytes();
let weightbytes = weights.as_slice();
let mut builders: Vec<&[u8]> = vec![xmlbytes, weightbytes];

println!("Created wasi-nn execution context with ID: {}", context);
let my_graph = wasi_nn::WasiNnGraph::load(
builders.into_iter(),
wasi_nn::GRAPH_ENCODING_OPENVINO,
wasi_nn::EXECUTION_TARGET_CPU,
);
let graph = my_graph.unwrap();
let mut context = graph.get_execution_context();

// TODO: Need to add code to the Wasmtime side to get the input / output tensor shapes
let intypes = graph.get_input_types();
let outtypes = graph.get_output_types();

// Load a tensor that precisely matches the graph input tensor (see
// `fixture/frozen_inference_graph.xml`).
for i in 0..5 {
let filename: String = format!("{}{}{}", "fixture/images/", i, ".jpg");
// Convert the image. If it fails just exit
let tensor_data = convert_image_to_bytes(&filename, 224, 224, TensorType::F32, ColorOrder::BGR).or_else(|e| {
Err(e)
}).unwrap();
let tensor_data =
convert_image_to_bytes(&filename, 224, 224, TensorType::F32, ColorOrder::BGR)
.or_else(|e| Err(e))
.unwrap();

println!("Read input tensor, size in bytes: {}", tensor_data.len());

wasi_nn::set_input_tensor(
context,
0,
&[1, 3, 224, 224],
wasi_nn::TENSOR_TYPE_F32,
&tensor_data,
).unwrap();
let tensor = wasi_nn::Tensor {
dimensions: &[1, 3, 224, 224],
type_: wasi_nn::TENSOR_TYPE_F32,
data: &tensor_data,
};

context.set_input(0, tensor);

// Execute the inference and get the output.
let mut output_buffer = vec![0f32; 1001];
let _written = wasi_nn::execute(
context,
let _res = context.compute();
let _wrote = context.get_output(
0,
&mut output_buffer[..] as *mut [f32] as *mut u8,
(output_buffer.len() * 4).try_into().unwrap(),
).unwrap();
);

println!("Executed graph inference");

let results = sort_results(&output_buffer);
println!("Found results, sorted top 5: {:?}", &results[..5]);

Expand Down
39 changes: 0 additions & 39 deletions rust/src/generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,42 +279,3 @@ pub mod wasi_ephemeral_nn {
pub fn compute(arg0: i32) -> i32;
}
}

pub fn backend_init(
builder: GraphBuilderArray<'_>,
encoding: GraphEncoding,
target: ExecutionTarget,) -> Result<GraphExecutionContext, NnErrno> {
unsafe {
init_execution_context(load(builder, encoding, target)?)
}
}

pub fn set_input_tensor(
context: GraphExecutionContext,
index: u32,
dimensions: &[u32],
ttype: TensorType,
tdata: TensorData,
) -> Result<(), NnErrno> {
let tensor = Tensor {
dimensions: dimensions,
type_: ttype,
data: &tdata,
};

unsafe {
set_input(context, index, tensor)
}
}

pub fn execute(
context: GraphExecutionContext,
index: u32,
out_buffer: *mut u8,
out_buffer_max_size: BufferSize,
) -> Result<BufferSize, NnErrno> {
unsafe {
compute(context)?;
get_output(context, index, out_buffer, out_buffer_max_size)
}
}
94 changes: 94 additions & 0 deletions rust/src/hl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use crate::generated::*;

pub struct WasiNnGraph {
execution_ctx: ExecutionContext,
input_types: TensorTypes,
output_types: TensorTypes,
}

impl WasiNnGraph {
pub fn load<'a>(
data: impl Iterator<Item = &'a [u8]>,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Self, NnErrno> {
unsafe {
let builders: Vec<&'a [u8]> = data.map(|x| x).collect();

if builders.len() > 1 {
let ctx = init_execution_context(load(&builders.as_slice(), encoding, target)?);
match ctx {
Ok(ctx) => Ok(Self {
execution_ctx: ExecutionContext::new(ctx),
input_types: TensorTypes::new(),
output_types: TensorTypes::new(),
}),
Err(ctx) => Err(ctx),
}
} else {
Err(NN_ERRNO_MISSING_MEMORY)
}
}
}

pub fn get_execution_context(&self) -> ExecutionContext {
self.execution_ctx
}

pub fn get_input_types(&self) -> &TensorTypes {
&self.input_types
}

pub fn get_output_types(&self) -> &TensorTypes {
&self.output_types
}
}

pub struct TensorTypes {
ttypes: Vec<TensorType>,
}

impl TensorTypes {
pub fn new() -> Self {
Self { ttypes: Vec::new() }
}

pub fn len(&self) -> usize {
self.ttypes.len()
}
pub fn get(&self, index: u32) -> Result<TensorType, String> {
if index < self.ttypes.len() as u32 {
Ok(self.ttypes[index as usize])
} else {
Err(format!("Invalid index {}", index))
}
}
}

#[derive(Copy, Clone, Debug)]
pub struct ExecutionContext {
execution_ctx: GraphExecutionContext,
}

impl ExecutionContext {
pub fn new(ctx: GraphExecutionContext) -> Self {
Self { execution_ctx: ctx }
}

pub fn set_input(&mut self, index: u32, tensor: Tensor) -> Result<(), NnErrno> {
unsafe { set_input(self.execution_ctx, index, tensor) }
}

pub fn compute(&mut self) -> Result<(), NnErrno> {
unsafe { compute(self.execution_ctx) }
}

pub fn get_output(
&self,
index: u32,
out_buffer: *mut u8,
out_buffer_max_size: BufferSize,
) -> Result<BufferSize, NnErrno> {
unsafe { get_output(self.execution_ctx, index, out_buffer, out_buffer_max_size) }
}
}
2 changes: 2 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
mod generated;
pub use generated::*;
mod hl;
pub use hl::*;

0 comments on commit 90eb701

Please sign in to comment.