Skip to content

Commit

Permalink
fix: fix networking issue inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Oct 26, 2023
1 parent 64c7a89 commit b3eec53
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
7 changes: 4 additions & 3 deletions presets/llama-2-chat/inference-api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uvicorn
from pydantic import BaseModel
from typing import Optional
from multiprocessing import Process
import threading

from llama import Llama
Expand Down Expand Up @@ -147,8 +148,8 @@ def health_check():
return {"status": "Healthy"}

def start_worker_server():
uvicorn.run(app=app_worker, host='0.0.0.0', port=5000)
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000")
uvicorn.run(app=app_worker, host='0.0.0.0', port=5000)

def worker_listen_tasks():
while True:
Expand Down Expand Up @@ -203,8 +204,8 @@ def worker_listen_tasks():

# Start the worker server in a separate thread. This worker server will
# provide a healthz endpoint for monitoring the health of the node.
server_thread = threading.Thread(target=start_worker_server, daemon=True)
server_thread.start()
server_process = Process(target=start_worker_server)
server_process.start()

# Regardless of local rank, all non-globally-0-ranked processes will listen
# for tasks (like chat completion) from the main server.
Expand Down
8 changes: 4 additions & 4 deletions presets/llama-2/inference-api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uvicorn
from pydantic import BaseModel
from typing import Optional
import threading
from multiprocessing import Process
import time
from multiprocessing import Value

Expand Down Expand Up @@ -136,8 +136,8 @@ def health_check():
return {"status": "Healthy"}

def start_worker_server():
uvicorn.run(app=app_worker, host='0.0.0.0', port=5000)
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000")
uvicorn.run(app=app_worker, host='0.0.0.0', port=5000)

def worker_listen_tasks():
while True:
Expand Down Expand Up @@ -191,8 +191,8 @@ def worker_listen_tasks():

# Start the worker server in a separate thread. This worker server will
# provide a healthz endpoint for monitoring the health of the node.
server_thread = threading.Thread(target=start_worker_server, daemon=True)
server_thread.start()
server_process = Process(target=start_worker_server)
server_process.start()

# Regardless of local rank, all non-globally-0-ranked processes will listen
# for tasks (like text completion) from the main server.
Expand Down

0 comments on commit b3eec53

Please sign in to comment.