Skip to content

Commit

Permalink
standardized genparams & non-supported model warning & llamafile quan…
Browse files Browse the repository at this point in the history
…t kv + fa
  • Loading branch information
sariola committed Oct 8, 2024
1 parent 992ca46 commit a54bbee
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 67 deletions.
10 changes: 9 additions & 1 deletion flow_judge/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import Enum
from typing import Any, Dict, Optional

Check warning on line 3 in flow_judge/models/common.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/common.py#L2-L3

Added lines #L2 - L3 were not covered by tests

from pydantic import BaseModel

Check warning on line 5 in flow_judge/models/common.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/common.py#L5

Added line #L5 was not covered by tests


class BaseFlowJudgeModel(ABC):
"""Base class for all FlowJudge models."""
Expand Down Expand Up @@ -56,14 +58,20 @@ async def _async_batch_generate(
"""Generate responses for multiple prompts asynchronously."""
pass

class GenerationParams(BaseModel):
temperature: float = 0.1
top_p: float = 0.95
max_new_tokens: int = 1000
do_sample: bool = True

Check warning on line 65 in flow_judge/models/common.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/common.py#L61-L65

Added lines #L61 - L65 were not covered by tests


class ModelType(Enum):

Check warning on line 68 in flow_judge/models/common.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/common.py#L68

Added line #L68 was not covered by tests
"""Enum for the type of model."""

TRANSFORMERS = "transformers"
VLLM = "vllm"
VLLM_ASYNC = "vllm_async"
LLAMAFILE = "llamafile" # Add this line
LLAMAFILE = "llamafile"

Check warning on line 74 in flow_judge/models/common.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/common.py#L71-L74

Added lines #L71 - L74 were not covered by tests


class Engine(Enum):
Expand Down
30 changes: 16 additions & 14 deletions flow_judge/models/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import os
from typing import Any, Dict, Optional
import warnings
from typing import Any, Dict, List

Check warning on line 4 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L3-L4

Added lines #L3 - L4 were not covered by tests

from pydantic import BaseModel

from flow_judge.models.common import BaseFlowJudgeModel, ModelConfig, ModelType
from flow_judge.models.common import BaseFlowJudgeModel, ModelConfig, ModelType, GenerationParams

Check warning on line 6 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L6

Added line #L6 was not covered by tests

try:
import torch
Expand All @@ -20,13 +19,6 @@
logger = logging.getLogger(__name__)


class GenerationParams(BaseModel):
temperature: float = 0.1
top_p: float = 0.95
max_new_tokens: int = 1000
do_sample: bool = True


class HfConfig(ModelConfig):
def __init__(

Check warning on line 23 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L22-L23

Added lines #L22 - L23 were not covered by tests
self,
Expand Down Expand Up @@ -64,6 +56,16 @@ def __init__(
)

default_model_id = "flowaicom/Flow-Judge-v0.1"

Check warning on line 58 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L58

Added line #L58 was not covered by tests

if model_id is not None and model_id != default_model_id:
warnings.warn(

Check warning on line 61 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L60-L61

Added lines #L60 - L61 were not covered by tests
f"The model '{model_id}' is not officially supported. "
f"This library is designed for the '{default_model_id}' model. "
"Using other models may lead to unexpected behavior, and we do not handle "
"GitHub issues for unsupported models. Proceed with caution.",
UserWarning
)

model_id = model_id or default_model_id

Check warning on line 69 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L69

Added line #L69 was not covered by tests

generation_params = GenerationParams(**(generation_params or {}))

Check warning on line 71 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L71

Added line #L71 was not covered by tests
Expand Down Expand Up @@ -118,7 +120,7 @@ def __init__(
"pip install flow-judge[...,hf]",
) from e

def _determine_batch_size(self, prompts: list[str]) -> int:
def _determine_batch_size(self, prompts: List[str]) -> int:

Check warning on line 123 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L123

Added line #L123 was not covered by tests
"""Determine an appropriate batch size based on available GPU memory and eval_inputs."""
if self.device == "cpu":
return 1 # Default to 1 for CPU
Expand Down Expand Up @@ -193,8 +195,8 @@ def _generate(self, prompt: str) -> str:
return generated_text.strip()

def _batch_generate(

Check warning on line 197 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L197

Added line #L197 was not covered by tests
self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any
) -> list[str]:
self, prompts: List[str], use_tqdm: bool = True, **kwargs: Any
) -> List[str]:
"""Generate responses for multiple prompts using batching."""
all_results = []

Expand Down
108 changes: 75 additions & 33 deletions flow_judge/models/llamafile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

import requests
from tqdm import tqdm
import warnings

Check warning on line 12 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L10-L12

Added lines #L10 - L12 were not covered by tests

from flow_judge.models.common import (

Check warning on line 14 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L14

Added line #L14 was not covered by tests
AsyncBaseFlowJudgeModel,
BaseFlowJudgeModel,
ModelConfig,
ModelType,
GenerationParams,
)

try:
Expand All @@ -34,33 +36,51 @@ class LlamafileConfig(ModelConfig):
def __init__(

Check warning on line 36 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L35-L36

Added lines #L35 - L36 were not covered by tests
self,
model_id: str,
generation_params: Dict[str, Any],
generation_params: GenerationParams,
model_filename: str = "flow-judge.llamafile",
cache_dir: str = os.path.expanduser("~/.cache/flow-judge"),
port: int = 8085,
disable_kv_offload: bool = False,
llamafile_kvargs: str = "",
quantized_kv: bool = True,
flash_attn: bool = True,
llamafile_server_kwargs: Dict[str, Any] = None,
**kwargs: Any,
):
super().__init__(model_id, ModelType.LLAMAFILE, generation_params, **kwargs)
super().__init__(model_id, ModelType.LLAMAFILE, generation_params.model_dump(), **kwargs)
self.model_filename = model_filename
self.cache_dir = cache_dir
self.port = port
self.disable_kv_offload = disable_kv_offload
self.llamafile_kvargs = llamafile_kvargs
self.quantized_kv = quantized_kv
self.flash_attn = flash_attn
self.llamafile_server_kwargs = llamafile_server_kwargs or {}

Check warning on line 56 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L49-L56

Added lines #L49 - L56 were not covered by tests


class Llamafile(BaseFlowJudgeModel, AsyncBaseFlowJudgeModel):

Check warning on line 59 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L59

Added line #L59 was not covered by tests
"""Combined FlowJudge model class for Llamafile supporting both sync and async operations."""
"""Combined FlowJudge model class for Llamafile supporting both sync and async operations.
Args:
model (str, optional): The model ID to use. Defaults to "sariola/flow-judge-llamafile".
generation_params (Dict[str, Any], optional): Generation parameters.
cache_dir (str, optional): Directory to cache the model. Defaults to "~/.cache/flow-judge".
port (int, optional): Port to run the Llamafile server on. Defaults to 8085.
disable_kv_offload (bool, optional): Whether to disable KV offloading. Defaults to False.
quantized_kv (bool, optional): Whether to enable Quantized KV. Defaults to True.
flash_attn (bool, optional): Whether to enable Flash Attention. Defaults to True.
llamafile_server_kwargs (Dict[str, Any], optional): Additional keyword arguments for the Llamafile server.
**kwargs: Additional keyword arguments.
"""

def __init__(

Check warning on line 74 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L74

Added line #L74 was not covered by tests
self,
model: str = None,
generation_params: dict[str, Any] = None,
generation_params: Dict[str, Any] = None,
cache_dir: str = os.path.expanduser("~/.cache/flow-judge"),
port: int = 8085,
disable_kv_offload: bool = False,
llamafile_kvargs: str = "",
quantized_kv: bool = True,
flash_attn: bool = True,
llamafile_server_kwargs: Dict[str, Any] = None,
**kwargs: Any,
):
"""Initialize the FlowJudge Llamafile model."""
Expand All @@ -73,20 +93,19 @@ def __init__(
)

default_model_id = "sariola/flow-judge-llamafile"

Check warning on line 95 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L95

Added line #L95 was not covered by tests

if model is not None and model != default_model_id:
warnings.warn(

Check warning on line 98 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L97-L98

Added lines #L97 - L98 were not covered by tests
f"The model '{model}' is not officially supported. "
f"This library is designed for the '{default_model_id}' model. "
"Using other models may lead to unexpected behavior, and we do not handle "
"GitHub issues for unsupported models. Proceed with caution.",
UserWarning
)

model = model or default_model_id

Check warning on line 106 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L106

Added line #L106 was not covered by tests

default_generation_params = {
"temperature": 0.1,
"top_p": 0.95,
"max_tokens": 2000,
"context_size": 8192,
"gpu_layers": 34,
"thread_count": os.cpu_count() or 1,
"batch_size": 32, # here batch doesn't mean parallel requests, it's just the batch size for the llamafile server
"max_concurrent_requests": 1,
"stop": ["<|endoftext|>"],
}
generation_params = generation_params or default_generation_params
generation_params = GenerationParams(**(generation_params or {}))

Check warning on line 108 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L108

Added line #L108 was not covered by tests

config = LlamafileConfig(

Check warning on line 110 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L110

Added line #L110 was not covered by tests
model_id=model,
Expand All @@ -95,14 +114,16 @@ def __init__(
cache_dir=cache_dir,
port=port,
disable_kv_offload=disable_kv_offload,
llamafile_kvargs=llamafile_kvargs,
quantized_kv=quantized_kv,
flash_attn=flash_attn,
llamafile_server_kwargs=llamafile_server_kwargs,
**kwargs,
)

super().__init__(model, "llamafile", generation_params, **kwargs)
super().__init__(model, "llamafile", config.generation_params, **kwargs)

Check warning on line 123 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L123

Added line #L123 was not covered by tests

try:
self.generation_params = generation_params
self.generation_params = config.generation_params
self.cache_dir = config.cache_dir
self.model_repo = config.model_id.split("/")[0]
self.model_filename = config.model_filename
Expand All @@ -121,7 +142,9 @@ def __init__(
self._context_depth = 0

Check warning on line 142 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L140-L142

Added lines #L140 - L142 were not covered by tests

self.disable_kv_offload = config.disable_kv_offload
self.llamafile_kvargs = config.llamafile_kvargs
self.quantized_kv = config.quantized_kv
self.flash_attn = config.flash_attn
self.llamafile_server_kwargs = config.llamafile_server_kwargs

Check warning on line 147 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L144-L147

Added lines #L144 - L147 were not covered by tests

self.metadata = {

Check warning on line 149 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L149

Added line #L149 was not covered by tests
"model_id": model,
Expand Down Expand Up @@ -209,16 +232,35 @@ def start_llamafile_server(self):
logging.error(f"Llamafile at {llamafile_path} is not executable")
raise PermissionError(f"Llamafile at {llamafile_path} is not executable")

Check warning on line 233 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L231-L233

Added lines #L231 - L233 were not covered by tests

command = f"sh -c '{llamafile_path} --server --host 127.0.0.1 --port {self.port} -c {self.generation_params.get('context_length', 8192)} -ngl {self.generation_params.get('gpu_layers', 34)} --temp {self.generation_params.get('temperature', 0.1)} -n {self.generation_params.get('max_tokens', 1000)} --threads {self.generation_params.get('thread_count', os.cpu_count() or 1)} --nobrowser -b {self.generation_params.get('batch_size', 1)} --parallel {self.generation_params.get('max_concurrent_requests', 1)} --cont-batching'"

if self.generation_params.get("disable_kv_offload", False):
command = f"sh -c '{llamafile_path} --server --host 127.0.0.1 --port {self.port} " \

Check warning on line 235 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L235

Added line #L235 was not covered by tests
f"-c {self.generation_params.get('context_size', 8192)} " \
f"-ngl {self.generation_params.get('gpu_layers', 34)} " \
f"--temp {self.generation_params['temperature']} " \
f"-n {self.generation_params['max_new_tokens']} " \
f"--threads {self.generation_params.get('thread_count', os.cpu_count() or 1)} " \
f"--nobrowser -b {self.generation_params.get('batch_size', 32)} " \
f"--parallel {self.generation_params.get('max_concurrent_requests', 1)} " \
f"--cont-batching'"

if self.disable_kv_offload:
command += " -nkvo"
logging.info("KV offloading disabled")

Check warning on line 247 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L245-L247

Added lines #L245 - L247 were not covered by tests

extra_args = self.generation_params.get("llamafile_kvargs", "")
if extra_args:
command += f" {extra_args}"
logging.info(f"Additional arguments added: {extra_args}")
if self.quantized_kv:
command += " -ctk q4_0 -ctv q4_0"
logging.info("Quantized KV enabled")

Check warning on line 251 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L249-L251

Added lines #L249 - L251 were not covered by tests

if self.flash_attn:
command += " -fa"
logging.info("Flash Attention enabled")

Check warning on line 255 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L253-L255

Added lines #L253 - L255 were not covered by tests

if self.quantized_kv and not self.flash_attn:
raise LlamafileError("Quantized KV is enabled but Flash Attention is disabled. This configuration won't function.")

Check warning on line 258 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L257-L258

Added lines #L257 - L258 were not covered by tests

# Add any additional server arguments
for key, value in self.llamafile_server_kwargs.items():
command += f" --{key} {value}"
logging.info(f"Additional server argument added: --{key} {value}")

Check warning on line 263 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L261-L263

Added lines #L261 - L263 were not covered by tests

logging.info(f"Starting llamafile server with command: {command}")

Check warning on line 265 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L265

Added line #L265 was not covered by tests

Expand Down Expand Up @@ -307,9 +349,9 @@ async def _async_generate(self, prompt: str) -> str:

def _get_generation_params(self):
return {

Check warning on line 351 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L350-L351

Added lines #L350 - L351 were not covered by tests
"max_tokens": self.generation_params.get("max_tokens", 1000),
"top_p": self.generation_params.get("top_p", 0.95),
"temperature": self.generation_params.get("temperature", 0.1),
"max_tokens": self.generation_params['max_new_tokens'],
"top_p": self.generation_params['top_p'],
"temperature": self.generation_params['temperature'],
"stop": self.generation_params.get("stop", ["<|endoftext|>"]),
}

Expand Down
Loading

0 comments on commit a54bbee

Please sign in to comment.