From 5899f4496e873d726cbb1bbfbc5ccc1e7af405ec Mon Sep 17 00:00:00 2001 From: rizar Date: Tue, 7 Jan 2025 03:04:06 +0000 Subject: [PATCH] should fix the loss of data --- examples/rl_gsm8k/orchestrate_rl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index 237b85c8..fbaca406 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -430,15 +430,15 @@ def main(cfg: DictConfig): ] start_basemodel_logprobs = time.time() - all_traces = all_results["train"]["training_samples"] + training_samples = all_results["train"]["training_samples"] with ThreadPoolExecutor( max_workers=cfg.get_logprobs_workers_per_gpu * torch.cuda.device_count() ) as executor: chunk_size = 64 futures = [] - for chunk_id, chunk_offset in enumerate(range(0, len(all_traces), chunk_size)): + for chunk_id, chunk_offset in enumerate(range(0, len(training_samples), chunk_size)): ref_llm = ref_llms[chunk_id % len(ref_llms)] - chunk = all_traces[chunk_offset: chunk_offset + chunk_size] + chunk = training_samples[chunk_offset: chunk_offset + chunk_size] futures.append( executor.submit(batch_annotate_traces_with_ref_logprobs, ref_llm, chunk) )