Skip to content

Commit

Permalink
fixed model to model_id in vllm and llamafile
Browse files Browse the repository at this point in the history
  • Loading branch information
sariola committed Oct 8, 2024
1 parent f0598d3 commit 1b56a93
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 16 deletions.
2 changes: 1 addition & 1 deletion flow_judge/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(

if self.device == "cpu":
logger.warning(

Check warning on line 109 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L108-L109

Added lines #L108 - L109 were not covered by tests
"Running the FlowJudgeHFModel on CPU may result in longer inference times."
"Running Hf on CPU may result in longer inference times."
)

self.batch_size = 1 # Default to 1, will be updated in batch_generate

Check warning on line 113 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L113

Added line #L113 was not covered by tests
Expand Down
16 changes: 8 additions & 8 deletions flow_judge/models/llamafile.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Llamafile(BaseFlowJudgeModel, AsyncBaseFlowJudgeModel):
"""Combined FlowJudge model class for Llamafile supporting both sync and async operations.
Args:
model (str, optional): The model ID to use. Defaults to "sariola/flow-judge-llamafile".
model_id (str, optional): The model ID to use. Defaults to "sariola/flow-judge-llamafile".
generation_params (Dict[str, Any], optional): Generation parameters.
cache_dir (str, optional): Directory to cache the model. Defaults to "~/.cache/flow-judge".
port (int, optional): Port to run the Llamafile server on. Defaults to 8085.
Expand All @@ -73,7 +73,7 @@ class Llamafile(BaseFlowJudgeModel, AsyncBaseFlowJudgeModel):

def __init__(

Check warning on line 74 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L74

Added line #L74 was not covered by tests
self,
model: str = None,
model_id: str = None,
generation_params: Dict[str, Any] = None,
cache_dir: str = os.path.expanduser("~/.cache/flow-judge"),
port: int = 8085,
Expand All @@ -94,21 +94,21 @@ def __init__(

default_model_id = "sariola/flow-judge-llamafile"

Check warning on line 95 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L95

Added line #L95 was not covered by tests

if model is not None and model != default_model_id:
if model_id is not None and model_id != default_model_id:
warnings.warn(

Check warning on line 98 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L97-L98

Added lines #L97 - L98 were not covered by tests
f"The model '{model}' is not officially supported. "
f"The model '{model_id}' is not officially supported. "
f"This library is designed for the '{default_model_id}' model. "
"Using other models may lead to unexpected behavior, and we do not handle "
"GitHub issues for unsupported models. Proceed with caution.",
UserWarning
)

model = model or default_model_id
model_id = model_id or default_model_id

Check warning on line 106 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L106

Added line #L106 was not covered by tests

generation_params = GenerationParams(**(generation_params or {}))

Check warning on line 108 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L108

Added line #L108 was not covered by tests

config = LlamafileConfig(

Check warning on line 110 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L110

Added line #L110 was not covered by tests
model_id=model,
model_id=model_id,
generation_params=generation_params,
model_filename="flow-judge.llamafile",
cache_dir=cache_dir,
Expand All @@ -120,7 +120,7 @@ def __init__(
**kwargs,
)

super().__init__(model, "llamafile", config.generation_params, **kwargs)
super().__init__(model_id, "llamafile", config.generation_params, **kwargs)

Check warning on line 123 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L123

Added line #L123 was not covered by tests

try:
self.generation_params = config.generation_params
Expand All @@ -147,7 +147,7 @@ def __init__(
self.llamafile_server_kwargs = config.llamafile_server_kwargs

Check warning on line 147 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L144-L147

Added lines #L144 - L147 were not covered by tests

self.metadata = {

Check warning on line 149 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L149

Added line #L149 was not covered by tests
"model_id": model,
"model_id": model_id,
"model_type": "llamafile",
}

Expand Down
13 changes: 6 additions & 7 deletions flow_judge/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Vllm(BaseFlowJudgeModel, AsyncBaseFlowJudgeModel):

def __init__(

Check warning on line 55 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L55

Added line #L55 was not covered by tests
self,
model: str = None,
model_id: str = None,
generation_params: Dict[str, Any] = None,
quantized: bool = True,
exec_async: bool = False,
Expand All @@ -70,24 +70,23 @@ def __init__(

default_model_id = "flowaicom/Flow-Judge-v0.1"

Check warning on line 71 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L71

Added line #L71 was not covered by tests

if model is not None and model != default_model_id:
if model_id is not None and model_id != default_model_id:
warnings.warn(

Check warning on line 74 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L73-L74

Added lines #L73 - L74 were not covered by tests
f"The model '{model}' is not officially supported. "
f"The model '{model_id}' is not officially supported. "
f"This library is designed for the '{default_model_id}' model. "
"Using other models may lead to unexpected behavior, and we do not handle "
"GitHub issues for unsupported models. Proceed with caution.",
UserWarning
)

model = model or default_model_id
# Only append "-AWQ" if it's the default model and quantization is enabled
model_id = f"{model}-AWQ" if quantized and model == default_model_id else model
model_id = model_id or default_model_id
model_id = f"{model_id}-AWQ" if quantized and model_id == default_model_id else model_id

Check warning on line 83 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L82-L83

Added lines #L82 - L83 were not covered by tests

generation_params = GenerationParams(**(generation_params or {}))

Check warning on line 85 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L85

Added line #L85 was not covered by tests

config = VllmConfig(model_id=model_id, generation_params=generation_params, quantization=quantized, exec_async=exec_async, **kwargs)

Check warning on line 87 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L87

Added line #L87 was not covered by tests

super().__init__(model, "vllm", config.generation_params, **kwargs)
super().__init__(model_id, "vllm", config.generation_params, **kwargs)

Check warning on line 89 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L89

Added line #L89 was not covered by tests

self.exec_async = exec_async
self.generation_params = config.generation_params

Check warning on line 92 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L91-L92

Added lines #L91 - L92 were not covered by tests
Expand Down

0 comments on commit 1b56a93

Please sign in to comment.