Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: enhance contract macro #27

Merged
merged 13 commits into from
Dec 20, 2024
8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ resolver = "2"
members = ["eth-riscv-interpreter", "eth-riscv-syscalls", "r55"]
default-members = ["eth-riscv-interpreter", "eth-riscv-syscalls", "r55"]

exclude = ["contract-derive", "erc20", "erc20x", "eth-riscv-runtime"]
exclude = [
"contract-derive",
"erc20",
"erc20x",
"erc20x_standalone",
"eth-riscv-runtime",
]

[workspace.package]
version = "0.1.0"
Expand Down
130 changes: 130 additions & 0 deletions contract-derive/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use alloy_core::primitives::keccak256;
use alloy_sol_types::SolValue;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{FnArg, Ident, ImplItemMethod, ReturnType, TraitItemMethod};

// Unified method info from `ImplItemMethod` and `TraitItemMethod`
pub struct MethodInfo<'a> {
name: &'a Ident,
args: Vec<syn::FnArg>,
return_type: &'a ReturnType,
}

impl<'a> From<&'a ImplItemMethod> for MethodInfo<'a> {
fn from(method: &'a ImplItemMethod) -> Self {
Self {
name: &method.sig.ident,
args: method.sig.inputs.iter().skip(1).cloned().collect(),
return_type: &method.sig.output,
}
}
}

impl<'a> From<&'a TraitItemMethod> for MethodInfo<'a> {
fn from(method: &'a TraitItemMethod) -> Self {
Self {
name: &method.sig.ident,
args: method.sig.inputs.iter().skip(1).cloned().collect(),
return_type: &method.sig.output,
}
}
}

// Helper function to generate intercate impl from user-defined methods
pub fn generate_interface<T>(
methods: &[&T],
interface_name: &Ident,
) -> quote::__private::TokenStream
where
for<'a> MethodInfo<'a>: From<&'a T>,
{
let methods: Vec<MethodInfo> = methods.iter().map(|&m| MethodInfo::from(m)).collect();

// Generate implementation
let method_impls = methods.iter().map(|method| {
let name = method.name;
let args = &method.args;
let return_type = method.return_type;
let method_selector = u32::from_be_bytes(
keccak256(name.to_string())[..4]
.try_into()
.unwrap_or_default(),
);

// Simply use index for arg names, and extract types
let (arg_names, arg_types): (Vec<_>, Vec<_>) = args
.iter()
.enumerate()
.map(|(i, arg)| {
if let FnArg::Typed(pat_type) = arg {
let ty = &*pat_type.ty;
(format_ident!("arg{}", i), ty)
} else {
panic!("Expected typed arguments");
}
})
.unzip();

let calldata = if arg_names.is_empty() {
quote! {
let mut complete_calldata = Vec::with_capacity(4);
complete_calldata.extend_from_slice(&[
#method_selector.to_be_bytes()[0],
#method_selector.to_be_bytes()[1],
#method_selector.to_be_bytes()[2],
#method_selector.to_be_bytes()[3],
]);
}
} else {
quote! {
let mut args_calldata = (#(#arg_names),*).abi_encode();
let mut complete_calldata = Vec::with_capacity(4 + args_calldata.len());
complete_calldata.extend_from_slice(&[
#method_selector.to_be_bytes()[0],
#method_selector.to_be_bytes()[1],
#method_selector.to_be_bytes()[2],
#method_selector.to_be_bytes()[3],
]);
complete_calldata.append(&mut args_calldata);
}
};

let return_ty = match return_type {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => quote! { #ty },
};

quote! {
pub fn #name(&self, #(#arg_names: #arg_types),*) -> Option<#return_ty> {
use alloy_sol_types::SolValue;
use alloc::vec::Vec;

#calldata

let result = eth_riscv_runtime::call_contract(
self.address,
0_u64,
&complete_calldata,
32_u64
)?;

<#return_ty>::abi_decode(&result, true).ok()
}
}
});

quote! {
pub struct #interface_name {
address: Address,
}

impl #interface_name {
pub fn new(address: Address) -> Self {
Self { address }
}

#(#method_impls)*
}
}
}
197 changes: 63 additions & 134 deletions contract-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ use alloy_core::primitives::keccak256;
use alloy_sol_types::SolValue;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Fields, ImplItem, ItemImpl, ItemTrait, TraitItem};
use syn::{FnArg, ReturnType};
use syn::{
parse_macro_input, Data, DeriveInput, Fields, FnArg, ImplItem, ImplItemMethod, ItemImpl,
ItemTrait, ReturnType, TraitItem,
};

mod helpers;

#[proc_macro_derive(Event, attributes(indexed))]
pub fn event_derive(input: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -100,13 +104,13 @@ pub fn contract(_attr: TokenStream, item: TokenStream) -> TokenStream {
panic!("Expected a struct.");
};

let mut public_methods = Vec::new();
let mut public_methods: Vec<&ImplItemMethod> = Vec::new();

// Iterate over the items in the impl block to find pub methods
for item in input.items.iter() {
if let ImplItem::Method(method) = item {
if let syn::Visibility::Public(_) = method.vis {
public_methods.push(method.clone());
public_methods.push(method);
}
}
}
Expand Down Expand Up @@ -217,42 +221,62 @@ pub fn contract(_attr: TokenStream, item: TokenStream) -> TokenStream {
}
};

// Generate the call method implementation
let call_method = quote! {
use alloy_sol_types::SolValue;
use eth_riscv_runtime::*;
// Generate the interface
let interface_name = format_ident!("I{}", struct_name);
let interface = helpers::generate_interface(&public_methods, &interface_name);

#emit_helper
impl Contract for #struct_name {
fn call(&self) {
self.call_with_data(&msg_data());
}
// Generate the complete output with module structure
let output = quote! {
// Public interface module
pub mod interface {
use super::*;
#interface
}

// Generate the call method implementation privately
// only when not in `interface-only` mode
#[cfg(not(feature = "interface-only"))]
mod implementation {
use super::*;
use alloy_sol_types::SolValue;
use eth_riscv_runtime::*;

#input

#emit_helper

impl Contract for #struct_name {
fn call(&self) {
self.call_with_data(&msg_data());
}

fn call_with_data(&self, calldata: &[u8]) {
let selector = u32::from_be_bytes([calldata[0], calldata[1], calldata[2], calldata[3]]);
let calldata = &calldata[4..];

fn call_with_data(&self, calldata: &[u8]) {
let selector = u32::from_be_bytes([calldata[0], calldata[1], calldata[2], calldata[3]]);
let calldata = &calldata[4..];
match selector {
#( #match_arms )*
_ => revert(),
}

match selector {
#( #match_arms )*
_ => revert(),
return_riscv(0, 0);
}
}

return_riscv(0, 0);
#[eth_riscv_runtime::entry]
fn main() -> ! {
let contract = #struct_name::default();
contract.call();
eth_riscv_runtime::return_riscv(0, 0)
}
}

#[eth_riscv_runtime::entry]
fn main() -> !
{
let contract = #struct_name::default();
contract.call();
eth_riscv_runtime::return_riscv(0, 0)
}
};
// Always export the interface
pub use interface::*;

let output = quote! {
#input
#call_method
// Only export contract impl when not in `interface-only` mode
#[cfg(not(feature = "interface-only"))]
pub use implementation::*;
};

TokenStream::from(output)
Expand Down Expand Up @@ -281,119 +305,24 @@ pub fn interface(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemTrait);
let trait_name = &input.ident;

let method_impls: Vec<_> = input
let methods: Vec<_> = input
.items
.iter()
.map(|item| {
if let TraitItem::Method(method) = item {
let method_name = &method.sig.ident;
let selector_bytes = keccak256(method_name.to_string())[..4]
.try_into()
.unwrap_or_default();
let method_selector = u32::from_be_bytes(selector_bytes);

// Extract argument types and names, skipping self
let arg_types: Vec<_> = method
.sig
.inputs
.iter()
.skip(1)
.map(|arg| {
if let FnArg::Typed(pat_type) = arg {
let ty = &*pat_type.ty;
quote! { #ty }
} else {
panic!("Expected typed arguments");
}
})
.collect();
let arg_names: Vec<_> = (0..method.sig.inputs.len() - 1)
.map(|i| format_ident!("arg{}", i))
.collect();

// Get the return type
let return_type = match &method.sig.output {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) =>
quote! { #ty },
};

// Generate calldata with different encoding depending on # of args
let args_encoding = if arg_names.is_empty() {
quote! {
let mut complete_calldata = Vec::with_capacity(4);
complete_calldata.extend_from_slice(&[
#method_selector.to_be_bytes()[0],
#method_selector.to_be_bytes()[1],
#method_selector.to_be_bytes()[2],
#method_selector.to_be_bytes()[3],
]);
}
} else if arg_names.len() == 1 {
quote! {
let mut args_calldata = #(#arg_names),*.abi_encode();
let mut complete_calldata = Vec::with_capacity(4 + args_calldata.len());
complete_calldata.extend_from_slice(&[
#method_selector.to_be_bytes()[0],
#method_selector.to_be_bytes()[1],
#method_selector.to_be_bytes()[2],
#method_selector.to_be_bytes()[3],
]);
complete_calldata.append(&mut args_calldata);
}
} else {
quote! {
let mut args_calldata = (#(#arg_names),*).abi_encode();
let mut complete_calldata = Vec::with_capacity(4 + args_calldata.len());
complete_calldata.extend_from_slice(&[
#method_selector.to_be_bytes()[0],
#method_selector.to_be_bytes()[1],
#method_selector.to_be_bytes()[2],
#method_selector.to_be_bytes()[3],
]);
complete_calldata.append(&mut args_calldata);
}
};

Some(quote! {
pub fn #method_name(&self, #(#arg_names: #arg_types),*) -> Option<#return_type> {
use alloy_sol_types::SolValue;
use alloc::vec::Vec;

#args_encoding

// Make the call
let result = eth_riscv_runtime::call_contract(
self.address,
0_u64,
&complete_calldata,
32_u64 // TODO: Figure out how to use SolType to get the return size

)?;

// Decode result
<#return_type>::abi_decode(&result, true).ok()
}
})
method
} else {
panic!("Expected methods arguments");
panic!("Expected methods arguments")
}
})
.collect();

let expanded = quote! {
pub struct #trait_name {
address: Address,
}

impl #trait_name {
pub fn new(address: Address) -> Self {
Self { address }
}
// Generate intreface implementation
let interface = helpers::generate_interface(&methods, trait_name);

#(#method_impls)*
}
let output = quote! {
#interface
};

TokenStream::from(expanded)
TokenStream::from(output)
}
Loading
Loading