diff --git a/presets/ragengine/models.py b/presets/ragengine/models.py index d9e7d1f60..5b18e6175 100644 --- a/presets/ragengine/models.py +++ b/presets/ragengine/models.py @@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, root_validator, ValidationError +from pydantic import BaseModel, Field, model_validator + class Document(BaseModel): text: str @@ -33,8 +34,8 @@ class QueryRequest(BaseModel): description="Optional parameters for reranking, e.g., top_n, batch_size", ) - @root_validator(pre=True) - def validate_params(cls, values): + @model_validator(mode="before") + def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]: llm_params = values.get("llm_params", {}) rerank_params = values.get("rerank_params", {}) @@ -43,7 +44,7 @@ def validate_params(cls, values): raise ValueError("Temperature must be between 0.0 and 1.0.") # Validate rerank parameters - top_k = values["top_k"] + top_k = values.get("top_k") if "top_n" in rerank_params and rerank_params["top_n"] > top_k: raise ValueError("Invalid configuration: 'top_n' for reranking cannot exceed 'top_k' from the RAG query.") diff --git a/presets/workspace/inference/text-generation/inference_api.py b/presets/workspace/inference/text-generation/inference_api.py index 51e8bcf19..11b690e0d 100644 --- a/presets/workspace/inference/text-generation/inference_api.py +++ b/presets/workspace/inference/text-generation/inference_api.py @@ -36,7 +36,6 @@ class ModelConfig: """ Transformers Model Configuration Parameters """ - pipeline: Optional[str] = field(default="text-generation", metadata={"help": "The model pipeline for the pre-trained model"}) pretrained_model_name_or_path: Optional[str] = field(default="/workspace/tfs/weights", metadata={"help": "Path to the pretrained model or model identifier from huggingface.co/models"}) combination_type: Optional[str]=field(default="svd", metadata={"help": "The combination type of multi adapters"}) state_dict: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "State dictionary for the model"})