Skip to content

Commit

Permalink
#27 Fix obj serialization for saving (#29)
Browse files Browse the repository at this point in the history
* fix: serialize obj for saving

* fix: add log warning back to steam

* fix: quality improvements to result_writer & change to readme for baseten extras & pyproject build ignore img dir

* fix: rm non-allowed exclude key from pyproject.toml

* feat: vllm pinned & test result writing unit + e2e

* fix: equate eval outputs and inputs

* fix: handle_batch_results docstring

* fix: failing baseten test

---------

Co-authored-by: sariola <karolus@flowrite.com>
  • Loading branch information
minaamshahid and sariola authored Oct 29, 2024
1 parent 9865fc0 commit f143379
Show file tree
Hide file tree
Showing 11 changed files with 1,011 additions and 193 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Extras available:
- `hf` to install Hugging Face Transformers dependencies
- `vllm` to install vLLM dependencies
- `llamafile` to install Llamafile dependencies
- `baseten` to install Baseten dependencies

## Quick Start

Expand Down
2 changes: 1 addition & 1 deletion flow_judge/eval_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class EvalOutput(BaseModel):
"""Output model for evaluation results."""

feedback: str = Field(..., description="Feedback from the evaluation")
score: int = Field(..., description="Numeric score from the evaluation")
score: int | None = Field(..., description="Numeric score from the evaluation")

@classmethod
def parse(cls, response: str, fail_on_parse_error: bool = False) -> "EvalOutput":
Expand Down
108 changes: 82 additions & 26 deletions flow_judge/flow_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,18 @@ def _validate_inputs(self, eval_inputs: EvalInput | list[EvalInput]):
else:
validate_eval_input(eval_inputs, self.metric)

def _save_results(self, eval_inputs: list[EvalInput], eval_outputs: list[EvalOutput]):
def _save_results(
self, eval_inputs: list[EvalInput], eval_outputs: list[EvalOutput], append: bool = False
):
"""Save results to disk."""
logger.info(f"Saving results to {self.output_dir}")
logger.info(f"{'Appending' if append else 'Saving'} results to {self.output_dir}")
write_results_to_disk(
eval_inputs, eval_outputs, self.model.metadata, self.metric.name, self.output_dir
eval_inputs,
eval_outputs,
self.model.metadata,
self.metric.name,
self.output_dir,
append=append,
)


Expand Down Expand Up @@ -122,8 +129,48 @@ def __init__(
if not isinstance(model, AsyncBaseFlowJudgeModel):
raise ValueError("Invalid model type. Use AsyncBaseFlowJudgeModel or its subclasses.")

def _handle_batch_result(
self, batch_result: BatchResult, batch_len: int, fail_on_parse_error: bool
) -> list[EvalOutput]:
"""Handle output parsing for batched results.
Args:
batch_result: The result of the batch from Baseten.
batch_len: The initial batch size derived from the length of Eval Inputs.
fail_on_parse_error: Flag to raise a parse error for the EvalOutput.
Returns:
list[EvalOutput]: A list of eval outputs with score and feedback.
Note:
There might be instances when downstream errors result in missing entries
for the eval outputs. We implement retry strategies where we can, but in
certain instances (such as network failures) errors are inevitable.
To ascertain predictability, we 'fill-in' the errors with empty EvalOutputs.
"""
eval_outputs = [EvalOutput(feedback="BasetenError", score=None)] * batch_len
for output in batch_result.successful_outputs:
index = output.get("index")
eval_outputs[index - 1] = EvalOutput.parse(
response=output["response"], fail_on_parse_error=fail_on_parse_error
)

# Log all downstream errors
if len(batch_result.errors) > 0:
logger.warning(
f"Number of Baseten API errors: {len(batch_result.errors)}"
f" of {batch_result.total_requests}."
f" Success rate is {batch_result.success_rate}"
" List of errors: "
)
for error in batch_result.errors:
logger.warning(f"{error.error_type}: {error.error_message}")

return eval_outputs

async def async_evaluate(
self, eval_input: EvalInput, save_results: bool = False
self, eval_input: EvalInput, save_results: bool = False, append: bool = False
) -> EvalOutput | None:
"""Evaluate a single EvalInput object asynchronously."""
try:
Expand All @@ -132,14 +179,16 @@ async def async_evaluate(
result = await self.model._async_generate(prompt)
response = result

# If there are Baseten errors we log & return here.
if isinstance(result, FlowJudgeError):
logger.error(f" {result.error_type}: {result.error_message}")
return

eval_output = EvalOutput.parse(response)
if save_results:
await asyncio.to_thread(self._save_results, [eval_input], [eval_output])
logger.info(f"Saving result {'(append)' if append else '(overwrite)'}")
await asyncio.to_thread(
self._save_results, [eval_input], [eval_output], append=append
)
return eval_output
except Exception as e:
logger.error(f"Asynchronous evaluation failed: {e}")
Expand All @@ -151,36 +200,43 @@ async def async_batch_evaluate(
eval_inputs: list[EvalInput],
use_tqdm: bool = True,
save_results: bool = True,
append: bool = False, # Change default to False
fail_on_parse_error: bool = False,
) -> list[EvalOutput]:
"""Batch evaluate a list of EvalInput objects asynchronously."""
self._validate_inputs(eval_inputs)
prompts = [self._format_prompt(eval_input) for eval_input in eval_inputs]
batch_result = await self.model._async_batch_generate(prompts, use_tqdm=use_tqdm)
responses = batch_result

if isinstance(responses, BatchResult):
responses = [result["response"] for result in batch_result.successful_outputs]
if len(batch_result.errors) > 0:
logger.warning(
f"Number of Baseten API errors: {len(batch_result.errors)}"
f" of {batch_result.total_requests}."
f" Success rate is {batch_result.success_rate}"
"List of errors: "
)
for error in batch_result.errors:
logger.warning(f"{error.error_type}: {error.error_message}")

eval_outputs = [
EvalOutput.parse(response, fail_on_parse_error=fail_on_parse_error)
for response in responses
]
if isinstance(batch_result, BatchResult):
eval_outputs = self._handle_batch_result(
batch_result=batch_result,
batch_len=len(eval_inputs),
fail_on_parse_error=fail_on_parse_error,
)
else:
eval_outputs = [
EvalOutput.parse(response, fail_on_parse_error=fail_on_parse_error)
for response in batch_result
]
logger.warning(f"{eval_outputs}")
parse_failures = sum(1 for output in eval_outputs if output.score and output.score == -1)

parse_failures = sum(1 for output in eval_outputs if output.score == -1)
if save_results:
await asyncio.to_thread(self._save_results, eval_inputs, eval_outputs)
logger.info(f"Saving {len(eval_outputs)} results")
for i, (eval_input, eval_output) in enumerate(
zip(eval_inputs, eval_outputs, strict=True)
):
await asyncio.to_thread(
self._save_results,
[eval_input],
[eval_output],
append=(append or i > 0), # Append for all but the first, unless append is True
)

if parse_failures > 0:
logger.warning(f"Number of parsing failures: {parse_failures} out of {len(responses)}")
logger.warning(
f"Number of parsing failures: {parse_failures} out of {len(eval_outputs)}"
)

return eval_outputs
Loading

0 comments on commit f143379

Please sign in to comment.