Skip to content

Commit

Permalink
Don't re-import TRS workflow if it already exists
Browse files Browse the repository at this point in the history
Works for postgres and sqlite
  • Loading branch information
mvdbeek committed Jul 16, 2023
1 parent 58d596f commit c6cef1d
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 21 deletions.
28 changes: 28 additions & 0 deletions lib/galaxy/managers/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
and_,
desc,
false,
func,
or_,
select,
true,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import (
aliased,
joinedload,
Expand Down Expand Up @@ -320,6 +323,31 @@ def attach_stored_workflow(self, trans, workflow):
trans.sa_session.commit()
return stored_workflow

def get_workflow_by_trs_id_and_version(self, sa_session, trs_id: str, trs_version: str, user_id: int) -> Optional[model.Workflow]:
def to_json(column, keys: List[str]):
if sa_session.bind.dialect.name == "postgresql":
cast = func.cast(func.convert_from(column, "UTF8"), JSONB)
for key in keys:
cast = cast.__getitem__(key)
return cast.astext
else:
for key in keys:
column = column.__getitem__(key)
return column

return sa_session.execute(
select([model.Workflow])
.join(model.StoredWorkflow, model.Workflow.stored_workflow_id == model.StoredWorkflow.id)
.filter(
and_(
to_json(model.Workflow.source_metadata, ["trs_tool_id"]) == trs_id,
to_json(model.Workflow.source_metadata, ["trs_version_id"]) == trs_version,
model.StoredWorkflow.user_id == user_id,
model.StoredWorkflow.latest_workflow_id == model.Workflow.id,
)
)
).scalar()

def get_owned_workflow(self, trans, encoded_workflow_id):
"""Get a workflow (non-stored) from a encoded workflow id and
make sure it accessible to the user.
Expand Down
17 changes: 12 additions & 5 deletions lib/galaxy/model/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sqlalchemy.inspection import inspect
from sqlalchemy.types import (
CHAR,
JSON,
LargeBinary,
String,
TypeDecorator,
Expand Down Expand Up @@ -94,27 +95,33 @@ class JSONType(TypeDecorator):
cache_ok = True

def process_bind_param(self, value, dialect):
if value is not None:
if value is not None and dialect.name in ("postgresql", "mysql"):
value = json_encoder.encode(value).encode()
return value

def process_result_value(self, value, dialect):
if value is not None:
if value is not None and dialect.name == "postgresql":
value = json_decoder.decode(unicodify(_sniffnfix_pg9_hex(value)))
return value

def load_dialect_impl(self, dialect):
if dialect.name == "mysql":
return dialect.type_descriptor(sqlalchemy.dialects.mysql.MEDIUMBLOB)
else:
return self.impl
self.impl = dialect.type_descriptor(sqlalchemy.dialects.mysql.MEDIUMBLOB)
elif dialect.name == "sqlite":
self.impl = dialect.type_descriptor(sqlalchemy.dialects.sqlite.JSON)
return self.impl

def copy_value(self, value):
return copy.deepcopy(value)

def compare_values(self, x, y):
return x == y

@property
def comparator_factory(self):
"""express comparison behavior in terms of the base type"""
return sqlalchemy.dialects.sqlite.JSON.comparator_factory


class MutableJSONType(JSONType):
"""Associated with MutationObj"""
Expand Down
49 changes: 33 additions & 16 deletions lib/galaxy/webapps/galaxy/api/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,16 @@ def create(self, trans: GalaxyWebTransaction, payload=None, **kwd):
trs_tool_id = payload.get("trs_tool_id")
trs_version_id = payload.get("trs_version_id")

workflow = self.workflow_manager.get_workflow_by_trs_id_and_version(
trans.sa_session, trs_tool_id, trs_version_id, trans.user.id
)
if workflow and workflow.stored_workflow:
return self.__import_response(
trans,
workflow,
workflow.stored_workflow.id,
message=f"Workflow '{escape(workflow.name)}' already imported.",
)
archive_data = server.get_version_descriptor(trs_tool_id, trs_version_id)
else:
try:
Expand Down Expand Up @@ -622,6 +632,23 @@ def get_tool_predictions(self, trans: ProvidesUserContext, payload, **kwd):
#
# -- Helper methods --
#
def __import_response(self, trans: GalaxyWebTransaction, workflow: model.Workflow, workflow_id: str, message: str):
response = {
"message": message,
"status": "success",
"id": trans.security.encode_id(workflow_id),
}
if workflow.has_errors:
response["message"] = "Imported, but some steps in this workflow have validation errors."
response["status"] = "error"
elif len(workflow.steps) == 0:
response["message"] = "Imported, but this workflow has no steps."
response["status"] = "error"
elif workflow.has_cycles:
response["message"] = "Imported, but this workflow contains cycles."
response["status"] = "error"
return response

def __api_import_from_archive(self, trans: GalaxyWebTransaction, archive_data, source=None, payload=None):
payload = payload or {}
try:
Expand All @@ -640,22 +667,12 @@ def __api_import_from_archive(self, trans: GalaxyWebTransaction, archive_data, s
)
workflow_id = workflow.id
workflow = workflow.latest_workflow

response = {
"message": f"Workflow '{escape(workflow.name)}' imported successfully.",
"status": "success",
"id": trans.security.encode_id(workflow_id),
}
if workflow.has_errors:
response["message"] = "Imported, but some steps in this workflow have validation errors."
response["status"] = "error"
elif len(workflow.steps) == 0:
response["message"] = "Imported, but this workflow has no steps."
response["status"] = "error"
elif workflow.has_cycles:
response["message"] = "Imported, but this workflow contains cycles."
response["status"] = "error"
return response
self.__import_response(
trans,
workflow,
workflow_id=workflow_id,
message=f"Workflow '{escape(workflow.name)}' imported successfully.",
)

def __api_import_new_workflow(self, trans: GalaxyWebTransaction, payload, **kwd):
data = payload["workflow"]
Expand Down
22 changes: 22 additions & 0 deletions test/unit/workflows/test_workflows_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from galaxy import model

from .workflow_support import MockApp

TRS_TOOL_ID = "#the_id"
TRS_TOOL_VERSION = "v1"


def test_find_workflow_by_trs_id():
app = MockApp()
with app.model.session.begin():
w = model.Workflow()

w.stored_workflow = model.StoredWorkflow()
w.stored_workflow.latest_workflow = w
u = model.User("test@test.com", "test", "test")
w.stored_workflow.user = u
w.source_metadata = {"trs_server": "dockstore", "trs_tool_id": TRS_TOOL_ID, "trs_version_id": TRS_TOOL_VERSION}
app.model.session.add(w)
app.model.session.commit()
app.model.session.flush()
assert app.workflow_manager.get_workflow_by_trs_id_and_version(app.model.session, TRS_TOOL_ID, TRS_TOOL_VERSION, u.id) == w

0 comments on commit c6cef1d

Please sign in to comment.