From ecc56d1b1449a8463e3ea2d16eb80d6d7c973aa8 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Thu, 25 Jan 2024 20:08:02 -0600 Subject: [PATCH] feat: Organize Inference Files - Part 1 (#216) This is the first part towards the new inference pipeline - organizing inference files into their own folder. This way we can easier trigger build pipelines (edits to presets/models/inference) as well as maintain/add new types of inference --- pkg/inference/preset-inference-types.go | 1 - presets/models/falcon/config.yaml | 27 --- presets/models/falcon/inference-api.py | 118 ------------ .../llama2-chat}/inference-api.py | 20 +- .../llama2-completion}/inference-api.py | 20 +- .../text-generation/inference-api.py | 182 ++++++++++++++++++ .../text-generation}/requirements.txt | 3 +- .../falcon-40b-instruct-statefulset.yaml | 2 +- .../falcon-40b/falcon-40b-statefulset.yaml | 2 +- .../falcon-7b-instruct-statefulset.yaml | 2 +- .../falcon-7b/falcon-7b-statefulset.yaml | 2 +- 11 files changed, 208 insertions(+), 171 deletions(-) delete mode 100644 presets/models/falcon/config.yaml delete mode 100644 presets/models/falcon/inference-api.py rename presets/models/{llama2chat => inference/llama2-chat}/inference-api.py (100%) rename presets/models/{llama2 => inference/llama2-completion}/inference-api.py (100%) create mode 100644 presets/models/inference/text-generation/inference-api.py rename presets/models/{falcon => inference/text-generation}/requirements.txt (86%) diff --git a/pkg/inference/preset-inference-types.go b/pkg/inference/preset-inference-types.go index 8e89143e8..d3157262e 100644 --- a/pkg/inference/preset-inference-types.go +++ b/pkg/inference/preset-inference-types.go @@ -48,7 +48,6 @@ var ( } DefaultAccelerateParams = map[string]string{ - "config_file": DefaultConfigFile, "num_processes": DefaultNumProcesses, "num_machines": DefaultNumMachines, "machine_rank": DefaultMachineRank, diff --git a/presets/models/falcon/config.yaml b/presets/models/falcon/config.yaml deleted file mode 100644 index ccfcbea60..000000000 --- a/presets/models/falcon/config.yaml +++ /dev/null @@ -1,27 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - deepspeed_multinode_launcher: standard - gradient_accumulation_steps: 1 - offload_optimizer_device: none - offload_param_device: none - zero3_init_flag: false - zero_stage: 2 -distributed_type: DEEPSPEED -downcast_bf16: 'no' -dynamo_config: - dynamo_backend: INDUCTOR -gpu_ids: all # GPUs you want to use (i.e. 0,1,2) -machine_rank: 0 # Machine accelerate launch is called on -main_process_ip: localhost # Master IP -main_process_port: 29500 # Master Port -main_training_function: main -mixed_precision: bf16 -num_machines: 1 # Num of machines -num_processes: 1 # Num of processes, processes_per_node is set for each node using its (num_machines // num_processes) -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false \ No newline at end of file diff --git a/presets/models/falcon/inference-api.py b/presets/models/falcon/inference-api.py deleted file mode 100644 index 731a5045a..000000000 --- a/presets/models/falcon/inference-api.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. -import argparse -import os -from typing import List, Optional - -import torch -import transformers -import uvicorn -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from transformers import AutoModelForCausalLM, AutoTokenizer - -parser = argparse.ArgumentParser(description='Falcon Model Configuration') -parser.add_argument('--load_in_8bit', default=False, action='store_true', help='Load model in 8-bit mode') -# parser.add_argument('--model_id', required=True, type=str, help='The Falcon ID for the pre-trained model') -args = parser.parse_args() - -app = FastAPI() - -tokenizer = AutoTokenizer.from_pretrained("/workspace/tfs/weights", local_files_only=True) -model = AutoModelForCausalLM.from_pretrained( - "/workspace/tfs/weights", # args.model_id, - device_map="auto", - torch_dtype=torch.bfloat16, - load_in_8bit=args.load_in_8bit, - local_files_only=True -) - -pipeline = transformers.pipeline( - "text-generation", - model=model, - tokenizer=tokenizer, - torch_dtype=torch.bfloat16, - device_map="auto", -) - -@app.get('/') -def home(): - return "Server is running", 200 - -@app.get("/healthz") -def health_check(): - if not torch.cuda.is_available(): - raise HTTPException(status_code=500, detail="No GPU available") - if not model: - raise HTTPException(status_code=500, detail="Falcon model not initialized") - if not pipeline: - raise HTTPException(status_code=500, detail="Falcon pipeline not initialized") - return {"status": "Healthy"} - -class GenerationParams(BaseModel): - prompt: str - max_length: int = 200 - min_length: int = 0 - do_sample: bool = True - early_stopping: bool = False - num_beams: int = 1 - num_beam_groups: int = 1 - diversity_penalty: float = 0.0 - temperature: float = 1.0 - top_k: int = 10 - top_p: float = 1 - typical_p: float = 1 - repetition_penalty: float = 1 - length_penalty: float = 1 - no_repeat_ngram_size: int = 0 - encoder_no_repeat_ngram_size: int = 0 - bad_words_ids: List[int] = None - num_return_sequences: int = 1 - output_scores: bool = False - return_dict_in_generate: bool = False - pad_token_id: Optional[int] = tokenizer.pad_token_id - eos_token_id: Optional[int] = tokenizer.eos_token_id - forced_bos_token_id: Optional[int] = None - forced_eos_token_id: Optional[int] = None - remove_invalid_values: Optional[bool] = None - -@app.post("/chat") -def generate_text(params: GenerationParams): - sequences = pipeline( - params.prompt, - max_length=params.max_length, - min_length=params.min_length, - do_sample=params.do_sample, - early_stopping=params.early_stopping, - num_beams=params.num_beams, - num_beam_groups=params.num_beam_groups, - diversity_penalty=params.diversity_penalty, - temperature=params.temperature, - top_k=params.top_k, - top_p=params.top_p, - typical_p=params.typical_p, - repetition_penalty=params.repetition_penalty, - length_penalty=params.length_penalty, - no_repeat_ngram_size=params.no_repeat_ngram_size, - encoder_no_repeat_ngram_size=params.encoder_no_repeat_ngram_size, - bad_words_ids=params.bad_words_ids, - num_return_sequences=params.num_return_sequences, - output_scores=params.output_scores, - return_dict_in_generate=params.return_dict_in_generate, - forced_bos_token_id=params.forced_bos_token_id, - forced_eos_token_id=params.forced_eos_token_id, - eos_token_id=params.eos_token_id, - remove_invalid_values=params.remove_invalid_values - ) - - result = "" - for seq in sequences: - print(f"Result: {seq['generated_text']}") - result += seq['generated_text'] - - return {"Result": result} - -if __name__ == "__main__": - local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Default to 0 if not set - port = 5000 + local_rank # Adjust port based on local rank - uvicorn.run(app=app, host='0.0.0.0', port=port) diff --git a/presets/models/llama2chat/inference-api.py b/presets/models/inference/llama2-chat/inference-api.py similarity index 100% rename from presets/models/llama2chat/inference-api.py rename to presets/models/inference/llama2-chat/inference-api.py index 1a2708b21..7f7899de1 100644 --- a/presets/models/llama2chat/inference-api.py +++ b/presets/models/inference/llama2-chat/inference-api.py @@ -1,21 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from fastapi import FastAPI, HTTPException -import uvicorn -from pydantic import BaseModel -from typing import Optional +import argparse +import functools import multiprocessing import multiprocessing.pool +import os +import signal +import sys import threading -import functools +from typing import Optional -from llama import Llama import torch -import sys -import signal -import os import torch.distributed as dist -import argparse +import uvicorn +from fastapi import FastAPI, HTTPException +from llama import Llama +from pydantic import BaseModel # Setup argparse parser = argparse.ArgumentParser(description="Llama API server.") diff --git a/presets/models/llama2/inference-api.py b/presets/models/inference/llama2-completion/inference-api.py similarity index 100% rename from presets/models/llama2/inference-api.py rename to presets/models/inference/llama2-completion/inference-api.py index 9d4c2d500..90ae8c1dc 100644 --- a/presets/models/llama2/inference-api.py +++ b/presets/models/inference/llama2-completion/inference-api.py @@ -1,21 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from fastapi import FastAPI, HTTPException -import uvicorn -from pydantic import BaseModel -from typing import Optional +import argparse +import functools import multiprocessing import multiprocessing.pool +import os +import signal +import sys import threading -import functools +from typing import Optional -from llama import Llama import torch -import sys -import signal -import os import torch.distributed as dist -import argparse +import uvicorn +from fastapi import FastAPI, HTTPException +from llama import Llama +from pydantic import BaseModel # Setup argparse parser = argparse.ArgumentParser(description="Llama API server.") diff --git a/presets/models/inference/text-generation/inference-api.py b/presets/models/inference/text-generation/inference-api.py new file mode 100644 index 000000000..7aa8c8ea4 --- /dev/null +++ b/presets/models/inference/text-generation/inference-api.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import argparse +import os +from typing import Any, Dict, List, Optional + +import GPUtil +import torch +import transformers +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + + +def dtype_type(string): + if hasattr(torch, string): + return getattr(torch, string) + else: + raise ValueError(f"Invalid torch dtype: {string}") + +parser = argparse.ArgumentParser(description='Model Configuration') +parser.add_argument('--pipeline', required=True, type=str, help='The model pipeline for the pre-trained model') +parser.add_argument('--load_in_8bit', default=False, action='store_true', help='Load model in 8-bit mode') +parser.add_argument('--trust_remote_code', default=False, action='store_true', help='Enable trusting remote code when loading the model') +parser.add_argument('--torch_dtype', default=None, type=dtype_type, help='The torch dtype for the pre-trained model') +parser.add_argument('--device_map', default="auto", type=str, help='The device map for the pre-trained model') +parser.add_argument('--cache_dir', type=str, default=None, help='Cache directory for the model') +parser.add_argument('--from_tf', action='store_true', default=False, help='Load model from a TensorFlow checkpoint') +parser.add_argument('--force_download', action='store_true', default=False, help='Force the download of the model') +parser.add_argument('--resume_download', action='store_true', default=False, help='Resume an interrupted download') +parser.add_argument('--proxies', type=str, default=None, help='Proxy configuration for downloading the model') +parser.add_argument('--revision', type=str, default="main", help='Specific model version to use') +# parser.add_argument('--local_files_only', action='store_true', default=False, help='Only use local files for model loading') +parser.add_argument('--output_loading_info', action='store_true', default=False, help='Output additional loading information') + +args = parser.parse_args() + +app = FastAPI() + +supported_pipelines = {"conversational", "text-generation"} +if args.pipeline not in supported_pipelines: + raise HTTPException(status_code=400, detail="Invalid pipeline specified") + +model_kwargs = { + "cache_dir": args.cache_dir, + "from_tf": args.from_tf, + "force_download": args.force_download, + "resume_download": args.resume_download, + "proxies": args.proxies, + "revision": args.revision, + "output_loading_info": args.output_loading_info, + "trust_remote_code": args.trust_remote_code, + "device_map": args.device_map, + "local_files_only": True, +} + +if args.load_in_8bit: + model_kwargs["load_in_8bit"] = args.load_in_8bit +if args.torch_dtype: + model_kwargs["torch_dtype"] = args.torch_dtype + +tokenizer = AutoTokenizer.from_pretrained("/workspace/tfs/weights", **model_kwargs) +model = AutoModelForCausalLM.from_pretrained( + "/workspace/tfs/weights", + **model_kwargs +) + +pipeline_kwargs = { + "trust_remote_code": args.trust_remote_code, + "device_map": args.device_map, +} + +if args.torch_dtype: + pipeline_kwargs["torch_dtype"] = args.torch_dtype + +pipeline = transformers.pipeline( + args.pipeline, + model=model, + tokenizer=tokenizer, + **pipeline_kwargs +) + +try: + # Attempt to load the generation configuration + default_generate_config = GenerationConfig.from_pretrained("/workspace/tfs/weights", local_files_only=True).to_dict() +except Exception as e: + default_generate_config = {} + +@app.get('/') +def home(): + return "Server is running", 200 + +@app.get("/healthz") +def health_check(): + if not torch.cuda.is_available(): + raise HTTPException(status_code=500, detail="No GPU available") + if not model: + raise HTTPException(status_code=500, detail="Model not initialized") + if not pipeline: + raise HTTPException(status_code=500, detail="Pipeline not initialized") + return {"status": "Healthy"} + +class UnifiedRequestModel(BaseModel): + # Fields for text generation + prompt: Optional[str] = Field(None, description="Prompt for text generation") + # Mutually Exclusive with return_full_text + # return_tensors: Optional[bool] = Field(False, description="Return tensors of predictions") + # return_text: Optional[bool] = Field(True, description="Return decoded texts in the outputs") + return_full_text: Optional[bool] = Field(True, description="Return full text if True, else only added text") + clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") + prefix: Optional[str] = Field(None, description="Prefix added to prompt") + handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation") + generate_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional kwargs for generate method") + + # Field for conversational model + messages: Optional[List[dict]] = Field(None, description="Messages for conversational model") + +@app.post("/chat") +def generate_text(request_model: UnifiedRequestModel): + user_generate_kwargs = request_model.generate_kwargs or {} + generate_kwargs = {**default_generate_config, **user_generate_kwargs} + + if args.pipeline == "text-generation": + if not request_model.prompt: + raise HTTPException(status_code=400, detail="Text generation parameter prompt required") + sequences = pipeline( + request_model.prompt, + # return_tensors=request_model.return_tensors, + # return_text=request_model.return_text, + return_full_text=request_model.return_full_text, + clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces, + prefix=request_model.prefix, + handle_long_generation=request_model.handle_long_generation, + **generate_kwargs + ) + + result = "" + for seq in sequences: + print(f"Result: {seq['generated_text']}") + result += seq['generated_text'] + + return {"Result": result} + + elif args.pipeline == "conversational": + if not request_model.messages: + raise HTTPException(status_code=400, detail="Conversational parameter messages required") + + response = pipeline( + request_model.messages, + clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces, + **generate_kwargs + ) + return {"Result": str(response[-1])} + + else: + raise HTTPException(status_code=400, detail="Invalid pipeline type") + +@app.get("/metrics") +def get_metrics(): + try: + gpus = GPUtil.getGPUs() + gpu_info = [] + for gpu in gpus: + gpu_info.append({ + "id": gpu.id, + "name": gpu.name, + "load": f"{gpu.load * 100:.2f}%", # Format as percentage + "temperature": f"{gpu.temperature} C", + "memory": { + "used": f"{gpu.memoryUsed / 1024:.2f} GB", + "total": f"{gpu.memoryTotal / 1024:.2f} GB" + } + }) + return {"gpu_info": gpu_info} + except Exception as e: + return {"error": str(e)} + +if __name__ == "__main__": + local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Default to 0 if not set + port = 5000 + local_rank # Adjust port based on local rank + uvicorn.run(app=app, host='0.0.0.0', port=port) diff --git a/presets/models/falcon/requirements.txt b/presets/models/inference/text-generation/requirements.txt similarity index 86% rename from presets/models/falcon/requirements.txt rename to presets/models/inference/text-generation/requirements.txt index 9b7562702..c97d0ca03 100644 --- a/presets/models/falcon/requirements.txt +++ b/presets/models/inference/text-generation/requirements.txt @@ -6,4 +6,5 @@ fastapi==0.103.2 pydantic==1.10.9 uvicorn[standard]==0.23.2 bitsandbytes==0.41.1 -deepspeed==0.11.1 \ No newline at end of file +deepspeed==0.11.1 +gputil==1.4.0 \ No newline at end of file diff --git a/presets/test/manifests/falcon-40b-instruct/falcon-40b-instruct-statefulset.yaml b/presets/test/manifests/falcon-40b-instruct/falcon-40b-instruct-statefulset.yaml index 0c39ec31e..ee3940d98 100644 --- a/presets/test/manifests/falcon-40b-instruct/falcon-40b-instruct-statefulset.yaml +++ b/presets/test/manifests/falcon-40b-instruct/falcon-40b-instruct-statefulset.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --config_file config.yaml --num_processes 1 --num_machines 1 --use_deepspeed --machine_rank 0 --gpu_ids all inference-api.py + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference-api.py --pipeline text-generation --torch_dtype bfloat16 livenessProbe: httpGet: path: /healthz diff --git a/presets/test/manifests/falcon-40b/falcon-40b-statefulset.yaml b/presets/test/manifests/falcon-40b/falcon-40b-statefulset.yaml index c213867df..afacb7da6 100644 --- a/presets/test/manifests/falcon-40b/falcon-40b-statefulset.yaml +++ b/presets/test/manifests/falcon-40b/falcon-40b-statefulset.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --config_file config.yaml --num_processes 1 --num_machines 1 --use_deepspeed --machine_rank 0 --gpu_ids all inference-api.py + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference-api.py --pipeline text-generation --torch_dtype bfloat16 livenessProbe: httpGet: path: /healthz diff --git a/presets/test/manifests/falcon-7b-instruct/falcon-7b-instruct-statefulset.yaml b/presets/test/manifests/falcon-7b-instruct/falcon-7b-instruct-statefulset.yaml index f6f4cc679..126d461c7 100644 --- a/presets/test/manifests/falcon-7b-instruct/falcon-7b-instruct-statefulset.yaml +++ b/presets/test/manifests/falcon-7b-instruct/falcon-7b-instruct-statefulset.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --config_file config.yaml --num_processes 1 --num_machines 1 --use_deepspeed --machine_rank 0 --gpu_ids all inference-api.py + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference-api.py --pipeline text-generation --torch_dtype bfloat16 livenessProbe: httpGet: path: /healthz diff --git a/presets/test/manifests/falcon-7b/falcon-7b-statefulset.yaml b/presets/test/manifests/falcon-7b/falcon-7b-statefulset.yaml index 67ba004d9..691f0a5f0 100644 --- a/presets/test/manifests/falcon-7b/falcon-7b-statefulset.yaml +++ b/presets/test/manifests/falcon-7b/falcon-7b-statefulset.yaml @@ -19,7 +19,7 @@ spec: command: - /bin/sh - -c - - accelerate launch --config_file config.yaml --num_processes 1 --num_machines 1 --use_deepspeed --machine_rank 0 --gpu_ids all inference-api.py + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference-api.py --pipeline text-generation --torch_dtype bfloat16 livenessProbe: httpGet: path: /healthz