diff --git a/nl2dsl/__init__.py b/nl2dsl/__init__.py index 01d56a8..574592d 100644 --- a/nl2dsl/__init__.py +++ b/nl2dsl/__init__.py @@ -2,7 +2,7 @@ from termcolor import cprint from typing import List, Dict, Callable -from .utils.mini_llm import call_llm_for_json +from .utils.mini_llm import call_llm_for_json, chat_completion_request from .utils.dsl_utils import ( update_flow, update_global_variables, diff --git a/nl2dsl/utils/dsl_utils.py b/nl2dsl/utils/dsl_utils.py index da6b3a8..5cc6de6 100644 --- a/nl2dsl/utils/dsl_utils.py +++ b/nl2dsl/utils/dsl_utils.py @@ -51,28 +51,36 @@ def update_flow(step, plugins={}, flow=[], debug=False): dsl_list.insert(-1, llm_response) elif step_type == "edit": edited_task = llm_response - edited = False - for i, task in enumerate(dsl_list): - if task["name"] == edited_task["name"]: - dsl_list[i] = edited_task - edited = True - break - if not edited: - dsl_list.insert(-1, edited_task) + if edited_task["name"] == "end": + cprint(f"Cannot edit end task.", "red") + + elif edited_task["name"] == "start": + for i, task in enumerate(dsl_list): + if task["name"] == "start": + if edited_task.get("goto"): + dsl_list[i]["goto"] = edited_task["goto"] + break + else: + edited = False + for i, task in enumerate(dsl_list): + if task["name"] == edited_task["name"]: + dsl_list[i] = edited_task + edited = True + break + if not edited: + dsl_list.insert(-1, edited_task) except TypeError as e: cprint(f"TypeError: {e}") cprint( f"It is likely that the assistant failed to return a valid json.", "red" ) - dsl_list = json.loads(flow) if step_type == "delete": - dsl_list = json.loads(flow) + dsl_list: list = json.loads(flow) delete_task_plan = step deleted = False if delete_task_plan["task_id"] == "start" or delete_task_plan["task_id"] == "end": cprint(f"Cannot delete start or end task.", "red") - dsl_list = json.loads(flow) else: for i, task in enumerate(dsl_list): if task["name"] == delete_task_plan["task_id"]: @@ -84,7 +92,6 @@ def update_flow(step, plugins={}, flow=[], debug=False): f"Task with id {delete_task_plan['task_id']} does not exist in the flow.", "red", ) - dsl_list = json.loads(flow) if debug: cprint(f"Intermediate DSL: {json.dumps(dsl_list, indent=4)}", "light_red")