diff --git a/flow_judge/models/huggingface.py b/flow_judge/models/huggingface.py index 5d97516..7e7fa90 100644 --- a/flow_judge/models/huggingface.py +++ b/flow_judge/models/huggingface.py @@ -107,7 +107,7 @@ def __init__( if self.device == "cpu": logger.warning( - "Running the FlowJudgeHFModel on CPU may result in longer inference times." + "Running Hf on CPU may result in longer inference times." ) self.batch_size = 1 # Default to 1, will be updated in batch_generate diff --git a/flow_judge/models/llamafile.py b/flow_judge/models/llamafile.py index 4bbaad0..7005e71 100644 --- a/flow_judge/models/llamafile.py +++ b/flow_judge/models/llamafile.py @@ -60,7 +60,7 @@ class Llamafile(BaseFlowJudgeModel, AsyncBaseFlowJudgeModel): """Combined FlowJudge model class for Llamafile supporting both sync and async operations. Args: - model (str, optional): The model ID to use. Defaults to "sariola/flow-judge-llamafile". + model_id (str, optional): The model ID to use. Defaults to "sariola/flow-judge-llamafile". generation_params (Dict[str, Any], optional): Generation parameters. cache_dir (str, optional): Directory to cache the model. Defaults to "~/.cache/flow-judge". port (int, optional): Port to run the Llamafile server on. Defaults to 8085. @@ -73,7 +73,7 @@ class Llamafile(BaseFlowJudgeModel, AsyncBaseFlowJudgeModel): def __init__( self, - model: str = None, + model_id: str = None, generation_params: Dict[str, Any] = None, cache_dir: str = os.path.expanduser("~/.cache/flow-judge"), port: int = 8085, @@ -94,21 +94,21 @@ def __init__( default_model_id = "sariola/flow-judge-llamafile" - if model is not None and model != default_model_id: + if model_id is not None and model_id != default_model_id: warnings.warn( - f"The model '{model}' is not officially supported. " + f"The model '{model_id}' is not officially supported. " f"This library is designed for the '{default_model_id}' model. " "Using other models may lead to unexpected behavior, and we do not handle " "GitHub issues for unsupported models. Proceed with caution.", UserWarning ) - model = model or default_model_id + model_id = model_id or default_model_id generation_params = GenerationParams(**(generation_params or {})) config = LlamafileConfig( - model_id=model, + model_id=model_id, generation_params=generation_params, model_filename="flow-judge.llamafile", cache_dir=cache_dir, @@ -120,7 +120,7 @@ def __init__( **kwargs, ) - super().__init__(model, "llamafile", config.generation_params, **kwargs) + super().__init__(model_id, "llamafile", config.generation_params, **kwargs) try: self.generation_params = config.generation_params @@ -147,7 +147,7 @@ def __init__( self.llamafile_server_kwargs = config.llamafile_server_kwargs self.metadata = { - "model_id": model, + "model_id": model_id, "model_type": "llamafile", } diff --git a/flow_judge/models/vllm.py b/flow_judge/models/vllm.py index 7e923ec..5d4f397 100644 --- a/flow_judge/models/vllm.py +++ b/flow_judge/models/vllm.py @@ -54,7 +54,7 @@ class Vllm(BaseFlowJudgeModel, AsyncBaseFlowJudgeModel): def __init__( self, - model: str = None, + model_id: str = None, generation_params: Dict[str, Any] = None, quantized: bool = True, exec_async: bool = False, @@ -70,24 +70,23 @@ def __init__( default_model_id = "flowaicom/Flow-Judge-v0.1" - if model is not None and model != default_model_id: + if model_id is not None and model_id != default_model_id: warnings.warn( - f"The model '{model}' is not officially supported. " + f"The model '{model_id}' is not officially supported. " f"This library is designed for the '{default_model_id}' model. " "Using other models may lead to unexpected behavior, and we do not handle " "GitHub issues for unsupported models. Proceed with caution.", UserWarning ) - model = model or default_model_id - # Only append "-AWQ" if it's the default model and quantization is enabled - model_id = f"{model}-AWQ" if quantized and model == default_model_id else model + model_id = model_id or default_model_id + model_id = f"{model_id}-AWQ" if quantized and model_id == default_model_id else model_id generation_params = GenerationParams(**(generation_params or {})) config = VllmConfig(model_id=model_id, generation_params=generation_params, quantization=quantized, exec_async=exec_async, **kwargs) - super().__init__(model, "vllm", config.generation_params, **kwargs) + super().__init__(model_id, "vllm", config.generation_params, **kwargs) self.exec_async = exec_async self.generation_params = config.generation_params