Skip to content

Commit

Permalink
feat(celery): Set celery worker queue via CELERY_WORKER_QUEUE (#1633)
Browse files Browse the repository at this point in the history
Setting `CELERY_WORKER_QUEUE` env var will make that instance of seer
both publish tasks to and the celery worker consume from the queue with
that name.

Will follow up w/ SRE to set this value.
  • Loading branch information
jennmueng authored Dec 17, 2024
1 parent bf00249 commit e135497
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 37 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,12 @@ VCRs are a way to record and replay HTTP requests. They are useful for recordin
To use VCRs, add the `@pytest.mark.vcr()` decorator to your test.

To record new VCRs, delete the existing cassettes and run the test. Subsequent test runs will use the cassette instead of making requests.


# Production

## Celery Worker Queue

You can set the queue that the celery worker listens on via the `CELERY_WORKER_QUEUE` environment variable.

If not set, the default queue name is `"seer"`.
10 changes: 8 additions & 2 deletions celeryworker.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
#!/bin/bash

# TODO: Remove debug log level once celery debugging is done
WORKER_CMD="celery -A src.celery_app.tasks worker --loglevel=info $CELERY_WORKER_OPTIONS"
# You can set the celery queue name via the CELERY_WORKER_QUEUE environment variable.
# If not set, the default queue name is "seer".
QUEUE="seer"
if [ "$CELERY_WORKER_QUEUE" != "" ]; then
QUEUE="$CELERY_WORKER_QUEUE"
fi

WORKER_CMD="celery -A src.celery_app.tasks worker --loglevel=info -Q $QUEUE $CELERY_WORKER_OPTIONS"

if [ "$CELERY_WORKER_ENABLE" = "true" ]; then
if [ "$DEV" = "true" ] || [ "$DEV" = "1" ]; then
Expand Down
14 changes: 4 additions & 10 deletions src/celery_app/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from enum import StrEnum
from typing import Any

from seer.bootup import module
Expand All @@ -10,11 +9,6 @@ class CeleryConfig(dict[str, Any]):
pass


class CeleryQueues(StrEnum):
DEFAULT = "seer"
CUDA = "seer-cuda"


@module.provider
def celery_config(app_config: AppConfig = injected) -> CeleryConfig:
return CeleryConfig(
Expand All @@ -23,11 +17,11 @@ def celery_config(app_config: AppConfig = injected) -> CeleryConfig:
result_serializer="json",
accept_content=["json"],
enable_utc=True,
task_default_queue=CeleryQueues.DEFAULT,
task_default_queue=app_config.CELERY_WORKER_QUEUE,
task_queues={
CeleryQueues.DEFAULT: {
"exchange": CeleryQueues.DEFAULT,
"routing_key": CeleryQueues.DEFAULT,
app_config.CELERY_WORKER_QUEUE: {
"exchange": app_config.CELERY_WORKER_QUEUE,
"routing_key": app_config.CELERY_WORKER_QUEUE,
}
},
result_backend="rpc://",
Expand Down
11 changes: 6 additions & 5 deletions src/celery_app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import seer.app # noqa: F401
from celery_app.app import celery_app as celery # noqa: F401
from celery_app.config import CeleryQueues
from seer.anomaly_detection.tasks import cleanup_disabled_alerts, cleanup_timeseries # noqa: F401
from seer.automation.autofix.tasks import check_and_mark_recent_autofix_runs
from seer.automation.tasks import delete_data_for_ttl
Expand All @@ -17,13 +16,15 @@ def setup_periodic_tasks(sender, config: AppConfig = injected, **kwargs):
if config.is_autofix_enabled:
sender.add_periodic_task(
crontab(minute="0", hour="*"),
check_and_mark_recent_autofix_runs.signature(kwargs={}, queue=CeleryQueues.DEFAULT),
check_and_mark_recent_autofix_runs.signature(
kwargs={}, queue=config.CELERY_WORKER_QUEUE
),
name="Check and mark recent autofix runs every hour",
)

sender.add_periodic_task(
crontab(minute="0", hour="0"), # run once a day
delete_data_for_ttl.signature(kwargs={}, queue=CeleryQueues.DEFAULT),
delete_data_for_ttl.signature(kwargs={}, queue=config.CELERY_WORKER_QUEUE),
name="Delete old Automation runs for 90 day time-to-live",
)

Expand All @@ -32,13 +33,13 @@ def setup_periodic_tasks(sender, config: AppConfig = injected, **kwargs):

sender.add_periodic_task(
crontab(minute="*", hour="*"), # run every minute
try_grpc_client.signature(kwargs={}, queue=CeleryQueues.DEFAULT),
try_grpc_client.signature(kwargs={}, queue=config.CELERY_WORKER_QUEUE),
name="Try executing grpc request every minute.",
)

if config.ANOMALY_DETECTION_ENABLED:
sender.add_periodic_task(
crontab(minute="0", hour="0", day_of_week="0"), # Run once a week on Sunday
cleanup_disabled_alerts.signature(kwargs={}, queue=CeleryQueues.DEFAULT),
cleanup_disabled_alerts.signature(kwargs={}, queue=config.CELERY_WORKER_QUEUE),
name="Clean up old disabled timeseries every week",
)
8 changes: 5 additions & 3 deletions src/seer/automation/autofix/steps/coding_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from sentry_sdk.ai.monitoring import ai_track

from celery_app.app import celery_app
from celery_app.config import CeleryQueues
from seer.automation.agent.models import Message
from seer.automation.autofix.components.coding.component import CodingComponent
from seer.automation.autofix.components.coding.models import CodingRequest
Expand All @@ -21,6 +20,8 @@
from seer.automation.models import EventDetails
from seer.automation.pipeline import PipelineStepTaskRequest
from seer.automation.utils import make_kill_signal
from seer.configuration import AppConfig
from seer.dependency_injection import inject, injected


class AutofixCodingStepRequest(PipelineStepTaskRequest):
Expand Down Expand Up @@ -54,7 +55,8 @@ def get_task():

@observe(name="Autofix - Plan+Code Step")
@ai_track(description="Autofix - Plan+Code Step")
def _invoke(self, **kwargs):
@inject
def _invoke(self, app_config: AppConfig = injected):
self.context.event_manager.clear_file_changes()

self.logger.info("Executing Autofix - Plan+Code Step")
Expand Down Expand Up @@ -107,5 +109,5 @@ def _invoke(self, **kwargs):
pr_to_comment_on=pr_to_comment_on,
),
),
queue=CeleryQueues.DEFAULT,
queue=app_config.CELERY_WORKER_QUEUE,
)
43 changes: 32 additions & 11 deletions src/seer/automation/autofix/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from langfuse import Langfuse

from celery_app.app import celery_app
from celery_app.config import CeleryQueues
from seer.automation.agent.models import Message
from seer.automation.agent.utils import parse_json_with_keys
from seer.automation.autofix.autofix_context import AutofixContext
Expand Down Expand Up @@ -42,7 +41,9 @@
from seer.automation.autofix.steps.root_cause_step import RootCauseStep, RootCauseStepRequest
from seer.automation.models import InitializationError
from seer.automation.utils import process_repo_provider, raise_if_no_genai_consent
from seer.configuration import AppConfig
from seer.db import DbPrIdToAutofixRunIdMapping, DbRunState, Session
from seer.dependency_injection import inject, injected
from seer.rpc import get_sentry_client

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,8 +127,10 @@ def check_and_mark_if_timed_out(state: ContinuationState):
logger.error(f"Autofix run {cur.run_id} has timed out")


@inject
def run_autofix_root_cause(
request: AutofixRequest,
app_config: AppConfig = injected,
):
state = create_initial_autofix_run(request)

Expand All @@ -142,13 +145,17 @@ def run_autofix_root_cause(
RootCauseStepRequest(
run_id=cur_state.run_id,
),
queue=CeleryQueues.DEFAULT,
queue=app_config.CELERY_WORKER_QUEUE,
).apply_async()

return cur_state.run_id


def run_autofix_execution(request: AutofixUpdateRequest):
@inject
def run_autofix_execution(
request: AutofixUpdateRequest,
app_config: AppConfig = injected,
):
state = ContinuationState(request.run_id)

raise_if_no_genai_consent(state.get().request.organization_id)
Expand All @@ -174,14 +181,18 @@ def run_autofix_execution(request: AutofixUpdateRequest):
AutofixCodingStepRequest(
run_id=cur.run_id,
),
queue=CeleryQueues.DEFAULT,
queue=app_config.CELERY_WORKER_QUEUE,
).apply_async()
except InitializationError as e:
sentry_sdk.capture_exception(e)
raise e


def run_autofix_create_pr(request: AutofixUpdateRequest):
@inject
def run_autofix_create_pr(
request: AutofixUpdateRequest,
app_config: AppConfig = injected,
):
if not isinstance(request.payload, AutofixCreatePrUpdatePayload):
raise ValueError("Invalid payload type for create_pr")

Expand All @@ -200,6 +211,7 @@ def run_autofix_create_pr(request: AutofixUpdateRequest):
)


@inject
def restart_step_with_user_response(
state: ContinuationState,
memory: list[Message],
Expand All @@ -208,6 +220,7 @@ def restart_step_with_user_response(
step_to_restart: Step,
step_class: Type[AutofixCodingStep | RootCauseStep],
step_request_class: Type[AutofixCodingStepRequest | RootCauseStepRequest],
app_config: AppConfig = injected,
):
cur_state = state.get()
if memory:
Expand All @@ -224,7 +237,7 @@ def restart_step_with_user_response(
run_id=cur_state.run_id,
initial_memory=memory,
),
queue=CeleryQueues.DEFAULT,
queue=app_config.CELERY_WORKER_QUEUE,
).apply_async()


Expand Down Expand Up @@ -333,7 +346,11 @@ def truncate_memory_to_match_insights(memory: list[Message], step: DefaultStep):
return truncated_memory if truncated_memory else memory


def restart_from_point_with_feedback(request: AutofixUpdateRequest):
@inject
def restart_from_point_with_feedback(
request: AutofixUpdateRequest,
app_config: AppConfig = injected,
):
if not isinstance(request.payload, AutofixRestartFromPointPayload):
raise ValueError("Invalid payload type for restart_from_point_with_feedback")

Expand Down Expand Up @@ -396,15 +413,15 @@ def restart_from_point_with_feedback(request: AutofixUpdateRequest):
run_id=state.get().run_id,
initial_memory=memory,
),
queue=CeleryQueues.DEFAULT,
queue=app_config.CELERY_WORKER_QUEUE,
).apply_async()
else:
RootCauseStep.get_signature(
RootCauseStepRequest(
run_id=state.get().run_id,
initial_memory=memory,
),
queue=CeleryQueues.DEFAULT,
queue=app_config.CELERY_WORKER_QUEUE,
).apply_async()


Expand Down Expand Up @@ -485,7 +502,11 @@ def update_code_change(request: AutofixUpdateRequest):
cur.steps[-1] = last_step


def run_autofix_evaluation(request: AutofixEvaluationRequest):
@inject
def run_autofix_evaluation(
request: AutofixEvaluationRequest,
app_config: AppConfig = injected,
):
langfuse = Langfuse()

dataset = langfuse.get_dataset(request.dataset_name)
Expand Down Expand Up @@ -521,7 +542,7 @@ def run_autofix_evaluation(request: AutofixEvaluationRequest):
item_index=i,
item_count=len(items),
),
queue=CeleryQueues.DEFAULT,
queue=app_config.CELERY_WORKER_QUEUE,
)


Expand Down
8 changes: 5 additions & 3 deletions src/seer/automation/codegen/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from celery_app.config import CeleryQueues
from seer.automation.codegen.models import (
CodegenContinuation,
CodegenPrReviewRequest,
Expand All @@ -10,6 +9,8 @@
from seer.automation.codegen.state import CodegenContinuationState
from seer.automation.codegen.unittest_step import UnittestStep, UnittestStepRequest
from seer.automation.state import DbState, DbStateRunTypes
from seer.configuration import AppConfig
from seer.dependency_injection import inject, injected


def create_initial_unittest_run(request: CodegenUnitTestsRequest) -> DbState[CodegenContinuation]:
Expand All @@ -25,7 +26,8 @@ def create_initial_unittest_run(request: CodegenUnitTestsRequest) -> DbState[Cod
return state


def codegen_unittest(request: CodegenUnitTestsRequest):
@inject
def codegen_unittest(request: CodegenUnitTestsRequest, app_config: AppConfig = injected):
state = create_initial_unittest_run(request)

cur_state = state.get()
Expand All @@ -39,7 +41,7 @@ def codegen_unittest(request: CodegenUnitTestsRequest):
pr_id=request.pr_id,
repo_definition=request.repo,
)
UnittestStep.get_signature(unittest_request, queue=CeleryQueues.DEFAULT).apply_async()
UnittestStep.get_signature(unittest_request, queue=app_config.CELERY_WORKER_QUEUE).apply_async()

return CodegenUnitTestsResponse(run_id=cur_state.run_id)

Expand Down
9 changes: 7 additions & 2 deletions src/seer/automation/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import sentry_sdk

from celery_app.config import CeleryQueues
from seer.automation.pipeline import (
PipelineChain,
PipelineStep,
PipelineStepTaskRequest,
SerializedSignature,
)
from seer.automation.utils import make_done_signal
from seer.configuration import AppConfig
from seer.dependency_injection import inject, injected


class ConditionalStepRequest(PipelineStepTaskRequest):
Expand Down Expand Up @@ -88,6 +89,10 @@ def _get_conditional_step_class() -> Type[ParallelizedChainConditionalStep]:
pass

def _invoke(self, **kwargs):
self._send_signatures()

@inject
def _send_signatures(self, app_config: AppConfig = injected):
signatures = [self.instantiate_signature(step) for step in self.request.steps]

expected_signals = [
Expand All @@ -105,6 +110,6 @@ def _invoke(self, **kwargs):
expected_signals=expected_signals,
on_success=self.request.on_success,
),
queue=CeleryQueues.DEFAULT,
queue=app_config.CELERY_WORKER_QUEUE,
),
)
7 changes: 7 additions & 0 deletions src/seer/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class AppConfig(BaseModel):
GRPC_SERVER_ENABLE: ParseBool = False
HOSTNAME: str = Field(default_factory=gethostname)

CELERY_WORKER_QUEUE: str = "seer"

# Test utility that disables deployment conditional behavior.
# Update this to reflect new kinds of conditional behavior by adding
# more test coverage and locking them in.
Expand Down Expand Up @@ -151,6 +153,11 @@ def do_validation(self):
if not self.DEV and not self.GITHUB_SENTRY_PRIVATE_KEY:
logger.warning("GITHUB_SENTRY_PRIVATE_KEY is missing in production!")

if not self.CELERY_WORKER_QUEUE:
logger.warning(
'CELERY_WORKER_QUEUE is not set in production! Will default to "seer"'
)


@configuration_module.provider
def load_from_environment(environ: dict[str, str] | None = None) -> AppConfig:
Expand Down
2 changes: 1 addition & 1 deletion supervisord.conf
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ stderr_logfile_maxbytes=0

; The celery worker program is disabled by default. Set CELERY_WORKER_ENABLE=true in the environment to enable it.
[program:celeryworker-default]
command=env CELERY_WORKER_OPTIONS="-c 16 -Q seer -n seer@%%h" /app/celeryworker.sh
command=env CELERY_WORKER_OPTIONS="-c 16 -n seer@%%h" /app/celeryworker.sh
directory=/app
startsecs=0
autostart=true
Expand Down
2 changes: 2 additions & 0 deletions tests/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def test_celery_app_configuration():

app.finalize()
celery_app.finalize()

assert app.conf.task_default_queue == resolve(CeleryConfig)["task_default_queue"]
assert app.conf.task_queues == resolve(CeleryConfig)["task_queues"]
assert celery_app.conf.task_queues == app.conf.task_queues

Expand Down

0 comments on commit e135497

Please sign in to comment.