diff --git a/Cargo.toml b/Cargo.toml index b87be607..c171e8fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,8 +36,6 @@ serde = { version = "1", features = ["derive"] } thiserror = { version = "1", optional = true } serde_json = { version = "1", optional = true } -tokio = { version = "1", features = ["sync"], optional = true } - [features] default = [] -js-tracer = ["boa_engine", "boa_gc", "tokio", "thiserror", "serde_json"] +js-tracer = ["boa_engine", "boa_gc", "thiserror", "serde_json"] diff --git a/src/tracing/js/bindings.rs b/src/tracing/js/bindings.rs index f3f3c051..7ef499ec 100644 --- a/src/tracing/js/bindings.rs +++ b/src/tracing/js/bindings.rs @@ -6,7 +6,7 @@ use crate::tracing::{ address_to_buf, bytes_to_address, bytes_to_hash, from_buf, to_bigint, to_buf, to_buf_value, }, - JsDbRequest, TransactionContext, + TransactionContext, }, types::CallKind, }; @@ -22,10 +22,10 @@ use revm::{ opcode::{PUSH0, PUSH32}, OpCode, SharedMemory, Stack, }, - primitives::{AccountInfo, State, KECCAK_EMPTY}, + primitives::{AccountInfo, Bytecode, State, KECCAK_EMPTY}, + DatabaseRef, }; -use std::{cell::RefCell, rc::Rc, sync::mpsc::channel}; -use tokio::sync::mpsc; +use std::{cell::RefCell, rc::Rc}; /// A macro that creates a native function that returns via [JsValue::from] macro_rules! js_value_getter { @@ -54,8 +54,8 @@ macro_rules! js_value_capture_getter { }; } -/// A reference to a value that can be garbagae collected, but will not give access to the value if -/// it has been dropped. +/// A wrapper for a value that can be garbage collected, but will not give access to the value if +/// it has been dropped via its guard. /// /// This is used to allow the JS tracer functions to access values at a certain point during /// inspection by ref without having to clone them and capture them in the js object. @@ -70,18 +70,29 @@ macro_rules! js_value_capture_getter { /// This type supports garbage collection of (rust) references and prevents access to the value if /// it has been dropped. #[derive(Debug, Clone)] -pub(crate) struct GuardedNullableGcRef { +pub(crate) struct GuardedNullableGc { /// The lifetime is a lie to make it possible to use a reference in boa which requires 'static - inner: Rc>>, + inner: Rc>>>, } -impl GuardedNullableGcRef { - /// Creates a garbage collectible reference to the given reference. +impl GuardedNullableGc { + /// Creates a garbage collectible value to the given reference. + /// + /// SAFETY; the caller must ensure that the guard is dropped before the value is dropped. + pub(crate) fn r#ref(val: &Val) -> (Self, GcGuard<'_, Val>) { + Self::new(Guarded::Ref(val)) + } + + /// Creates a garbage collectible value to the given reference. /// /// SAFETY; the caller must ensure that the guard is dropped before the value is dropped. - pub(crate) fn new(val: &Val) -> (Self, RefGuard<'_, Val>) { + pub(crate) fn r#owned<'a>(val: Val) -> (Self, GcGuard<'a, Val>) { + Self::new(Guarded::Owned(val)) + } + + fn new(val: Guarded<'_, Val>) -> (Self, GcGuard<'_, Val>) { let inner = Rc::new(RefCell::new(Some(val))); - let guard = RefGuard { inner: Rc::clone(&inner) }; + let guard = GcGuard { inner: Rc::clone(&inner) }; // SAFETY: guard enforces that the value is removed from the refcell before it is dropped let this = Self { inner: unsafe { std::mem::transmute(inner) } }; @@ -94,25 +105,35 @@ impl GuardedNullableGcRef { where F: FnOnce(&Val) -> R, { - self.inner.borrow().map(f) + self.inner.borrow().as_ref().map(|val| match val { + Guarded::Ref(val) => f(val), + Guarded::Owned(val) => f(val), + }) } } -impl Finalize for GuardedNullableGcRef {} +impl Finalize for GuardedNullableGc {} -unsafe impl Trace for GuardedNullableGcRef { +unsafe impl Trace for GuardedNullableGc { empty_trace!(); } -/// Guard the inner references, once this value is dropped the inner reference is also removed. +/// A value that is either a reference or an owned value. +#[derive(Debug)] +enum Guarded<'a, T> { + Ref(&'a T), + Owned(T), +} + +/// Guard the inner value, once this value is dropped the inner value is also removed. /// -/// This type guarantees that it never outlives the wrapped reference. +/// This type guarantees that it never outlives the wrapped value. #[derive(Debug)] -pub(crate) struct RefGuard<'a, Val> { - inner: Rc>>, +pub(crate) struct GcGuard<'a, Val> { + inner: Rc>>>, } -impl<'a, Val> Drop for RefGuard<'a, Val> { +impl<'a, Val> Drop for GcGuard<'a, Val> { fn drop(&mut self) { self.inner.borrow_mut().take(); } @@ -196,12 +217,12 @@ impl StepLog { /// Represents the memory object #[derive(Debug, Clone)] -pub(crate) struct MemoryRef(pub(crate) GuardedNullableGcRef); +pub(crate) struct MemoryRef(pub(crate) GuardedNullableGc); impl MemoryRef { /// Creates a new stack reference - pub(crate) fn new(mem: &SharedMemory) -> (Self, RefGuard<'_, SharedMemory>) { - let (inner, guard) = GuardedNullableGcRef::new(mem); + pub(crate) fn new(mem: &SharedMemory) -> (Self, GcGuard<'_, SharedMemory>) { + let (inner, guard) = GuardedNullableGc::r#ref(mem); (MemoryRef(inner), guard) } @@ -288,12 +309,12 @@ unsafe impl Trace for MemoryRef { /// Represents the state object #[derive(Debug, Clone)] -pub(crate) struct StateRef(pub(crate) GuardedNullableGcRef); +pub(crate) struct StateRef(pub(crate) GuardedNullableGc); impl StateRef { /// Creates a new stack reference - pub(crate) fn new(state: &State) -> (Self, RefGuard<'_, State>) { - let (inner, guard) = GuardedNullableGcRef::new(state); + pub(crate) fn new(state: &State) -> (Self, GcGuard<'_, State>) { + let (inner, guard) = GuardedNullableGc::r#ref(state); (StateRef(inner), guard) } @@ -308,6 +329,27 @@ unsafe impl Trace for StateRef { empty_trace!(); } +/// Represents the database +#[derive(Debug, Clone)] +pub(crate) struct GcDb(pub(crate) GuardedNullableGc); + +impl GcDb +where + DB: DatabaseRef + 'static, +{ + /// Creates a new stack reference + fn new<'a>(db: DB) -> (Self, GcGuard<'a, DB>) { + let (inner, guard) = GuardedNullableGc::owned(db); + (GcDb(inner), guard) + } +} + +impl Finalize for GcDb {} + +unsafe impl Trace for GcDb { + empty_trace!(); +} + /// Represents the opcode object #[derive(Debug)] pub(crate) struct OpObj(pub(crate) u8); @@ -367,12 +409,12 @@ impl From for OpObj { /// Represents the stack object #[derive(Debug)] -pub(crate) struct StackRef(pub(crate) GuardedNullableGcRef); +pub(crate) struct StackRef(pub(crate) GuardedNullableGc); impl StackRef { /// Creates a new stack reference - pub(crate) fn new(stack: &Stack) -> (Self, RefGuard<'_, Stack>) { - let (inner, guard) = GuardedNullableGcRef::new(stack); + pub(crate) fn new(stack: &Stack) -> (Self, GcGuard<'_, Stack>) { + let (inner, guard) = GuardedNullableGc::r#ref(stack); (StackRef(inner), guard) } @@ -680,39 +722,51 @@ impl EvmContext { } /// DB is the object that allows the js inspector to interact with the database. -#[derive(Debug, Clone)] pub(crate) struct EvmDbRef { - state: StateRef, - to_db: mpsc::Sender, + inner: Rc, } impl EvmDbRef { - /// Creates a new DB reference - pub(crate) fn new( - state: &State, - to_db: mpsc::Sender, - ) -> (Self, RefGuard<'_, State>) { - let (state, guard) = StateRef::new(state); - let this = Self { state, to_db }; + /// Creates a new evm and db JS object. + pub(crate) fn new<'a, 'b, DB>(state: &'a State, db: &'b DB) -> (Self, EvmDbGuard<'a, 'b>) + where + DB: DatabaseRef, + DB::Error: std::fmt::Display, + { + let (state, state_guard) = StateRef::new(state); + + // SAFETY: + // + // boa requires 'static lifetime for all objects. + // As mention in the `Safety` section of [GuardedNullableGc] the caller of this function + // needs to guarantee that the passed-in lifetime is sufficiently long for the lifetime of + // the guard. + let db = JsDb(db); + let js_db = unsafe { + std::mem::transmute::< + Box + '_>, + Box + 'static>, + >(Box::new(db)) + }; + + let (db, db_guard) = GcDb::new(js_db); + + let inner = EvmDbRefInner { state, db }; + let this = Self { inner: Rc::new(inner) }; + let guard = EvmDbGuard { _state_guard: state_guard, _db_guard: db_guard }; (this, guard) } fn read_basic(&self, address: JsValue, ctx: &mut Context<'_>) -> JsResult> { let buf = from_buf(address, ctx)?; let address = bytes_to_address(buf); - if let acc @ Some(_) = self.state.get_account(&address) { + if let acc @ Some(_) = self.inner.state.get_account(&address) { return Ok(acc); } - let (tx, rx) = channel(); - if self.to_db.try_send(JsDbRequest::Basic { address, resp: tx }).is_err() { - return Err(JsError::from_native( - JsNativeError::error() - .with_message(format!("Failed to read address {address:?} from database",)), - )); - } - match rx.recv() { - Ok(Ok(maybe_acc)) => Ok(maybe_acc), + let res = self.inner.db.0.with_inner(|db| db.basic_ref(address)); + match res { + Some(Ok(maybe_acc)) => Ok(maybe_acc), _ => Err(JsError::from_native( JsNativeError::error() .with_message(format!("Failed to read address {address:?} from database",)), @@ -727,16 +781,13 @@ impl EvmDbRef { return JsArrayBuffer::new(0, ctx); } - let (tx, rx) = channel(); - if self.to_db.try_send(JsDbRequest::Code { code_hash, resp: tx }).is_err() { - return Err(JsError::from_native( - JsNativeError::error() - .with_message(format!("Failed to read code hash {code_hash:?} from database",)), - )); - } - - let code = match rx.recv() { - Ok(Ok(code)) => code, + let res = self + .inner + .db + .0 + .with_inner(|db| db.code_by_hash_ref(code_hash).map(|code| code.bytecode)); + let code = match res { + Some(Ok(code)) => code, _ => { return Err(JsError::from_native(JsNativeError::error().with_message(format!( "Failed to read code hash {code_hash:?} from database", @@ -759,19 +810,10 @@ impl EvmDbRef { let buf = from_buf(slot, ctx)?; let slot = bytes_to_hash(buf); - let (tx, rx) = channel(); - if self - .to_db - .try_send(JsDbRequest::StorageAt { address, index: slot.into(), resp: tx }) - .is_err() - { - return Err(JsError::from_native(JsNativeError::error().with_message(format!( - "Failed to read state for {address:?} at {slot:?} from database", - )))); - } + let res = self.inner.db.0.with_inner(|db| db.storage_ref(address, slot.into())); - let value = match rx.recv() { - Ok(Ok(value)) => value, + let value = match res { + Some(Ok(value)) => value, _ => { return Err(JsError::from_native(JsNativeError::error().with_message(format!( "Failed to read state for {address:?} at {slot:?} from database", @@ -871,11 +913,61 @@ unsafe impl Trace for EvmDbRef { empty_trace!(); } +impl Clone for EvmDbRef { + fn clone(&self) -> Self { + Self { inner: Rc::clone(&self.inner) } + } +} + +/// DB is the object that allows the js inspector to interact with the database. +struct EvmDbRefInner { + state: StateRef, + db: GcDb + 'static>>, +} + +/// Guard the inner references, once this value is dropped the inner reference is also removed. +/// +/// This ensures that the guards are dropped within the lifetime of the borrowed values. +pub(crate) struct EvmDbGuard<'a, 'b> { + _state_guard: GcGuard<'a, State>, + _db_guard: GcGuard<'b, Box + 'static>>, +} + +/// A wrapper Database for the JS context. +pub(crate) struct JsDb(DB); + +impl DatabaseRef for JsDb +where + DB: DatabaseRef, + DB::Error: std::fmt::Display, +{ + type Error = String; + + fn basic_ref(&self, _address: Address) -> Result, Self::Error> { + self.0.basic_ref(_address).map_err(|e| e.to_string()) + } + + fn code_by_hash_ref(&self, _code_hash: B256) -> Result { + self.0.code_by_hash_ref(_code_hash).map_err(|e| e.to_string()) + } + + fn storage_ref(&self, _address: Address, _index: U256) -> Result { + self.0.storage_ref(_address, _index).map_err(|e| e.to_string()) + } + + fn block_hash_ref(&self, _number: U256) -> Result { + self.0.block_hash_ref(_number).map_err(|e| e.to_string()) + } +} + #[cfg(test)] mod tests { - use super::*; - use crate::tracing::js::builtins::BIG_INT_JS; use boa_engine::{object::builtins::JsArrayBuffer, property::Attribute, Source}; + use revm::db::{CacheDB, EmptyDB}; + + use crate::tracing::js::builtins::BIG_INT_JS; + + use super::*; #[test] fn test_contract() { @@ -929,4 +1021,97 @@ mod tests { let input = buffer.take().unwrap(); assert_eq!(input, contract.input); } + + #[test] + fn test_evm_db_gc() { + let mut context = Context::default(); + + let result = context + .eval(Source::from_bytes( + "( + function(db, addr) {return db.exists(addr) } + ) + " + .to_string() + .as_bytes(), + )) + .unwrap(); + assert!(result.is_callable()); + + let f = result.as_callable().unwrap(); + + let mut db = CacheDB::new(EmptyDB::new()); + let state = State::default(); + { + let (db, guard) = EvmDbRef::new(&state, &db); + let addr = Address::default(); + let addr = JsValue::from(addr.to_string()); + let db = db.into_js_object(&mut context).unwrap(); + let res = f.call(&result, &[db.clone().into(), addr.clone()], &mut context).unwrap(); + assert!(!res.as_boolean().unwrap()); + + // drop the db which also drops any GC values + drop(guard); + let res = f.call(&result, &[db.clone().into(), addr.clone()], &mut context); + assert!(res.is_err()); + } + let addr = Address::default(); + db.insert_account_info(addr, Default::default()); + + { + let (db, guard) = EvmDbRef::new(&state, &db); + let addr = JsValue::from(addr.to_string()); + let db = db.into_js_object(&mut context).unwrap(); + let res = f.call(&result, &[db.clone().into(), addr.clone()], &mut context).unwrap(); + + // account exists + assert!(res.as_boolean().unwrap()); + + // drop the db which also drops any GC values + drop(guard); + let res = f.call(&result, &[db.clone().into(), addr.clone()], &mut context); + assert!(res.is_err()); + } + } + + #[test] + fn test_evm_db_gc_captures() { + let mut context = Context::default(); + + let res = context + .eval(Source::from_bytes( + r"({ + setup: function(db) {this.db = db;}, + result: function(addr) {return this.db.exists(addr) } + }) + " + .to_string() + .as_bytes(), + )) + .unwrap(); + + let obj = res.as_object().unwrap(); + + let result_fn = obj.get("result", &mut context).unwrap().as_object().cloned().unwrap(); + let setup_fn = obj.get("setup", &mut context).unwrap().as_object().cloned().unwrap(); + + let db = CacheDB::new(EmptyDB::new()); + let state = State::default(); + { + let (db_ref, guard) = EvmDbRef::new(&state, &db); + let js_db = db_ref.into_js_object(&mut context).unwrap(); + let _res = setup_fn.call(&(obj.clone().into()), &[js_db.into()], &mut context).unwrap(); + assert!(obj.get("db", &mut context).unwrap().is_object()); + + let addr = Address::default(); + let addr = JsValue::from(addr.to_string()); + let res = result_fn.call(&(obj.clone().into()), &[addr.clone()], &mut context).unwrap(); + assert!(!res.as_boolean().unwrap()); + + // drop the guard which also drops any GC values + drop(guard); + let res = result_fn.call(&(obj.clone().into()), &[addr], &mut context); + assert!(res.is_err()); + } + } } diff --git a/src/tracing/js/builtins.rs b/src/tracing/js/builtins.rs index c03fa808..2a97157b 100644 --- a/src/tracing/js/builtins.rs +++ b/src/tracing/js/builtins.rs @@ -49,7 +49,13 @@ pub(crate) fn from_buf(val: JsValue, context: &mut Context<'_>) -> JsResult, /// Keeps track of the current call stack. call_stack: Vec, - /// sender half of a channel to communicate with the database service. - to_db_service: mpsc::Sender, /// Marker to track whether the precompiles have been registered. precompiles_registered: bool, } @@ -92,12 +89,8 @@ impl JsInspector { /// /// This also accepts a sender half of a channel to communicate with the database service so the /// DB can be queried from inside the inspector. - pub fn new( - code: String, - config: serde_json::Value, - to_db_service: mpsc::Sender, - ) -> Result { - Self::with_transaction_context(code, config, to_db_service, Default::default()) + pub fn new(code: String, config: serde_json::Value) -> Result { + Self::with_transaction_context(code, config, Default::default()) } /// Creates a new inspector from a javascript code snippet. See also [Self::new]. @@ -107,7 +100,6 @@ impl JsInspector { pub fn with_transaction_context( code: String, config: serde_json::Value, - to_db_service: mpsc::Sender, transaction_context: TransactionContext, ) -> Result { // Instantiate the execution context @@ -177,7 +169,6 @@ impl JsInspector { exit_fn, step_fn, call_stack: Default::default(), - to_db_service, precompiles_registered: false, }) } @@ -207,18 +198,32 @@ impl JsInspector { /// Calls the result function and returns the result as [serde_json::Value]. /// /// Note: This is supposed to be called after the inspection has finished. - pub fn json_result( + pub fn json_result( &mut self, res: ResultAndState, env: &Env, - ) -> Result { - Ok(self.result(res, env)?.to_json(&mut self.ctx)?) + db: &DB, + ) -> Result + where + DB: DatabaseRef, + ::Error: std::fmt::Display, + { + Ok(self.result(res, env, db)?.to_json(&mut self.ctx)?) } /// Calls the result function and returns the result. - pub fn result(&mut self, res: ResultAndState, env: &Env) -> Result { + pub fn result( + &mut self, + res: ResultAndState, + env: &Env, + db: &DB, + ) -> Result + where + DB: DatabaseRef, + ::Error: std::fmt::Display, + { let ResultAndState { result, state } = res; - let (db, _db_guard) = EvmDbRef::new(&state, self.to_db_service.clone()); + let (db, _db_guard) = EvmDbRef::new(&state, db); let gas_used = result.gas_used(); let mut to = None; @@ -367,15 +372,15 @@ impl JsInspector { impl Inspector for JsInspector where - DB: Database, + DB: Database + DatabaseRef, + ::Error: std::fmt::Display, { fn step(&mut self, interp: &mut Interpreter<'_>, data: &mut EVMData<'_, DB>) { if self.step_fn.is_none() { return; } - let (db, _db_guard) = - EvmDbRef::new(&data.journaled_state.state, self.to_db_service.clone()); + let (db, _db_guard) = EvmDbRef::new(&data.journaled_state.state, &*data.db); let (stack, _stack_guard) = StackRef::new(&interp.stack); let (memory, _memory_guard) = MemoryRef::new(interp.shared_memory); @@ -412,8 +417,7 @@ where } if matches!(interp.instruction_result, return_revert!()) { - let (db, _db_guard) = - EvmDbRef::new(&data.journaled_state.state, self.to_db_service.clone()); + let (db, _db_guard) = EvmDbRef::new(&data.journaled_state.state, &*data.db); let (stack, _stack_guard) = StackRef::new(&interp.stack); let (memory, _memory_guard) = MemoryRef::new(interp.shared_memory); @@ -604,34 +608,6 @@ impl TransactionContext { } } -/// Request variants to be sent from the inspector to the database -#[derive(Debug, Clone)] -pub enum JsDbRequest { - /// Bindings for [Database::basic] - Basic { - /// The address of the account to be loaded - address: Address, - /// The response channel - resp: std::sync::mpsc::Sender, String>>, - }, - /// Bindings for [Database::code_by_hash] - Code { - /// The code hash of the code to be loaded - code_hash: B256, - /// The response channel - resp: std::sync::mpsc::Sender>, - }, - /// Bindings for [Database::storage] - StorageAt { - /// The address of the account - address: Address, - /// Index of the storage slot - index: U256, - /// The response channel - resp: std::sync::mpsc::Sender>, - }, -} - /// Represents an active call #[derive(Debug)] struct CallStackItem {