Skip to content

Commit

Permalink
commit everything
Browse files Browse the repository at this point in the history
  • Loading branch information
sxlijin committed Jul 13, 2024
1 parent fa70f6e commit f97deb4
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 163 deletions.
4 changes: 3 additions & 1 deletion engine/baml-runtime/src/tracing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ impl BamlTracer {
}

if let Some(tracer) = &self.tracer {
tracer.submit(response.to_log_schema(&self.options, event_chain, tags, span))?;
tracer
.submit(response.to_log_schema(&self.options, event_chain, tags, span))
.context("Error while submitting span for delivery")?;
guard.finalize();
Ok(Some(span_id))
} else {
Expand Down
294 changes: 178 additions & 116 deletions engine/baml-runtime/src/tracing/threaded_tracer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use anyhow::Result;
use std::sync::{mpsc, Arc, Mutex};
use tokio::sync::watch;
use futures::future::join_all;
use std::{
cell::RefCell,
ops::DerefMut,
sync::{
mpsc::{self, RecvTimeoutError},
Arc, Mutex,
},
};
use tokio::{
runtime::{self, Handle},
sync::{oneshot, watch},
};
use web_time::{Duration, Instant};

use crate::{
Expand All @@ -14,9 +25,8 @@ use super::api_wrapper::{core_types::LogSchema, APIConfig, APIWrapper, BoundaryA
const MAX_TRACE_SEND_CONCURRENCY: usize = 10;

enum TxEventSignal {
#[allow(dead_code)]
Stop,
Flush(u128),
Noop,
Stop(oneshot::Sender<TraceStats>),
Submit(LogSchema),
}

Expand All @@ -27,36 +37,14 @@ enum ProcessorStatus {

struct DeliveryThread {
api_config: Arc<APIWrapper>,
span_rx: mpsc::Receiver<TxEventSignal>,
stop_tx: watch::Sender<ProcessorStatus>,
rt: tokio::runtime::Runtime,
max_batch_size: usize,
max_concurrency: Arc<tokio::sync::Semaphore>,
stats: TraceStats,
}

impl DeliveryThread {
fn new(
api_config: APIWrapper,
span_rx: mpsc::Receiver<TxEventSignal>,
stop_tx: watch::Sender<ProcessorStatus>,
max_batch_size: usize,
stats: TraceStats,
) -> Self {
let rt = tokio::runtime::Runtime::new().unwrap();

Self {
api_config: Arc::new(api_config),
span_rx,
stop_tx,
rt,
max_batch_size,
max_concurrency: tokio::sync::Semaphore::new(MAX_TRACE_SEND_CONCURRENCY).into(),
stats,
}
}

async fn process_batch(&self, batch: Vec<LogSchema>) {
// TODO: this needs to submit stuff to the runtime
fn process_batch(&self, rt: &Handle, batch: Vec<LogSchema>) {
let work = batch
.into_iter()
.map(|work| {
Expand Down Expand Up @@ -91,52 +79,75 @@ impl DeliveryThread {
})
.collect::<Vec<_>>();

// Wait for all the futures to complete
futures::future::join_all(work).await;
rt.spawn(join_all(work));
}

fn run(&self) {
let mut batch = Vec::with_capacity(self.max_batch_size);
let mut now = Instant::now();
fn run(&mut self, mut span_rx: std::sync::mpsc::Receiver<TxEventSignal>) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.thread_name("tracing-delivery")
.build()
.expect("Failed to start tracing thread");

log::debug!("[DeliveryThread] starting");
for i in 0..10 {
log::trace!(
"[DeliveryThread] startup loop: {} is receiver closed? {}",
i,
"no idea"
);
std::thread::sleep(Duration::from_secs(1));
}
loop {
log::trace!("[DeliveryThread] looping");
// Try to fill the batch up to max_batch_size
let (batch_full, flush, exit) =
match self.span_rx.recv_timeout(Duration::from_millis(100)) {
Ok(TxEventSignal::Submit(work)) => {
self.stats.guard().submit();
batch.push(work);
(batch.len() >= self.max_batch_size, None, false)
}
Ok(TxEventSignal::Flush(id)) => (false, Some(id), false),
Ok(TxEventSignal::Stop) => (false, None, true),
Err(mpsc::RecvTimeoutError::Timeout) => (false, None, false),
Err(mpsc::RecvTimeoutError::Disconnected) => (false, None, true),
};

let time_trigger = now.elapsed().as_millis() >= 1000;

let should_process_batch =
(batch_full || flush.is_some() || exit || time_trigger) && !batch.is_empty();

// Send events every 1 second or when the batch is full
if should_process_batch {
self.rt
.block_on(self.process_batch(std::mem::take(&mut batch)));
}

if should_process_batch || time_trigger {
now = std::time::Instant::now();
}

if let Some(id) = flush {
match self.stop_tx.send(ProcessorStatus::Done(id)) {
Ok(_) => {}
Err(e) => {
log::error!("Error sending flush signal: {:?}", e);
}
// let exit = match span_rx.blocking_recv() {
// Some(TxEventSignal::Submit(work)) => {
// log::debug!("[DeliveryThread] SUBMIT received");
// self.stats.guard().submit();
// self.process_batch(rt.handle(), vec![work]);

// false
// }
// Some(TxEventSignal::Noop) => {
// log::trace!("[DeliveryThread] NOOP recv");
// false
// }
// Some(TxEventSignal::Stop(sender)) => {
// let _ = sender.send(self.stats.clone());
// log::trace!("[DeliveryThread] STOP recv");
// true
// }
// None => true,
// };
let exit = match span_rx.recv_timeout(Duration::from_secs(1)) {
Ok(TxEventSignal::Submit(work)) => {
log::debug!("[DeliveryThread] SUBMIT received");
self.process_batch(rt.handle(), vec![work]);

false
}
}
Ok(TxEventSignal::Noop) => {
log::trace!("[DeliveryThread] NOOP recv");
false
}
Ok(TxEventSignal::Stop(sender)) => {
log::trace!("[DeliveryThread] STOP recv");
let _ = sender.send(self.stats.clone());
true
}
Err(RecvTimeoutError::Timeout) => {
log::trace!("[DeliveryThread] Error receiving from channel: timeout");
false
}
Err(RecvTimeoutError::Disconnected) => {
log::trace!("[DeliveryThread] Error receiving from channel: disconnected");
true
}
};

if exit {
log::trace!("[DeliveryThread] exiting");
return;
}
}
Expand All @@ -145,72 +156,108 @@ impl DeliveryThread {

pub(super) struct ThreadedTracer {
api_config: Arc<APIWrapper>,
span_tx: mpsc::Sender<TxEventSignal>,
stop_rx: watch::Receiver<ProcessorStatus>,
// span_tx: tokio::sync::mpsc::UnboundedSender<TxEventSignal>,
span_tx: std::sync::mpsc::Sender<TxEventSignal>,
// stop_rx: watch::Receiver<ProcessorStatus>,
#[allow(dead_code)]
join_handle: std::thread::JoinHandle<()>,
join_handle: Option<std::thread::JoinHandle<()>>,
log_event_callback: Arc<Mutex<Option<LogEventCallbackSync>>>,
stats: TraceStats,
}

impl ThreadedTracer {
fn start_worker(
api_config: APIWrapper,
max_batch_size: usize,
stats: TraceStats,
) -> (
mpsc::Sender<TxEventSignal>,
watch::Receiver<ProcessorStatus>,
std::thread::JoinHandle<()>,
) {
let (span_tx, span_rx) = mpsc::channel();
let (stop_tx, stop_rx) = watch::channel(ProcessorStatus::Active);
let join_handle = std::thread::spawn(move || {
DeliveryThread::new(api_config, span_rx, stop_tx, max_batch_size, stats).run();
});
pub fn new(api_config: &APIWrapper, max_batch_size: usize, stats: TraceStats) -> Self {
// let (span_tx, span_rx) = tokio::sync::mpsc::unbounded_channel();
let (span_tx, span_rx) = std::sync::mpsc::channel();
// let (stop_tx, stop_rx) = watch::channel(ProcessorStatus::Active);

(span_tx, stop_rx, join_handle)
}
let api_config = Arc::new(api_config.clone());

pub fn new(api_config: &APIWrapper, max_batch_size: usize, stats: TraceStats) -> Self {
let (span_tx, stop_rx, join_handle) =
Self::start_worker(api_config.clone(), max_batch_size, stats.clone());
let mut t = DeliveryThread {
api_config: api_config.clone(),
max_batch_size,
max_concurrency: tokio::sync::Semaphore::new(MAX_TRACE_SEND_CONCURRENCY).into(),
stats: stats.clone(),
};

std::thread::spawn(move || {
t.run(span_rx);
});

Self {
api_config: Arc::new(api_config.clone()),
api_config,
span_tx,
stop_rx,
join_handle,
//stop_rx,
join_handle: None,
log_event_callback: Arc::new(Mutex::new(None)),
stats,
}
}

pub fn flush(&self) -> Result<()> {
let id = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
self.span_tx.send(TxEventSignal::Flush(id))?;

let flush_start = Instant::now();

while flush_start.elapsed() < Duration::from_secs(60) {
{
match *self.stop_rx.borrow() {
ProcessorStatus::Active => {}
ProcessorStatus::Done(r_id) if r_id >= id => {
return Ok(());
}
ProcessorStatus::Done(id) => {
// Old flush, ignore
}
}
// let id = std::time::SystemTime::now()
// .duration_since(std::time::UNIX_EPOCH)
// .unwrap()
// .as_millis();
// log::debug!("Asking delivery thread to flush events");
// self.span_tx.send(TxEventSignal::Flush(id))?;

// let flush_start = Instant::now();

// while flush_start.elapsed() < Duration::from_secs(60) {
// {
// match *self.stop_rx.borrow() {
// ProcessorStatus::Active => {}
// ProcessorStatus::Done(r_id) if r_id >= id => {
// return Ok(());
// }
// ProcessorStatus::Done(id) => {
// // Old flush, ignore
// }
// }
// }
// std::thread::sleep(Duration::from_millis(100));
// }

// anyhow::bail!("BatchProcessor worker thread did not finish in time")
self.shutdown()
}

// pub fn shutdown(&self) -> Result<()> {
// let mut locked = self.runtime.lock().unwrap();
// match *locked {
// Some(ref t) => log::debug!(
// "Asking delivery thread to stop, runtime status is {:#?}",
// t.metrics()
// ),
// None => {
// log::debug!("Asking delivery thread to stop, runtime has already been shutdown")
// }
// }
// self.span_tx.send(TxEventSignal::Stop)?;

// let Some(runtime) = std::mem::take(locked.deref_mut()) else {
// anyhow::bail!("ThreadedTracer has already been shutdown");
// };
// runtime.shutdown_timeout(Duration::from_secs(13));
// Ok(())
// }

pub fn shutdown(&self) -> Result<()> {
let (tx, rx) = oneshot::channel();

self.span_tx.send(TxEventSignal::Stop(tx))?;

match rx.blocking_recv() {
Ok(stats) => {
log::debug!("Received stats from delivery thread");
}
std::thread::sleep(Duration::from_millis(100));
}
Err(e) => {
log::error!("Error receiving handle from delivery thread: {:?}", e);
}
};

anyhow::bail!("BatchProcessor worker thread did not finish in time")
Ok(())
}

pub fn set_log_event_callback(&self, log_event_callback: LogEventCallbackSync) {
Expand All @@ -221,6 +268,21 @@ impl ThreadedTracer {
}

pub fn submit(&self, mut event: LogSchema) -> Result<()> {
log::debug!("submitting NOOPs during trace.submit");
for _ in 0..3 {
match self.span_tx.send(TxEventSignal::Noop) {
Ok(_) => {
log::debug!("NOOP sent to delivery thread");
if let Some(join_handle) = &self.join_handle {
join_handle.thread().unpark();
}
}
Err(e) => {
log::error!("Error sending NOOP to delivery thread: {:?}", e);
}
}
}

let callback = self.log_event_callback.lock().unwrap();
if let Some(ref callback) = *callback {
let event = event.clone();
Expand Down
Loading

0 comments on commit f97deb4

Please sign in to comment.