Skip to content

Commit

Permalink
init fix for extras
Browse files Browse the repository at this point in the history
  • Loading branch information
sariola committed Oct 8, 2024
1 parent 60036bc commit 3435564
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 11 deletions.
36 changes: 32 additions & 4 deletions flow_judge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from flow_judge.eval_data_types import EvalInput, EvalOutput
from flow_judge.flow_judge import AsyncFlowJudge, FlowJudge
from flow_judge.metrics import CustomMetric, Metric, RubricItem, list_all_metrics
from flow_judge.models import Hf, Llamafile, Vllm
from flow_judge.models.common import BaseFlowJudgeModel

Check warning on line 6 in flow_judge/__init__.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/__init__.py#L6

Added line #L6 was not covered by tests
from flow_judge.utils.prompt_formatter import format_rubric, format_user_prompt, format_vars

Expand All @@ -25,10 +24,39 @@
"CustomMetric",
"BaseFlowJudgeModel",
"EvalOutput",
"Llamafile",
"Vllm",
"Hf",
]

# Conditional imports for optional dependencies
try:
from flow_judge.models.hf import Hf
__all__.append("Hf")
except ImportError:
Hf = None

Check warning on line 34 in flow_judge/__init__.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/__init__.py#L30-L34

Added lines #L30 - L34 were not covered by tests

try:
from flow_judge.models.vllm import Vllm
__all__.append("Vllm")
except ImportError:
Vllm = None

Check warning on line 40 in flow_judge/__init__.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/__init__.py#L36-L40

Added lines #L36 - L40 were not covered by tests

try:
from flow_judge.models.llamafile import Llamafile
__all__.append("Llamafile")
except ImportError:
Llamafile = None

Check warning on line 46 in flow_judge/__init__.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/__init__.py#L42-L46

Added lines #L42 - L46 were not covered by tests

def get_available_models():

Check warning on line 48 in flow_judge/__init__.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/__init__.py#L48

Added line #L48 was not covered by tests
"""Return a list of available model classes based on installed extras."""
models = [BaseFlowJudgeModel]
if Hf is not None:
models.append(Hf)
if Vllm is not None:
models.append(Vllm)
if Llamafile is not None:
models.append(Llamafile)
return models

Check warning on line 57 in flow_judge/__init__.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/__init__.py#L50-L57

Added lines #L50 - L57 were not covered by tests

__all__.append("get_available_models")

Check warning on line 59 in flow_judge/__init__.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/__init__.py#L59

Added line #L59 was not covered by tests

# Add all metric names to __all__
__all__ += list_all_metrics()
4 changes: 2 additions & 2 deletions flow_judge/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic import BaseModel

Check warning on line 5 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L5

Added line #L5 was not covered by tests

from flow_judge.models.common import BaseFlowJudgeModel, ModelConfig, ModelType

Check warning on line 7 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L7

Added line #L7 was not covered by tests

try:
import torch
from huggingface_hub import snapshot_download
Expand All @@ -15,8 +17,6 @@

from tqdm import tqdm

Check warning on line 18 in flow_judge/models/huggingface.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/huggingface.py#L18

Added line #L18 was not covered by tests

from flow_judge.models.common import BaseFlowJudgeModel, ModelConfig, ModelType

logger = logging.getLogger(__name__)


Expand Down
1 change: 0 additions & 1 deletion flow_judge/models/llamafile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

try:
from openai import AsyncOpenAI, OpenAI

LLAMAFILE_AVAILABLE = True
except ImportError:
LLAMAFILE_AVAILABLE = False

Check warning on line 24 in flow_judge/models/llamafile.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/llamafile.py#L20-L24

Added lines #L20 - L24 were not covered by tests
Expand Down
6 changes: 2 additions & 4 deletions flow_judge/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import asyncio
from typing import Any, Dict

Check warning on line 2 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L1-L2

Added lines #L1 - L2 were not covered by tests

import torch
from transformers import AutoTokenizer

from flow_judge.models.common import (

Check warning on line 4 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L4

Added line #L4 was not covered by tests
AsyncBaseFlowJudgeModel,
BaseFlowJudgeModel,
Expand All @@ -12,8 +9,9 @@
)

try:
import torch
from transformers import AutoTokenizer
from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams

VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False

Check warning on line 17 in flow_judge/models/vllm.py

View check run for this annotation

Codecov / codecov/patch

flow_judge/models/vllm.py#L11-L17

Added lines #L11 - L17 were not covered by tests
Expand Down

0 comments on commit 3435564

Please sign in to comment.