Skip to content

Commit

Permalink
Merge pull request #19332 from mvdbeek/fix_null_replacement
Browse files Browse the repository at this point in the history
[24.2] Fix workflows with optional non-default parameter input
  • Loading branch information
jmchilton authored Jan 7, 2025
2 parents 1d44a55 + 7a7836e commit f8ae3ce
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 65 deletions.
11 changes: 6 additions & 5 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8108,7 +8108,7 @@ class WorkflowStep(Base, RepresentById, UsesCreateAndUpdateTime):
tool_errors: Mapped[Optional[bytes]] = mapped_column(JSONType)
position: Mapped[Optional[bytes]] = mapped_column(MutableJSONType)
config: Mapped[Optional[bytes]] = mapped_column(JSONType)
order_index: Mapped[Optional[int]]
order_index: Mapped[int]
when_expression: Mapped[Optional[bytes]] = mapped_column(JSONType)
uuid: Mapped[Optional[Union[UUID, str]]] = mapped_column(UUIDType)
label: Mapped[Optional[str]] = mapped_column(Unicode(255))
Expand Down Expand Up @@ -8206,17 +8206,18 @@ def setup_inputs_by_name(self):
# Ensure input_connections has already been set.

# Make connection information available on each step by input name.
inputs_by_name = {}
inputs_by_name: Dict[str, Any] = {}
for step_input in self.inputs:
input_name = step_input.name
assert input_name not in inputs_by_name
inputs_by_name[input_name] = step_input
self._inputs_by_name = inputs_by_name
return inputs_by_name

@property
def inputs_by_name(self):
if self._inputs_by_name is None:
self.setup_inputs_by_name()
return self.setup_inputs_by_name()
return self._inputs_by_name

def get_input(self, input_name):
Expand Down Expand Up @@ -8695,7 +8696,7 @@ class WorkflowInvocation(Base, UsesCreateAndUpdateTime, Dictifiable, Serializabl
state: Mapped[Optional[str]] = mapped_column(TrimmedString(64), index=True)
scheduler: Mapped[Optional[str]] = mapped_column(TrimmedString(255), index=True)
handler: Mapped[Optional[str]] = mapped_column(TrimmedString(255), index=True)
uuid: Mapped[Optional[Union[UUID, str]]] = mapped_column(UUIDType())
uuid: Mapped[Optional[Union[UUID]]] = mapped_column(UUIDType())
history_id: Mapped[Optional[int]] = mapped_column(ForeignKey("history.id"), index=True)

history = relationship("History", back_populates="workflow_invocations")
Expand Down Expand Up @@ -9420,7 +9421,7 @@ class WorkflowInvocationStep(Base, Dictifiable, Serializable):
)
action: Mapped[Optional[bytes]] = mapped_column(MutableJSONType)

workflow_step = relationship("WorkflowStep")
workflow_step: Mapped[WorkflowStep] = relationship("WorkflowStep")
job: Mapped[Optional["Job"]] = relationship(back_populates="workflow_invocation_step", uselist=False)
implicit_collection_jobs = relationship("ImplicitCollectionJobs", uselist=False)
output_dataset_collections = relationship(
Expand Down
2 changes: 2 additions & 0 deletions lib/galaxy/model/store/ro_crate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def _add_step_parameter_pv(self, step: WorkflowInvocationStep, crate: ROCrate):

def _add_step_parameter_fp(self, step: WorkflowInvocationStep, crate: ROCrate):
param_id = step.workflow_step.label
assert step.workflow_step.tool_inputs
param_type = step.workflow_step.tool_inputs["parameter_type"]
return crate.add(
ContextEntity(
Expand All @@ -375,6 +376,7 @@ def _add_step_parameter_fp(self, step: WorkflowInvocationStep, crate: ROCrate):

def _add_step_tool_pv(self, step: WorkflowInvocationStep, tool_input: str, crate: ROCrate):
param_id = tool_input
assert step.workflow_step.tool_inputs
return crate.add(
ContextEntity(
crate,
Expand Down
27 changes: 25 additions & 2 deletions lib/galaxy/tools/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from .workflow_utils import (
is_runtime_value,
NO_REPLACEMENT,
runtime_to_json,
)
from .wrapped import flat_to_nested_state
Expand Down Expand Up @@ -174,14 +175,36 @@ def callback_helper(input, input_values, name_prefix, label_prefix, parent_prefi
if input.name not in input_values:
args["error"] = f"No value found for '{args.get('prefixed_label')}'."
new_value = callback(**args)

# is this good enough ? feels very ugh
if new_value == [no_replacement_value]:
# Single unspecified value in multiple="true" input with a single null input, pretend it's a singular value
new_value = no_replacement_value
if isinstance(new_value, list):
# Maybe mixed input, I guess tool defaults don't really make sense here ?
# Would e.g. be default dataset in multiple="true" input, you wouldn't expect the default to be inserted
# if other inputs are connected and provided.
new_value = [item if not item == no_replacement_value else None for item in new_value]

if no_replacement_value is REPLACE_ON_TRUTHY:
replace = bool(new_value)
else:
replace = new_value != no_replacement_value
if replace:
input_values[input.name] = new_value
elif replace_optional_connections and is_runtime_value(value) and hasattr(input, "value"):
input_values[input.name] = input.value
elif replace_optional_connections:
# Only used in workflow context
has_default = hasattr(input, "value")
if new_value is value is NO_REPLACEMENT or is_runtime_value(value):
# NO_REPLACEMENT means value was connected but left unspecified
if has_default:
# Use default if we have one
input_values[input.name] = input.value
else:
# Should fail if input is not optional and does not have default value
# Effectively however depends on parameter implementation.
# We might want to raise an exception here, instead of depending on a tool parameter value error.
input_values[input.name] = None

def get_current_case(input, input_values):
test_parameter = input.test_param
Expand Down
14 changes: 11 additions & 3 deletions lib/galaxy/tools/parameters/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
ParameterParseException,
text_input_is_optional,
)
from galaxy.tools.parameters.workflow_utils import workflow_building_modes
from galaxy.tools.parameters.workflow_utils import (
NO_REPLACEMENT,
workflow_building_modes,
)
from galaxy.util import (
sanitize_param,
string_as_bool,
Expand Down Expand Up @@ -247,6 +250,8 @@ def to_python(self, value, app):
def value_to_basic(self, value, app, use_security=False):
if is_runtime_value(value):
return runtime_to_json(value)
elif value == NO_REPLACEMENT:
return {"__class__": "NoReplacement"}
return self.to_json(value, app, use_security)

def value_from_basic(self, value, app, ignore_errors=False):
Expand All @@ -255,8 +260,11 @@ def value_from_basic(self, value, app, ignore_errors=False):
if isinstance(self, HiddenToolParameter):
raise ParameterValueError(message_suffix="Runtime Parameter not valid", parameter_name=self.name)
return runtime_to_object(value)
elif isinstance(value, MutableMapping) and value.get("__class__") == "UnvalidatedValue":
return value["value"]
elif isinstance(value, MutableMapping):
if value.get("__class__") == "UnvalidatedValue":
return value["value"]
elif value.get("__class__") == "NoReplacement":
return NO_REPLACEMENT
# Delegate to the 'to_python' method
if ignore_errors:
try:
Expand Down
9 changes: 9 additions & 0 deletions lib/galaxy/tools/parameters/workflow_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from collections.abc import MutableMapping


class NoReplacement:

def __str__(self):
return "NO_REPLACEMENT singleton"


NO_REPLACEMENT = NoReplacement()


class workflow_building_modes:
DISABLED = False
ENABLED = True
Expand Down
56 changes: 31 additions & 25 deletions lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@
from galaxy.tools.parameters.workflow_utils import (
ConnectedValue,
is_runtime_value,
NO_REPLACEMENT,
NoReplacement,
runtime_to_json,
workflow_building_modes,
)
Expand Down Expand Up @@ -129,14 +131,6 @@
POSSIBLE_PARAMETER_TYPES: Tuple[INPUT_PARAMETER_TYPES] = get_args(INPUT_PARAMETER_TYPES)


class NoReplacement:
def __str__(self):
return "NO_REPLACEMENT singleton"


NO_REPLACEMENT = NoReplacement()


class ConditionalStepWhen(BooleanToolParameter):
pass

Expand Down Expand Up @@ -416,14 +410,11 @@ def get_runtime_inputs(self, step, connections: Optional[Iterable[WorkflowStepCo
"""
return {}

def compute_runtime_state(self, trans, step=None, step_updates=None):
def compute_runtime_state(self, trans, step=None, step_updates=None, replace_default_values=False):
"""Determine the runtime state (potentially different from self.state
which describes configuration state). This (again unlike self.state) is
currently always a `DefaultToolState` object.
If `step` is not `None`, it will be used to search for default values
defined in workflow input steps.
If `step_updates` is `None`, this is likely for rendering the run form
for instance and no runtime properties are available and state must be
solely determined by the default runtime state described by the step.
Expand All @@ -432,6 +423,8 @@ def compute_runtime_state(self, trans, step=None, step_updates=None):
supplied by the workflow runner.
"""
state = self.get_runtime_state()
if replace_default_values and step:
state.inputs = step.state.inputs
step_errors = {}

if step is not None:
Expand All @@ -441,8 +434,11 @@ def update_value(input, context, prefixed_name, **kwargs):
if step_input is None:
return NO_REPLACEMENT

if step_input.default_value_set:
return step_input.default_value
if replace_default_values and step_input.default_value_set:
input_value = step_input.default_value
if isinstance(input, BaseDataToolParameter):
input_value = raw_to_galaxy(trans.app, trans.history, input_value)
return input_value

return NO_REPLACEMENT

Expand All @@ -469,7 +465,7 @@ def update_value(input, context, prefixed_name, **kwargs):

return state, step_errors

def encode_runtime_state(self, step, runtime_state):
def encode_runtime_state(self, step, runtime_state: DefaultToolState):
"""Takes the computed runtime state and serializes it during run request creation."""
return runtime_state.encode(Bunch(inputs=self.get_runtime_inputs(step)), self.trans.app)

Expand Down Expand Up @@ -954,7 +950,7 @@ class InputModule(WorkflowModule):

def get_runtime_state(self):
state = DefaultToolState()
state.inputs = dict(input=None)
state.inputs = dict(input=NO_REPLACEMENT)
return state

def get_all_inputs(self, data_only=False, connectable_only=False):
Expand All @@ -966,7 +962,7 @@ def execute(
invocation = invocation_step.workflow_invocation
step = invocation_step.workflow_step
input_value = step.state.inputs["input"]
if input_value is None:
if input_value is NO_REPLACEMENT:
default_value = step.get_input_default_value(NO_REPLACEMENT)
if default_value is not NO_REPLACEMENT:
input_value = raw_to_galaxy(trans.app, trans.history, default_value)
Expand All @@ -993,7 +989,7 @@ def execute(
# everything should come in from the API and this can be eliminated.
if not invocation.has_input_for_step(step.id):
content = next(iter(step_outputs.values()))
if content:
if content and content is not NO_REPLACEMENT:
invocation.add_input(content, step.id)
progress.set_outputs_for_input(invocation_step, step_outputs)
return None
Expand Down Expand Up @@ -1582,7 +1578,7 @@ def _parameter_def_list_to_options(parameter_value):

def get_runtime_state(self):
state = DefaultToolState()
state.inputs = dict(input=None)
state.inputs = dict(input=NO_REPLACEMENT)
return state

def get_all_outputs(self, data_only=False):
Expand All @@ -1609,7 +1605,7 @@ def execute(
input_value = progress.inputs_by_step_id[step.id]
else:
input_value = step.state.inputs["input"]
if input_value is None:
if input_value is NO_REPLACEMENT:
default_value = step.get_input_default_value(NO_REPLACEMENT)
# TODO: look at parameter type and infer if value should be a dictionary
# instead. Guessing only field parameter types in CWL branch would have
Expand Down Expand Up @@ -2233,13 +2229,14 @@ def get_runtime_state(self):
def get_runtime_inputs(self, step, connections: Optional[Iterable[WorkflowStepConnection]] = None):
return self.get_inputs()

def compute_runtime_state(self, trans, step=None, step_updates=None):
def compute_runtime_state(self, trans, step=None, step_updates=None, replace_default_values=False):
# Warning: This method destructively modifies existing step state.
if self.tool:
step_errors = {}
state = self.state
self.runtime_post_job_actions = {}
state, step_errors = super().compute_runtime_state(trans, step, step_updates)
state, step_errors = super().compute_runtime_state(
trans, step, step_updates, replace_default_values=replace_default_values
)
if step_updates:
self.runtime_post_job_actions = step_updates.get(RUNTIME_POST_JOB_ACTIONS_KEY, {})
step_metadata_runtime_state = self.__step_meta_runtime_state()
Expand All @@ -2266,7 +2263,11 @@ def decode_runtime_state(self, step, runtime_state):
)

def execute(
self, trans, progress: "WorkflowProgress", invocation_step, use_cached_job: bool = False
self,
trans,
progress: "WorkflowProgress",
invocation_step: "WorkflowInvocationStep",
use_cached_job: bool = False,
) -> Optional[bool]:
invocation = invocation_step.workflow_invocation
step = invocation_step.workflow_step
Expand All @@ -2275,7 +2276,12 @@ def execute(
# TODO: why do we even create an invocation, seems like something we could check on submit?
message = f"Specified tool [{tool.id}] in step {step.order_index + 1} is not workflow-compatible."
raise exceptions.MessageException(message)
self.state, _ = self.compute_runtime_state(
trans, step, step_updates=progress.param_map.get(step.id), replace_default_values=True
)
step.state = self.state
tool_state = step.state
assert tool_state is not None
tool_inputs = tool.inputs.copy()
# Not strictly needed - but keep Tool state clean by stripping runtime
# metadata parameters from it.
Expand Down Expand Up @@ -2404,7 +2410,7 @@ def callback(input, prefixed_name: str, **kwargs):
mapping_params=mapping_params,
history=invocation.history,
collection_info=collection_info,
workflow_invocation_uuid=invocation.uuid.hex,
workflow_invocation_uuid=invocation.uuid.hex if invocation.uuid else None,
invocation_step=invocation_step,
max_num_jobs=max_num_jobs,
validate_outputs=validate_outputs,
Expand Down
6 changes: 2 additions & 4 deletions lib/galaxy/workflow/refactor/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from galaxy.tools.parameters.basic import contains_workflow_parameter
from galaxy.tools.parameters.workflow_utils import (
ConnectedValue,
NO_REPLACEMENT,
runtime_to_json,
)
from .schema import (
Expand Down Expand Up @@ -41,10 +42,7 @@
UpgradeSubworkflowAction,
UpgradeToolAction,
)
from ..modules import (
InputParameterModule,
NO_REPLACEMENT,
)
from ..modules import InputParameterModule

log = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit f8ae3ce

Please sign in to comment.