From bdf06b79a837ec5bbe1de8749b3b3cdb6a08f2ee Mon Sep 17 00:00:00 2001 From: wi0lono Date: Fri, 21 Jun 2024 12:19:34 +0530 Subject: [PATCH] Expose import --- nl2dsl/__init__.py | 2 +- nl2dsl/utils/dsl_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nl2dsl/__init__.py b/nl2dsl/__init__.py index 01d56a8..574592d 100644 --- a/nl2dsl/__init__.py +++ b/nl2dsl/__init__.py @@ -2,7 +2,7 @@ from termcolor import cprint from typing import List, Dict, Callable -from .utils.mini_llm import call_llm_for_json +from .utils.mini_llm import call_llm_for_json, chat_completion_request from .utils.dsl_utils import ( update_flow, update_global_variables, diff --git a/nl2dsl/utils/dsl_utils.py b/nl2dsl/utils/dsl_utils.py index c885968..5cc6de6 100644 --- a/nl2dsl/utils/dsl_utils.py +++ b/nl2dsl/utils/dsl_utils.py @@ -51,10 +51,10 @@ def update_flow(step, plugins={}, flow=[], debug=False): dsl_list.insert(-1, llm_response) elif step_type == "edit": edited_task = llm_response - if edited_task["task_id"] == "end": + if edited_task["name"] == "end": cprint(f"Cannot edit end task.", "red") - elif edited_task["task_id"] == "start": + elif edited_task["name"] == "start": for i, task in enumerate(dsl_list): if task["name"] == "start": if edited_task.get("goto"):