diff --git a/docs/stable/store/quickstart.md b/docs/stable/store/quickstart.md index 82e7b04..b2df44c 100644 --- a/docs/stable/store/quickstart.md +++ b/docs/stable/store/quickstart.md @@ -158,7 +158,10 @@ class VllmModelDownloader: from vllm import LLM from vllm.config import LoadFormat - def _run_writer(input_dir, output_dir): + # set the model storage path + storage_path = os.getenv("STORAGE_PATH", "./models") + + def _run_writer(input_dir, model_name): # load models from the input directory llm_writer = LLM( model=input_dir, @@ -169,10 +172,11 @@ class VllmModelDownloader: enforce_eager=True, max_model_len=1, ) + model_path = os.path.join(storage_path, model_name) model_executer = llm_writer.llm_engine.model_executor # save the models in the ServerlessLLM format model_executer.save_serverless_llm_state( - path=output_dir, pattern=pattern, max_size=max_size + path=model_path, pattern=pattern, max_size=max_size ) for file in os.listdir(input_dir): # Copy the metadata files into the output directory @@ -182,11 +186,11 @@ class VllmModelDownloader: ".safetensors", ): src_path = os.path.join(input_dir, file) - dest_path = os.path.join(output_dir, file) + dest_path = os.path.join(model_path, file) if os.path.isdir(src_path): shutil.copytree(src_path, dest_path) else: - shutil.copy(src_path, output_dir) + shutil.copy(src_path, dest_path) del model_executer del llm_writer gc.collect() @@ -194,36 +198,25 @@ class VllmModelDownloader: torch.cuda.empty_cache() torch.cuda.synchronize() - # set the model storage path - storage_path = os.getenv("STORAGE_PATH", "./models") - model_dir = os.path.join(storage_path, model_name) - - # create the output directory - if os.path.exists(model_dir): - print(f"Already exists: {model_dir}") - return - os.makedirs(model_dir, exist_ok=True) - try: with TemporaryDirectory() as cache_dir: - # download model from huggingface + # download from huggingface input_dir = snapshot_download( model_name, cache_dir=cache_dir, allow_patterns=["*.safetensors", "*.bin", "*.json", "*.txt"], ) - _run_writer(input_dir, model_dir) + _run_writer(input_dir, model_name) except Exception as e: print(f"An error occurred while saving the model: {e}") # remove the output dir - shutil.rmtree(model_dir) + shutil.rmtree(os.path.join(storage_path, model_name)) raise RuntimeError( - f"Failed to save model {model_name} for vllm backend: {e}" + f"Failed to save {model_name} for vllm backend: {e}" ) downloader = VllmModelDownloader() downloader.download_vllm_model("facebook/opt-1.3b", "float16", 1) - ``` After downloading the model, you can launch the checkpoint store server and load the model in vLLM through `serverless_llm` load format.