Skip to content

Commit

Permalink
add rpc timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
chzyer committed Dec 5, 2024
1 parent 69c8a11 commit 5906991
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 81 deletions.
4 changes: 3 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

189 changes: 111 additions & 78 deletions crates/prover/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use std::collections::BTreeMap;
use std::future::IntoFuture;
use std::sync::Arc;
use std::time::Instant;
use std::time::{Duration, Instant};

use crate::types::{DaApiServer, ProverV1ApiServer, ProverV2ApiServer};
use crate::{Collector, DaItemLockStatus, DaManager, Metadata, TaskManager, BUILD_TAG};

use alloy::primitives::Bytes;
use async_trait::async_trait;
use automata_sgx_sdk::dcap::dcap_quote;
use base::eth::{Eth, Keypair};
use base::format::debug;
use base::thread::wait_timeout;
use base::trace::Alive;
use base::eth::{Eth, Keypair};
use jsonrpsee::core::RpcResult;
use jsonrpsee::types::{ErrorObject, ErrorObjectOwned};
use jsonrpsee::RpcModule;
Expand All @@ -34,6 +36,7 @@ pub struct ProverApi {
pub pob_da: Arc<DaManager<Vec<Pob>>>,
pub metrics: Arc<Collector>,
pub keypair: Keypair,
pub request_timeout: Option<Duration>,

pub scroll: ScrollBatchVerifier,
pub linea: LineaBatchVerifier,
Expand All @@ -58,49 +61,70 @@ impl ProverApi {
pub fn err<M: Into<String>>(&self, code: i32, msg: M) -> ErrorObjectOwned {
ErrorObject::owned(code, msg, None::<()>)
}
}

#[async_trait]
impl ProverV1ApiServer for ProverApi {
async fn generate_attestation_report(&self, req: Bytes) -> RpcResult<Bytes> {
let mut data = [0_u8; 64];
if req.len() > 64 {
return Err(self.err(14002, "invalid report data"));
async fn wait<T, F>(&self, f: F) -> RpcResult<T>
where
F: IntoFuture<Output = Result<T, ErrorObjectOwned>>,
{
match wait_timeout(self.request_timeout, f).await {
Ok(n) => n,
Err(_) => Err(ErrorObject::owned(
14502,
format!("request timeout"),
None::<()>,
)),
}
data[64 - req.len()..].copy_from_slice(&req);
data[0..12].copy_from_slice(&[0_u8; 12]);
data[12..32].copy_from_slice(self.keypair.address().as_slice());

log::info!("report data: {:?}", Bytes::copy_from_slice(&data));
}

async fn inner_generate_context(
&self,
start_block: u64,
end_block: u64,
ty: u64,
) -> RpcResult<SuccinctPobList> {
let ty = TaskType::from_u64(ty);

let start = Instant::now();
let result = match ty {
TaskType::Scroll => self
.scroll
.generate_context(start_block, end_block)
.await
.map_err(jsonrpc_err(14004))?,
TaskType::Linea => self
.linea
.generate_context(start_block, end_block)
.await
.map_err(jsonrpc_err(14004))?,
TaskType::Other(_) => return Err(self.err(14005, format!("unknown task: {:?}", ty))),
};

let result = dcap_quote(data);
let pob_list = SuccinctPobList::compress(&result);
let gen_ctx_time = start.elapsed().as_millis() as f64;

self.pob_da
.put(pob_list.hash, Arc::new(result), POB_EXPIRED_SECS);

let data = serde_json::to_vec(&pob_list).unwrap();
self.metrics
.gen_attestation_report_ms
.pob_size
.lock()
.unwrap()
.set([], start.elapsed().as_millis() as f64);

match result {
Ok(quote) => Ok(quote.into()),
Err(err) => {
let msg = format!("generate report failed: {:?}", err);
return Err(self.err(14003, msg));
}
}
}

async fn get_poe(&self, tx_hash: B256) -> RpcResult<PoeResponse> {
self.prove_task_with_sample(tx_hash, None, self.sampling, TaskType::Scroll.u64())
.await
.set([ty.name()], data.len() as _);
self.metrics
.counter_gen_ctx
.lock()
.unwrap()
.inc([ty.name()]);
self.metrics
.gauge_gen_ctx_ms
.lock()
.unwrap()
.set([ty.name()], gen_ctx_time);
Ok(pob_list)
}
}

#[async_trait]
impl ProverV2ApiServer for ProverApi {
async fn prove_task(&self, params: ProveTaskParams) -> RpcResult<PoeResponse> {
async fn inner_prove_task(&self, params: ProveTaskParams) -> RpcResult<PoeResponse> {
let ty = TaskType::from_opu64(params.task_type);

let pob_list = self
Expand All @@ -124,7 +148,10 @@ impl ProverV2ApiServer for ProverApi {
let result = match ty {
TaskType::Scroll => self
.scroll
.prove(pob_list.as_slice(), params.batch().map_err(jsonrpc_err(14001))?)
.prove(
pob_list.as_slice(),
params.batch().map_err(jsonrpc_err(14001))?,
)
.await
.map_err(debug),
TaskType::Linea => self.linea.prove(&pob_list, params).await.map_err(debug),
Expand Down Expand Up @@ -155,13 +182,57 @@ impl ProverV2ApiServer for ProverApi {
poe_signature: Some(sig.to_vec().into()),
})
}
}

#[async_trait]
impl ProverV1ApiServer for ProverApi {
async fn generate_attestation_report(&self, req: Bytes) -> RpcResult<Bytes> {
let mut data = [0_u8; 64];
if req.len() > 32 {
return Err(self.err(14002, "invalid report data"));
}
data[32 - req.len()..].copy_from_slice(&req);
data[12..32].copy_from_slice(self.keypair.address().as_slice());

log::info!("report data: {:?}", data);

let start = Instant::now();

let result = dcap_quote(data);

self.metrics
.gen_attestation_report_ms
.lock()
.unwrap()
.set([], start.elapsed().as_millis() as f64);

match result {
Ok(quote) => Ok(quote.into()),
Err(err) => {
let msg = format!("generate report failed: {:?}", err);
return Err(self.err(14003, msg));
}
}
}

async fn get_poe(&self, tx_hash: B256) -> RpcResult<PoeResponse> {
self.wait(self.prove_task_with_sample(tx_hash, None, self.sampling, TaskType::Scroll.u64()))
.await
}
}

#[async_trait]
impl ProverV2ApiServer for ProverApi {
async fn prove_task(&self, params: ProveTaskParams) -> RpcResult<PoeResponse> {
self.wait(self.inner_prove_task(params)).await
}

async fn prove_task_without_context(
&self,
task_data: Bytes,
ty: u64,
) -> RpcResult<PoeResponse> {
self.prove_task_with_sample(B256::default(), Some(task_data), 0, ty)
self.wait(self.prove_task_with_sample(B256::default(), Some(task_data), 0, ty))
.await
}

Expand All @@ -171,46 +242,8 @@ impl ProverV2ApiServer for ProverApi {
end_block: u64,
ty: u64,
) -> RpcResult<SuccinctPobList> {
let ty = TaskType::from_u64(ty);

let start = Instant::now();
let result = match ty {
TaskType::Scroll => self
.scroll
.generate_context(start_block, end_block)
.await
.map_err(jsonrpc_err(14004))?,
TaskType::Linea => self
.linea
.generate_context(start_block, end_block)
.await
.map_err(jsonrpc_err(14004))?,
TaskType::Other(_) => return Err(self.err(14005, format!("unknown task: {:?}", ty))),
};

let pob_list = SuccinctPobList::compress(&result);
let gen_ctx_time = start.elapsed().as_millis() as f64;

self.pob_da
.put(pob_list.hash, Arc::new(result), POB_EXPIRED_SECS);

let data = serde_json::to_vec(&pob_list).unwrap();
self.metrics
.pob_size
.lock()
.unwrap()
.set([ty.name()], data.len() as _);
self.metrics
.counter_gen_ctx
.lock()
.unwrap()
.inc([ty.name()]);
self.metrics
.gauge_gen_ctx_ms
.lock()
.unwrap()
.set([ty.name()], gen_ctx_time);
Ok(pob_list)
self.wait(self.inner_generate_context(start_block, end_block, ty))
.await
}

async fn metadata(&self) -> RpcResult<Metadata> {
Expand Down Expand Up @@ -284,7 +317,7 @@ impl ProverApi {
println!("task: {:?}", batch_task);

let pob_list = self
.generate_context(
.inner_generate_context(
batch_task.start().unwrap(),
batch_task.end().unwrap(),
ty.u64(),
Expand All @@ -297,7 +330,7 @@ impl ProverApi {
);

let poe = self
.prove_task(ProveTaskParams {
.inner_prove_task(ProveTaskParams {
batch: Some(task_data),
pob_hash: pob_list.hash,
start: None,
Expand Down
3 changes: 2 additions & 1 deletion crates/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ pub use task_manager::*;
mod metrics;
pub use metrics::*;

use base::{eth::Keypair, trace::Alive};
use base::eth::Eth;
use base::{eth::Keypair, trace::Alive};
use jsonrpsee::{
server::{tower, ServerBuilder, TlsLayer},
Methods,
Expand Down Expand Up @@ -82,6 +82,7 @@ pub async fn entrypoint() {
pobda_task_mgr: Arc::new(TaskManager::new(100)),
pob_da: Arc::new(DaManager::new()),
metrics: collector.clone(),
request_timeout: Some(Duration::from_secs(300)),
keypair,
};

Expand Down
9 changes: 8 additions & 1 deletion crates/prover/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub struct Config {

#[serde(default = "default_l2_timeout_secs")]
pub l2_timeout_secs: u64,
#[serde(default = "default_req_timeout_secs")]
pub req_timeout_secs: u64,
}

impl Config {
Expand Down Expand Up @@ -79,6 +81,10 @@ pub fn get_timeout(timeout_secs: u64) -> Option<Duration> {
}
}

fn default_req_timeout_secs() -> u64 {
300
}

fn default_l2_timeout_secs() -> u64 {
60
}
Expand All @@ -89,7 +95,8 @@ pub trait ProverV2Api {
async fn prove_task(&self, arg: ProveTaskParams) -> RpcResult<PoeResponse>;

#[method(name = "proveTaskWithoutContext")]
async fn prove_task_without_context(&self, task_data: Bytes, ty: u64) -> RpcResult<PoeResponse>;
async fn prove_task_without_context(&self, task_data: Bytes, ty: u64)
-> RpcResult<PoeResponse>;

#[method(name = "genContext")]
async fn generate_context(
Expand Down

0 comments on commit 5906991

Please sign in to comment.