diff --git a/nl2dsl/__init__.py b/nl2dsl/__init__.py index 01d56a8..574592d 100644 --- a/nl2dsl/__init__.py +++ b/nl2dsl/__init__.py @@ -2,7 +2,7 @@ from termcolor import cprint from typing import List, Dict, Callable -from .utils.mini_llm import call_llm_for_json +from .utils.mini_llm import call_llm_for_json, chat_completion_request from .utils.dsl_utils import ( update_flow, update_global_variables, diff --git a/nl2dsl/utils/dsl_utils.py b/nl2dsl/utils/dsl_utils.py index c885968..5cc6de6 100644 --- a/nl2dsl/utils/dsl_utils.py +++ b/nl2dsl/utils/dsl_utils.py @@ -51,10 +51,10 @@ def update_flow(step, plugins={}, flow=[], debug=False): dsl_list.insert(-1, llm_response) elif step_type == "edit": edited_task = llm_response - if edited_task["task_id"] == "end": + if edited_task["name"] == "end": cprint(f"Cannot edit end task.", "red") - elif edited_task["task_id"] == "start": + elif edited_task["name"] == "start": for i, task in enumerate(dsl_list): if task["name"] == "start": if edited_task.get("goto"):