diff --git a/lib/galaxy/managers/workflows.py b/lib/galaxy/managers/workflows.py index 077165451be7..b9241ea8bbcd 100644 --- a/lib/galaxy/managers/workflows.py +++ b/lib/galaxy/managers/workflows.py @@ -27,9 +27,12 @@ and_, desc, false, + func, or_, + select, true, ) +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import ( aliased, joinedload, @@ -320,6 +323,31 @@ def attach_stored_workflow(self, trans, workflow): trans.sa_session.commit() return stored_workflow + def get_workflow_by_trs_id_and_version(self, sa_session, trs_id: str, trs_version: str, user_id: int) -> Optional[model.Workflow]: + def to_json(column, keys: List[str]): + if sa_session.bind.dialect.name == "postgresql": + cast = func.cast(func.convert_from(column, "UTF8"), JSONB) + for key in keys: + cast = cast.__getitem__(key) + return cast.astext + else: + for key in keys: + column = column.__getitem__(key) + return column + + return sa_session.execute( + select([model.Workflow]) + .join(model.StoredWorkflow, model.Workflow.stored_workflow_id == model.StoredWorkflow.id) + .filter( + and_( + to_json(model.Workflow.source_metadata, ["trs_tool_id"]) == trs_id, + to_json(model.Workflow.source_metadata, ["trs_version_id"]) == trs_version, + model.StoredWorkflow.user_id == user_id, + model.StoredWorkflow.latest_workflow_id == model.Workflow.id, + ) + ) + ).scalar() + def get_owned_workflow(self, trans, encoded_workflow_id): """Get a workflow (non-stored) from a encoded workflow id and make sure it accessible to the user. diff --git a/lib/galaxy/model/custom_types.py b/lib/galaxy/model/custom_types.py index 8f292a792a43..f7e977e36b5d 100644 --- a/lib/galaxy/model/custom_types.py +++ b/lib/galaxy/model/custom_types.py @@ -14,6 +14,7 @@ from sqlalchemy.inspection import inspect from sqlalchemy.types import ( CHAR, + JSON, LargeBinary, String, TypeDecorator, @@ -94,20 +95,21 @@ class JSONType(TypeDecorator): cache_ok = True def process_bind_param(self, value, dialect): - if value is not None: + if value is not None and dialect.name in ("postgresql", "mysql"): value = json_encoder.encode(value).encode() return value def process_result_value(self, value, dialect): - if value is not None: + if value is not None and dialect.name == "postgresql": value = json_decoder.decode(unicodify(_sniffnfix_pg9_hex(value))) return value def load_dialect_impl(self, dialect): if dialect.name == "mysql": - return dialect.type_descriptor(sqlalchemy.dialects.mysql.MEDIUMBLOB) - else: - return self.impl + self.impl = dialect.type_descriptor(sqlalchemy.dialects.mysql.MEDIUMBLOB) + elif dialect.name == "sqlite": + self.impl = dialect.type_descriptor(sqlalchemy.dialects.sqlite.JSON) + return self.impl def copy_value(self, value): return copy.deepcopy(value) @@ -115,6 +117,11 @@ def copy_value(self, value): def compare_values(self, x, y): return x == y + @property + def comparator_factory(self): + """express comparison behavior in terms of the base type""" + return sqlalchemy.dialects.sqlite.JSON.comparator_factory + class MutableJSONType(JSONType): """Associated with MutationObj""" diff --git a/lib/galaxy/webapps/galaxy/api/workflows.py b/lib/galaxy/webapps/galaxy/api/workflows.py index 8c99c6cb983e..32648bebc0c1 100644 --- a/lib/galaxy/webapps/galaxy/api/workflows.py +++ b/lib/galaxy/webapps/galaxy/api/workflows.py @@ -289,6 +289,16 @@ def create(self, trans: GalaxyWebTransaction, payload=None, **kwd): trs_tool_id = payload.get("trs_tool_id") trs_version_id = payload.get("trs_version_id") + workflow = self.workflow_manager.get_workflow_by_trs_id_and_version( + trans.sa_session, trs_tool_id, trs_version_id, trans.user.id + ) + if workflow and workflow.stored_workflow: + return self.__import_response( + trans, + workflow, + workflow.stored_workflow.id, + message=f"Workflow '{escape(workflow.name)}' already imported.", + ) archive_data = server.get_version_descriptor(trs_tool_id, trs_version_id) else: try: @@ -622,6 +632,23 @@ def get_tool_predictions(self, trans: ProvidesUserContext, payload, **kwd): # # -- Helper methods -- # + def __import_response(self, trans: GalaxyWebTransaction, workflow: model.Workflow, workflow_id: str, message: str): + response = { + "message": message, + "status": "success", + "id": trans.security.encode_id(workflow_id), + } + if workflow.has_errors: + response["message"] = "Imported, but some steps in this workflow have validation errors." + response["status"] = "error" + elif len(workflow.steps) == 0: + response["message"] = "Imported, but this workflow has no steps." + response["status"] = "error" + elif workflow.has_cycles: + response["message"] = "Imported, but this workflow contains cycles." + response["status"] = "error" + return response + def __api_import_from_archive(self, trans: GalaxyWebTransaction, archive_data, source=None, payload=None): payload = payload or {} try: @@ -640,22 +667,12 @@ def __api_import_from_archive(self, trans: GalaxyWebTransaction, archive_data, s ) workflow_id = workflow.id workflow = workflow.latest_workflow - - response = { - "message": f"Workflow '{escape(workflow.name)}' imported successfully.", - "status": "success", - "id": trans.security.encode_id(workflow_id), - } - if workflow.has_errors: - response["message"] = "Imported, but some steps in this workflow have validation errors." - response["status"] = "error" - elif len(workflow.steps) == 0: - response["message"] = "Imported, but this workflow has no steps." - response["status"] = "error" - elif workflow.has_cycles: - response["message"] = "Imported, but this workflow contains cycles." - response["status"] = "error" - return response + self.__import_response( + trans, + workflow, + workflow_id=workflow_id, + message=f"Workflow '{escape(workflow.name)}' imported successfully.", + ) def __api_import_new_workflow(self, trans: GalaxyWebTransaction, payload, **kwd): data = payload["workflow"] diff --git a/test/unit/workflows/test_workflows_manager.py b/test/unit/workflows/test_workflows_manager.py new file mode 100644 index 000000000000..10ca66058c58 --- /dev/null +++ b/test/unit/workflows/test_workflows_manager.py @@ -0,0 +1,22 @@ +from galaxy import model + +from .workflow_support import MockApp + +TRS_TOOL_ID = "#the_id" +TRS_TOOL_VERSION = "v1" + + +def test_find_workflow_by_trs_id(): + app = MockApp() + with app.model.session.begin(): + w = model.Workflow() + + w.stored_workflow = model.StoredWorkflow() + w.stored_workflow.latest_workflow = w + u = model.User("test@test.com", "test", "test") + w.stored_workflow.user = u + w.source_metadata = {"trs_server": "dockstore", "trs_tool_id": TRS_TOOL_ID, "trs_version_id": TRS_TOOL_VERSION} + app.model.session.add(w) + app.model.session.commit() + app.model.session.flush() + assert app.workflow_manager.get_workflow_by_trs_id_and_version(app.model.session, TRS_TOOL_ID, TRS_TOOL_VERSION, u.id) == w