diff --git a/Cargo.lock b/Cargo.lock index 03fd74c60664..ea94bd74a0de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1374,6 +1374,19 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake3" +version = "1.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.9.0" @@ -2044,6 +2057,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "convert_case" version = "0.6.0" @@ -6369,6 +6388,7 @@ checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "base64 0.22.1", "bytes", + "futures-channel", "futures-core", "futures-util", "http", @@ -9088,6 +9108,7 @@ dependencies = [ "alloy-rlp", "assert_matches", "bincode", + "blake3", "codspeed-criterion-compat", "futures-util", "itertools 0.13.0", @@ -9096,6 +9117,7 @@ dependencies = [ "pprof", "rand 0.8.5", "rayon", + "reqwest", "reth-chainspec", "reth-codecs", "reth-config", @@ -9109,6 +9131,7 @@ dependencies = [ "reth-execution-errors", "reth-execution-types", "reth-exex", + "reth-fs-util", "reth-network-p2p", "reth-network-peers", "reth-primitives", @@ -9121,8 +9144,10 @@ dependencies = [ "reth-static-file", "reth-storage-errors", "reth-testing-utils", + "reth-tracing", "reth-trie", "reth-trie-db", + "serde", "tempfile", "thiserror 2.0.11", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 0e6a39084b2e..4a6f90d7380b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -485,6 +485,7 @@ backon = { version = "1.2", default-features = false, features = [ ] } bincode = "1.3" bitflags = "2.4" +blake3 = "1.5.5" boyer-moore-magiclen = "0.2.16" bytes = { version = "1.5", default-features = false } cfg-if = "1.0" diff --git a/crates/stages/stages/Cargo.toml b/crates/stages/stages/Cargo.toml index e7114eeb16ac..2b519558c078 100644 --- a/crates/stages/stages/Cargo.toml +++ b/crates/stages/stages/Cargo.toml @@ -22,6 +22,7 @@ reth-db-api.workspace = true reth-etl.workspace = true reth-evm.workspace = true reth-exex.workspace = true +reth-fs-util.workspace = true reth-network-p2p.workspace = true reth-primitives = { workspace = true, features = ["secp256k1"] } reth-primitives-traits = { workspace = true, features = [ @@ -57,6 +58,12 @@ rayon.workspace = true num-traits = "0.2.15" tempfile = { workspace = true, optional = true } bincode.workspace = true +blake3.workspace = true +reqwest = { workspace = true, default-features = false, features = [ + "rustls-tls-native-roots", + "blocking" +] } +serde = { workspace = true, features = ["derive"] } [dev-dependencies] # reth @@ -75,6 +82,7 @@ reth-testing-utils.workspace = true reth-trie = { workspace = true, features = ["test-utils"] } reth-provider = { workspace = true, features = ["test-utils"] } reth-network-peers.workspace = true +reth-tracing.workspace = true alloy-rlp.workspace = true itertools.workspace = true diff --git a/crates/stages/stages/src/stages/mod.rs b/crates/stages/stages/src/stages/mod.rs index 33a4d76a11f9..142452aa5344 100644 --- a/crates/stages/stages/src/stages/mod.rs +++ b/crates/stages/stages/src/stages/mod.rs @@ -17,6 +17,8 @@ mod index_storage_history; /// Stage for computing state root. mod merkle; mod prune; +/// The s3 download stage +mod s3; /// The sender recovery stage. mod sender_recovery; /// The transaction lookup stage @@ -32,6 +34,7 @@ pub use index_account_history::*; pub use index_storage_history::*; pub use merkle::*; pub use prune::*; +pub use s3::*; pub use sender_recovery::*; pub use tx_lookup::*; diff --git a/crates/stages/stages/src/stages/s3/downloader/error.rs b/crates/stages/stages/src/stages/s3/downloader/error.rs new file mode 100644 index 000000000000..b9fc8cd99619 --- /dev/null +++ b/crates/stages/stages/src/stages/s3/downloader/error.rs @@ -0,0 +1,30 @@ +use alloy_primitives::B256; +use reth_fs_util::FsPathError; + +#[derive(Debug, thiserror::Error)] +pub enum DownloaderError { + /// Requires a valid `total_size` {0} + #[error("requires a valid total_size")] + InvalidMetadataTotalSize(Option), + #[error("tried to access chunk on index {0}, but there's only {1} chunks")] + /// Invalid chunk access + InvalidChunk(usize, usize), + // File hash mismatch. + #[error("file hash does not match the expected one {0} != {1} ")] + InvalidFileHash(B256, B256), + // Empty content length returned from the server. + #[error("metadata got an empty content length from server")] + EmptyContentLength, + /// Reqwest error + #[error(transparent)] + FsPath(#[from] FsPathError), + /// Reqwest error + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + /// Std Io error + #[error(transparent)] + StdIo(#[from] std::io::Error), + /// Bincode error + #[error(transparent)] + Bincode(#[from] bincode::Error), +} diff --git a/crates/stages/stages/src/stages/s3/downloader/fetch.rs b/crates/stages/stages/src/stages/s3/downloader/fetch.rs new file mode 100644 index 000000000000..e538e7a37d0f --- /dev/null +++ b/crates/stages/stages/src/stages/s3/downloader/fetch.rs @@ -0,0 +1,190 @@ +use super::{ + error::DownloaderError, + meta::Metadata, + worker::{worker_fetch, WorkerRequest, WorkerResponse}, +}; +use alloy_primitives::B256; +use reqwest::{blocking::Client, header::CONTENT_LENGTH}; +use std::{ + collections::HashMap, + fs::{File, OpenOptions}, + io::BufReader, + path::Path, + sync::mpsc::channel, +}; +use tracing::{debug, error}; + +/// Downloads file from url to data file path. +/// +/// If a `file_hash` is passed, it will verify it at the end. +/// +/// ## Details +/// +/// 1) A [`Metadata`] file is created or opened in `{target_dir}/download/{filename}.metadata`. It +/// tracks the download progress including total file size, downloaded bytes, chunk sizes, and +/// ranges that still need downloading. Allows for resumability. +/// 2) The target file is preallocated with the total size of the file in +/// `{target_dir}/download/{filename}`. +/// 3) Multiple `workers` are spawned for downloading of specific chunks of the file. +/// 4) `Orchestrator` manages workers, distributes chunk ranges, and ensures the download progresses +/// efficiently by dynamically assigning tasks to workers as they become available. +/// 5) Once the file is downloaded: +/// * If `file_hash` is `Some`, verifies its blake3 hash. +/// * Deletes the metadata file +/// * Moves downloaded file to target directory. +pub fn fetch( + filename: &str, + target_dir: &Path, + url: &str, + mut concurrent: u64, + file_hash: Option, +) -> Result<(), DownloaderError> { + // Create a temporary directory to download files to, before moving them to target_dir. + let download_dir = target_dir.join("download"); + reth_fs_util::create_dir_all(&download_dir)?; + + let data_file = download_dir.join(filename); + let mut metadata = metadata(&data_file, url)?; + if metadata.is_done() { + return Ok(()) + } + + // Ensure the file is preallocated so we can download it concurrently + { + let file = OpenOptions::new() + .create(true) + .truncate(true) + .read(true) + .write(true) + .open(&data_file)?; + + if file.metadata()?.len() as usize != metadata.total_size { + debug!(target: "sync::stages::s3::downloader", ?data_file, length = metadata.total_size, "Preallocating space."); + file.set_len(metadata.total_size as u64)?; + } + } + + while !metadata.is_done() { + // Find the missing file chunks and the minimum number of workers required + let missing_chunks = metadata.needed_ranges(); + concurrent = concurrent + .min(std::thread::available_parallelism()?.get() as u64) + .min(missing_chunks.len() as u64); + + // Create channels for communication between workers and orchestrator + let (orchestrator_tx, orchestrator_rx) = channel(); + + // Initiate workers + for worker_id in 0..concurrent { + let orchestrator_tx = orchestrator_tx.clone(); + let data_file = data_file.clone(); + let url = url.to_string(); + std::thread::spawn(move || { + if let Err(error) = worker_fetch(worker_id, &orchestrator_tx, data_file, url) { + let _ = orchestrator_tx.send(WorkerResponse::Err { worker_id, error }); + } + }); + } + + // Drop the sender to allow the loop processing to exit once all workers are done + drop(orchestrator_tx); + + let mut workers = HashMap::new(); + let mut missing_chunks = missing_chunks.into_iter(); + + // Distribute chunk ranges to workers when they free up + while let Ok(worker_msg) = orchestrator_rx.recv() { + debug!(target: "sync::stages::s3::downloader", ?worker_msg, "received message from worker"); + + let available_worker = match worker_msg { + WorkerResponse::Ready { worker_id, tx } => { + workers.insert(worker_id, tx); + worker_id + } + WorkerResponse::DownloadedChunk { worker_id, chunk_index, written_bytes } => { + metadata.update_chunk(chunk_index, written_bytes)?; + worker_id + } + WorkerResponse::Err { worker_id, error } => { + error!(target: "sync::stages::s3::downloader", ?worker_id, "Worker found an error: {:?}", error); + return Err(error) + } + }; + + let worker = workers.get(&available_worker).expect("should exist"); + match missing_chunks.next() { + Some((chunk_index, (start, end))) => { + let _ = worker.send(WorkerRequest::Download { chunk_index, start, end }); + } + None => { + let _ = worker.send(WorkerRequest::Finish); + } + } + } + } + + if let Some(file_hash) = file_hash { + check_file_hash(&data_file, &file_hash)?; + } + + // Move downloaded file to desired directory. + metadata.delete()?; + reth_fs_util::rename(data_file, target_dir.join(filename))?; + + Ok(()) +} + +/// Creates a metadata file used to keep track of the downloaded chunks. Useful on resuming after a +/// shutdown. +fn metadata(data_file: &Path, url: &str) -> Result { + if Metadata::file_path(data_file).exists() { + debug!(target: "sync::stages::s3::downloader", ?data_file, "Loading metadata "); + return Metadata::load(data_file) + } + + let client = Client::new(); + let resp = client.head(url).send()?; + let total_length: usize = resp + .headers() + .get(CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse().ok()) + .ok_or(DownloaderError::EmptyContentLength)?; + + debug!(target: "sync::stages::s3::downloader", ?data_file, "Creating metadata "); + + Metadata::builder(data_file).with_total_size(total_length).build() +} + +/// Ensures the file on path has the expected blake3 hash. +fn check_file_hash(path: &Path, expected: &B256) -> Result<(), DownloaderError> { + let mut reader = BufReader::new(File::open(path)?); + let mut hasher = blake3::Hasher::new(); + std::io::copy(&mut reader, &mut hasher)?; + + let file_hash = hasher.finalize(); + if file_hash.as_bytes() != expected { + return Err(DownloaderError::InvalidFileHash(file_hash.as_bytes().into(), *expected)) + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use alloy_primitives::b256; + + #[test] + fn test_download() { + reth_tracing::init_test_tracing(); + + let b3sum = b256!("81a7318f69fc1d6bb0a58a24af302f3b978bc75a435e4ae5d075f999cd060cfd"); + let url = "https://link.testfile.org/500MB"; + + let file = tempfile::NamedTempFile::new().unwrap(); + let filename = file.path().file_name().unwrap().to_str().unwrap(); + let target_dir = file.path().parent().unwrap(); + fetch(filename, target_dir, url, 4, Some(b3sum)).unwrap(); + } +} diff --git a/crates/stages/stages/src/stages/s3/downloader/meta.rs b/crates/stages/stages/src/stages/s3/downloader/meta.rs new file mode 100644 index 000000000000..c2ef754f7aff --- /dev/null +++ b/crates/stages/stages/src/stages/s3/downloader/meta.rs @@ -0,0 +1,161 @@ +use super::error::DownloaderError; +use serde::{Deserialize, Serialize}; +use std::{ + fs::File, + path::{Path, PathBuf}, +}; +use tracing::info; + +/// Tracks download progress and manages chunked downloads for resumable file transfers. +#[derive(Debug, Serialize, Deserialize)] +pub struct Metadata { + /// Total file size + pub total_size: usize, + /// Total file size + pub downloaded: usize, + /// Download chunk size. Default 150MB. + pub chunk_size: usize, + /// Each chunk remaining download range. + chunks: Vec>, + /// Path with the stored metadata. + #[serde(skip)] + path: PathBuf, +} + +impl Metadata { + /// Build a [`Metadata`] using [`MetadataBuilder`]. + pub fn builder(data_file: &Path) -> MetadataBuilder { + MetadataBuilder::new(Self::file_path(data_file)) + } + + /// Returns the metadata file path of a data file: `{data_file}.metadata` + pub fn file_path(data_file: &Path) -> PathBuf { + data_file.with_file_name(format!( + "{}.metadata", + data_file.file_name().unwrap_or_default().to_string_lossy() + )) + } + + /// Returns a list of all chunks with their remaining ranges to be downloaded. + /// + /// Returns a list of `(chunk_index, (start, end))` + pub fn needed_ranges(&self) -> Vec<(usize, (usize, usize))> { + self.chunks + .iter() + .enumerate() + .filter(|(_, remaining)| remaining.is_some()) + .map(|(index, remaining)| (index, remaining.expect("qed"))) + .collect() + } + + /// Updates a downloaded chunk. + pub fn update_chunk( + &mut self, + index: usize, + downloaded_bytes: usize, + ) -> Result<(), DownloaderError> { + self.downloaded += downloaded_bytes; + + let num_chunks = self.chunks.len(); + if index >= self.chunks.len() { + return Err(DownloaderError::InvalidChunk(index, num_chunks)) + } + + // Update chunk with downloaded range + if let Some((mut start, end)) = self.chunks[index] { + start += downloaded_bytes; + if start > end { + self.chunks[index] = None; + } else { + self.chunks[index] = Some((start, end)); + } + } + + let file = self.path.file_stem().unwrap_or_default().to_string_lossy().into_owned(); + info!( + target: "sync::stages::s3::downloader", + file, + "{}/{}", self.downloaded / 1024 / 1024, self.total_size / 1024 / 1024); + + self.commit() + } + + /// Commits the [`Metadata`] to file. + pub fn commit(&self) -> Result<(), DownloaderError> { + Ok(reth_fs_util::atomic_write_file(&self.path, |file| { + bincode::serialize_into(file, &self) + })?) + } + + /// Loads a [`Metadata`] file from disk using the target data file. + pub fn load(data_file: &Path) -> Result { + Ok(bincode::deserialize_from(File::open(Self::file_path(data_file))?)?) + } + + /// Returns true if we have downloaded all chunks. + pub fn is_done(&self) -> bool { + !self.chunks.iter().any(|c| c.is_some()) + } + + /// Deletes [`Metadata`] file from disk. + pub fn delete(self) -> Result<(), DownloaderError> { + Ok(reth_fs_util::remove_file(&self.path)?) + } +} + +/// A builder that can configure [Metadata] +#[derive(Debug)] +pub struct MetadataBuilder { + /// Path with the stored metadata. + metadata_path: PathBuf, + /// Total file size + total_size: Option, + /// Download chunk size. Default 150MB. + chunk_size: usize, +} + +impl MetadataBuilder { + const fn new(metadata_path: PathBuf) -> Self { + Self { + metadata_path, + total_size: None, + chunk_size: 150 * (1024 * 1024), // 150MB + } + } + + pub const fn with_total_size(mut self, total_size: usize) -> Self { + self.total_size = Some(total_size); + self + } + + pub const fn with_chunk_size(mut self, chunk_size: usize) -> Self { + self.chunk_size = chunk_size; + self + } + + /// Returns a [Metadata] if + pub fn build(&self) -> Result { + match &self.total_size { + Some(total_size) if *total_size > 0 => { + let chunks = (0..*total_size) + .step_by(self.chunk_size) + .map(|start| { + Some((start, (start + self.chunk_size).min(*total_size).saturating_sub(1))) + }) + .collect(); + + let metadata = Metadata { + path: self.metadata_path.clone(), + total_size: *total_size, + downloaded: 0, + chunk_size: self.chunk_size, + chunks, + }; + metadata.commit()?; + + Ok(metadata) + } + _ => Err(DownloaderError::InvalidMetadataTotalSize(self.total_size)), + } + } +} diff --git a/crates/stages/stages/src/stages/s3/downloader/mod.rs b/crates/stages/stages/src/stages/s3/downloader/mod.rs new file mode 100644 index 000000000000..bee3b3957b4f --- /dev/null +++ b/crates/stages/stages/src/stages/s3/downloader/mod.rs @@ -0,0 +1,10 @@ +//! Provides functionality for downloading files in chunks from a remote source. It supports +//! concurrent downloads, resuming interrupted downloads, and file integrity verification. + +mod error; +mod fetch; +mod meta; +mod worker; + +pub use fetch::fetch; +pub use meta::Metadata; diff --git a/crates/stages/stages/src/stages/s3/downloader/worker.rs b/crates/stages/stages/src/stages/s3/downloader/worker.rs new file mode 100644 index 000000000000..e79af18f498f --- /dev/null +++ b/crates/stages/stages/src/stages/s3/downloader/worker.rs @@ -0,0 +1,69 @@ +use reqwest::{blocking::Client, header::RANGE}; +use std::{ + fs::OpenOptions, + io::{BufWriter, Read, Seek, SeekFrom}, + path::PathBuf, + sync::mpsc::{channel, Sender}, +}; +use tracing::debug; + +use super::error::DownloaderError; + +/// Responses sent by a worker. +#[derive(Debug)] +pub(crate) enum WorkerResponse { + /// Worker has been spawned and awaiting work. + Ready { worker_id: u64, tx: Sender }, + /// Worker has downloaded + DownloadedChunk { worker_id: u64, chunk_index: usize, written_bytes: usize }, + /// Worker has encountered an error. + Err { worker_id: u64, error: DownloaderError }, +} + +/// Requests sent to a worker. +#[derive(Debug)] +pub(crate) enum WorkerRequest { + /// Requests a range to be downloaded. + Download { chunk_index: usize, start: usize, end: usize }, + /// Signals a worker exit. + Finish, +} + +/// Downloads requested chunk ranges to the data file. +pub(crate) fn worker_fetch( + worker_id: u64, + orchestrator_tx: &Sender, + data_file: PathBuf, + url: String, +) -> Result<(), DownloaderError> { + let client = Client::new(); + let mut data_file = BufWriter::new(OpenOptions::new().write(true).open(data_file)?); + + // Signals readiness to download + let (tx, rx) = channel::(); + let _ = orchestrator_tx.send(WorkerResponse::Ready { worker_id, tx }); + + while let Ok(req) = rx.recv() { + debug!(target: "sync::stages::s3::downloader", worker_id, ?req, "received from orchestrator"); + + match req { + WorkerRequest::Download { chunk_index, start, end } => { + data_file.seek(SeekFrom::Start(start as u64))?; + + let mut response = + client.get(&url).header(RANGE, format!("bytes={}-{}", start, end)).send()?; + + let written_bytes = std::io::copy(response.by_ref(), &mut data_file)? as usize; + + let _ = orchestrator_tx.send(WorkerResponse::DownloadedChunk { + worker_id, + chunk_index, + written_bytes, + }); + } + WorkerRequest::Finish => break, + } + } + + Ok(()) +} diff --git a/crates/stages/stages/src/stages/s3/filelist.rs b/crates/stages/stages/src/stages/s3/filelist.rs new file mode 100644 index 000000000000..683c4a208862 --- /dev/null +++ b/crates/stages/stages/src/stages/s3/filelist.rs @@ -0,0 +1,21 @@ +use alloy_primitives::B256; + +/// File list to be downloaded with their hashes. +pub(crate) static DOWNLOAD_FILE_LIST: [[(&str, B256); 3]; 2] = [ + [ + ("static_file_transactions_0_499999", B256::ZERO), + ("static_file_transactions_0_499999.off", B256::ZERO), + ("static_file_transactions_0_499999.conf", B256::ZERO), + // ("static_file_blockmeta_0_499999", B256::ZERO), + // ("static_file_blockmeta_0_499999.off", B256::ZERO), + // ("static_file_blockmeta_0_499999.conf", B256::ZERO), + ], + [ + ("static_file_transactions_500000_999999", B256::ZERO), + ("static_file_transactions_500000_999999.off", B256::ZERO), + ("static_file_transactions_500000_999999.conf", B256::ZERO), + // ("static_file_blockmeta_500000_999999", B256::ZERO), + // ("static_file_blockmeta_500000_999999.off", B256::ZERO), + // ("static_file_blockmeta_500000_999999.conf", B256::ZERO), + ], +]; diff --git a/crates/stages/stages/src/stages/s3/mod.rs b/crates/stages/stages/src/stages/s3/mod.rs new file mode 100644 index 000000000000..9acddb300626 --- /dev/null +++ b/crates/stages/stages/src/stages/s3/mod.rs @@ -0,0 +1,187 @@ +mod downloader; +pub use downloader::{fetch, Metadata}; + +mod filelist; +use filelist::DOWNLOAD_FILE_LIST; + +use reth_db::{cursor::DbCursorRW, tables, transaction::DbTxMut}; +use reth_primitives::StaticFileSegment; +use reth_provider::{ + BlockBodyIndicesProvider, DBProvider, StageCheckpointReader, StageCheckpointWriter, + StaticFileProviderFactory, +}; +use reth_stages_api::{ + ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput, +}; + +/// S3 `StageId` +const S3_STAGE_ID: StageId = StageId::Other("S3"); + +/// The s3 stage. +#[derive(Default, Debug, Clone)] +#[non_exhaustive] +pub struct S3Stage; + +impl Stage for S3Stage +where + Provider: DBProvider + + StaticFileProviderFactory + + StageCheckpointReader + + StageCheckpointWriter, +{ + fn id(&self) -> StageId { + S3_STAGE_ID + } + + fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result + where + Provider: DBProvider + + StaticFileProviderFactory + + StageCheckpointReader + + StageCheckpointWriter, + { + let cfg = "localhost:8000"; + let checkpoint = provider.get_stage_checkpoint(S3_STAGE_ID)?.unwrap_or_default(); + let static_file_provider = provider.static_file_provider(); + let mut tx_block_cursor = provider.tx_ref().cursor_write::()?; + + for block_range_files in &DOWNLOAD_FILE_LIST { + let (_, block_range) = + StaticFileSegment::parse_filename(block_range_files[0].0).expect("qed"); + + if block_range.end() <= checkpoint.block_number { + continue + } + + for (filename, file_hash) in block_range_files { + if static_file_provider.directory().join(filename).exists() { + // TODO: check hash if the file already exists? + continue + } + + fetch( + filename, + static_file_provider.directory(), + &format!("{cfg}/{filename}"), + std::thread::available_parallelism()?.get() as u64, + Some(*file_hash), + ) + .unwrap(); // TODO add DownloadError to StageError + } + + // Re-initializes the provider to detect the new additions + static_file_provider.initialize_index()?; + + // Populate TransactionBlock table + for block_number in block_range.start()..=block_range.end() { + // TODO: should be error if none + if let Some(indice) = static_file_provider.block_body_indices(block_number)? { + if indice.tx_count() > 0 { + tx_block_cursor.append(indice.last_tx_num(), &block_number)?; + } + } + } + + let checkpoint = + StageCheckpoint { block_number: block_range.end(), stage_checkpoint: None }; + provider.save_stage_checkpoint(StageId::Bodies, checkpoint)?; + provider.save_stage_checkpoint(S3_STAGE_ID, checkpoint)?; + } + + Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true }) + } + + fn unwind( + &mut self, + _provider: &Provider, + input: UnwindInput, + ) -> Result { + // TODO + Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::{ + ExecuteStageTestRunner, StageTestRunner, TestRunnerError, + TestStageDB, UnwindStageTestRunner, + }; + use reth_primitives::SealedHeader; + use reth_testing_utils::{ + generators, + generators::{random_header, random_header_range}, + }; + + // stage_test_suite_ext!(S3TestRunner, s3); + + #[derive(Default)] + struct S3TestRunner { + db: TestStageDB, + } + + impl StageTestRunner for S3TestRunner { + type S = S3Stage; + + fn db(&self) -> &TestStageDB { + &self.db + } + + fn stage(&self) -> Self::S { + S3Stage + } + } + + impl ExecuteStageTestRunner for S3TestRunner { + type Seed = Vec; + + fn seed_execution(&mut self, input: ExecInput) -> Result { + let start = input.checkpoint().block_number; + let mut rng = generators::rng(); + let head = random_header(&mut rng, start, None); + self.db.insert_headers_with_td(std::iter::once(&head))?; + + // use previous progress as seed size + let end = input.target.unwrap_or_default() + 1; + + if start + 1 >= end { + return Ok(Vec::default()) + } + + let mut headers = random_header_range(&mut rng, start + 1..end, head.hash()); + self.db.insert_headers_with_td(headers.iter())?; + headers.insert(0, head); + Ok(headers) + } + + fn validate_execution( + &self, + input: ExecInput, + output: Option, + ) -> Result<(), TestRunnerError> { + if let Some(output) = output { + assert!(output.done, "stage should always be done"); + assert_eq!( + output.checkpoint.block_number, + input.target(), + "stage progress should always match progress of previous stage" + ); + } + Ok(()) + } + } + + impl UnwindStageTestRunner for S3TestRunner { + fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> { + Ok(()) + } + } + + #[test] + fn parse_files() { + for block_range_files in &DOWNLOAD_FILE_LIST { + let (_, _) = StaticFileSegment::parse_filename(block_range_files[0].0).expect("qed"); + } + } +}