diff --git a/flow_judge/__init__.py b/flow_judge/__init__.py index 6255276..e3989cb 100644 --- a/flow_judge/__init__.py +++ b/flow_judge/__init__.py @@ -3,7 +3,6 @@ from flow_judge.eval_data_types import EvalInput, EvalOutput from flow_judge.flow_judge import AsyncFlowJudge, FlowJudge from flow_judge.metrics import CustomMetric, Metric, RubricItem, list_all_metrics -from flow_judge.models import Hf, Llamafile, Vllm from flow_judge.models.common import BaseFlowJudgeModel from flow_judge.utils.prompt_formatter import format_rubric, format_user_prompt, format_vars @@ -25,10 +24,39 @@ "CustomMetric", "BaseFlowJudgeModel", "EvalOutput", - "Llamafile", - "Vllm", - "Hf", ] +# Conditional imports for optional dependencies +try: + from flow_judge.models.hf import Hf + __all__.append("Hf") +except ImportError: + Hf = None + +try: + from flow_judge.models.vllm import Vllm + __all__.append("Vllm") +except ImportError: + Vllm = None + +try: + from flow_judge.models.llamafile import Llamafile + __all__.append("Llamafile") +except ImportError: + Llamafile = None + +def get_available_models(): + """Return a list of available model classes based on installed extras.""" + models = [BaseFlowJudgeModel] + if Hf is not None: + models.append(Hf) + if Vllm is not None: + models.append(Vllm) + if Llamafile is not None: + models.append(Llamafile) + return models + +__all__.append("get_available_models") + # Add all metric names to __all__ __all__ += list_all_metrics() diff --git a/flow_judge/models/huggingface.py b/flow_judge/models/huggingface.py index eb3461b..da0944e 100644 --- a/flow_judge/models/huggingface.py +++ b/flow_judge/models/huggingface.py @@ -4,6 +4,8 @@ from pydantic import BaseModel +from flow_judge.models.common import BaseFlowJudgeModel, ModelConfig, ModelType + try: import torch from huggingface_hub import snapshot_download @@ -15,8 +17,6 @@ from tqdm import tqdm -from flow_judge.models.common import BaseFlowJudgeModel, ModelConfig, ModelType - logger = logging.getLogger(__name__) diff --git a/flow_judge/models/llamafile.py b/flow_judge/models/llamafile.py index 45e6021..f3de89e 100644 --- a/flow_judge/models/llamafile.py +++ b/flow_judge/models/llamafile.py @@ -19,7 +19,6 @@ try: from openai import AsyncOpenAI, OpenAI - LLAMAFILE_AVAILABLE = True except ImportError: LLAMAFILE_AVAILABLE = False diff --git a/flow_judge/models/vllm.py b/flow_judge/models/vllm.py index 627621f..a2af319 100644 --- a/flow_judge/models/vllm.py +++ b/flow_judge/models/vllm.py @@ -1,9 +1,6 @@ import asyncio from typing import Any, Dict -import torch -from transformers import AutoTokenizer - from flow_judge.models.common import ( AsyncBaseFlowJudgeModel, BaseFlowJudgeModel, @@ -12,8 +9,9 @@ ) try: + import torch + from transformers import AutoTokenizer from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams - VLLM_AVAILABLE = True except ImportError: VLLM_AVAILABLE = False