From 6638cdee47c1b5efd3363b2963d63c6e79b2d271 Mon Sep 17 00:00:00 2001 From: Eric Hulburd Date: Mon, 26 Aug 2024 18:19:11 -0700 Subject: [PATCH] feat(quil-py): support extern call instructions --- Cargo.lock | 6 +- quil-py/pyrightconfig.json | 4 + quil-py/quil/instructions/__init__.pyi | 338 +++ quil-py/quil/program/__init__.pyi | 15 + quil-py/src/instruction/declaration.rs | 16 + quil-py/src/instruction/extern_call.rs | 198 ++ quil-py/src/instruction/mod.rs | 19 +- quil-py/src/instruction/pragma.rs | 13 +- quil-py/src/program/mod.rs | 7 + quil-py/test/instructions/test_extern_call.py | 66 + quil-rs/src/expression/mod.rs | 2 +- quil-rs/src/instruction/extern_call.rs | 2060 +++++++++++++++++ quil-rs/src/instruction/mod.rs | 24 +- quil-rs/src/instruction/pragma.rs | 25 +- quil-rs/src/parser/command.rs | 192 +- quil-rs/src/parser/common.rs | 25 +- quil-rs/src/parser/error/mod.rs | 5 + quil-rs/src/parser/expression.rs | 10 +- quil-rs/src/parser/instruction.rs | 1 + quil-rs/src/parser/lexer/mod.rs | 5 +- quil-rs/src/parser/mod.rs | 3 +- quil-rs/src/parser/reserved_pragma_extern.rs | 335 +++ quil-rs/src/parser/token.rs | 3 + .../program/analysis/control_flow_graph.rs | 5 +- quil-rs/src/program/memory.rs | 73 +- quil-rs/src/program/mod.rs | 148 +- quil-rs/src/program/scheduling/graph.rs | 10 +- 27 files changed, 3546 insertions(+), 62 deletions(-) create mode 100644 quil-py/pyrightconfig.json create mode 100644 quil-py/src/instruction/extern_call.rs create mode 100644 quil-py/test/instructions/test_extern_call.py create mode 100644 quil-rs/src/instruction/extern_call.rs create mode 100644 quil-rs/src/parser/reserved_pragma_extern.rs diff --git a/Cargo.lock b/Cargo.lock index 8b152af7..49c0d401 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1132,7 +1132,7 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quil-cli" -version = "0.4.0" +version = "0.4.1" dependencies = [ "anyhow", "clap", @@ -1141,7 +1141,7 @@ dependencies = [ [[package]] name = "quil-py" -version = "0.11.0" +version = "0.11.2" dependencies = [ "indexmap", "ndarray", @@ -1155,7 +1155,7 @@ dependencies = [ [[package]] name = "quil-rs" -version = "0.27.0" +version = "0.27.1" dependencies = [ "approx", "clap", diff --git a/quil-py/pyrightconfig.json b/quil-py/pyrightconfig.json new file mode 100644 index 00000000..2d49b3a6 --- /dev/null +++ b/quil-py/pyrightconfig.json @@ -0,0 +1,4 @@ +{ + "venv": ".venv", + "venvPath": "." +} diff --git a/quil-py/quil/instructions/__init__.pyi b/quil-py/quil/instructions/__init__.pyi index 545472d1..7af547f8 100644 --- a/quil-py/quil/instructions/__init__.pyi +++ b/quil-py/quil/instructions/__init__.pyi @@ -79,6 +79,7 @@ class Instruction: Calibration, Capture, BinaryLogic, + Call, CircuitDefinition, Convert, Comparison, @@ -101,6 +102,7 @@ class Instruction: Pragma, Pulse, RawCapture, + ReservedPragma, Reset, SetFrequency, SetPhase, @@ -131,6 +133,7 @@ class Instruction: Capture, BinaryLogic, CircuitDefinition, + Call, Convert, Comparison, Declaration, @@ -152,6 +155,7 @@ class Instruction: Pragma, Pulse, RawCapture, + ReservedPragma, Reset, SetFrequency, SetPhase, @@ -168,6 +172,7 @@ class Instruction: def is_arithmetic(self) -> bool: ... def is_binary_logic(self) -> bool: ... def is_calibration_definition(self) -> bool: ... + def is_call(self) -> bool: ... def is_capture(self) -> bool: ... def is_circuit_definition(self) -> bool: ... def is_convert(self) -> bool: ... @@ -193,6 +198,7 @@ class Instruction: def is_pragma(self) -> bool: ... def is_pulse(self) -> bool: ... def is_raw_capture(self) -> bool: ... + def is_reserved_pragma(self) -> bool: ... def is_reset(self) -> bool: ... def is_set_frequency(self) -> bool: ... def is_set_phase(self) -> bool: ... @@ -217,6 +223,8 @@ class Instruction: @staticmethod def from_calibration_definition(inner: Calibration) -> Instruction: ... @staticmethod + def from_call(inner: Call) -> Instruction: ... + @staticmethod def from_capture(inner: Capture) -> Instruction: ... @staticmethod def from_circuit_definition( @@ -269,6 +277,8 @@ class Instruction: @staticmethod def from_raw_capture(inner: RawCapture) -> Instruction: ... @staticmethod + def from_reserved_pragma(inner: ReservedPragma) -> Instruction: ... + @staticmethod def from_set_frequency(inner: SetFrequency) -> Instruction: ... @staticmethod def from_set_phase(inner: SetPhase) -> Instruction: ... @@ -294,6 +304,8 @@ class Instruction: def to_arithmetic(self) -> Arithmetic: ... def as_binary_logic(self) -> Optional[BinaryLogic]: ... def to_binary_logic(self) -> BinaryLogic: ... + def as_call(self) -> Optional[Call]: ... + def to_call(self) -> Call: ... def as_convert(self) -> Optional[Convert]: ... def to_convert(self) -> Convert: ... def as_comparison(self) -> Optional[Comparison]: ... @@ -346,6 +358,8 @@ class Instruction: def to_pulse(self) -> Pulse: ... def as_raw_capture(self) -> Optional[RawCapture]: ... def to_raw_capture(self) -> RawCapture: ... + def as_reserved_pragma(self) -> Optional[ReservedPragma]: ... + def to_reserved_pragma(self) -> ReservedPragma: ... def as_reset(self) -> Optional[Reset]: ... def to_reset(self) -> Reset: ... def as_set_frequency(self) -> Optional[SetFrequency]: ... @@ -1175,6 +1189,303 @@ class FrameIdentifier: If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. """ +@final +class CallArgument: + """An argument to a `Call` instruction. + + This may be expressed as an identifier, a memory reference, or an immediate value. Memory references and identifiers require a corresponding memory region declaration by the time of + compilation (at the time of call argument resolution and memory graph construction to be more precise). + + Additionally, an argument's resolved type must match the expected type of the corresponding `ExternParameter` + in the `Externsignature`. + """ + def inner(self) -> Union[List[List[Expression]], List[int], PauliSum]: + """Returns the inner value of the variant. Raises a ``RuntimeError`` if inner data doesn't exist.""" + ... + def is_identifier(self) -> bool: ... + def is_memory_reference(self) -> bool: ... + def is_immediate(self) -> bool: ... + def as_identifier(self) -> Optional[str]: ... + def to_identifier(self) -> str: ... + def as_memory_reference(self) -> Optional["MemoryReference"]: ... + def to_memory_reference(self) -> "MemoryReference": ... + def as_immediate(self) -> Optional[complex]: ... + def to_immediate(self) -> complex: ... + @staticmethod + def from_identifier(inner: str) -> "CallArgument": ... + @staticmethod + def from_memory_reference(inner: "MemoryReference") -> "CallArgument": ... + @staticmethod + def from_immediate(inner: complex) -> "CallArgument": ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + +@final +class CallArguments: + """A list of `CallArgument`s for a single `CALL`. + + This abstract is necessary for the inner workings of the `CallArgument` type resolution. + """ + + def inner(self) -> Union[List[List[Expression]], List[int], PauliSum]: + """Returns the inner value of the variant. Raises a ``RuntimeError`` if inner data doesn't exist.""" + ... + def is_arguments(self) -> bool: ... + def as_arguments(self) -> Optional[List["CallArgument"]]: ... + def to_arguments(self) -> List["CallArgument"]: ... + @staticmethod + def from_arguments(inner: List["CallArgument"]) -> "CallArguments": ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + +class CallValidationError(ValueError): + """An error that may occur when performing operations on a ``Call``.""" + + ... + +class Call: + """An instruction to an external function declared within an `ExternDefinition`. + + These calls are generally specific to a particular hardware or virtual machine + backend. For further detail, see: + + * `Other instructions and Directives `_ + in the Quil specification. + * `EXTERN / CALL RFC `_ + * `quil#87 `_ + + + Also see `ExternDefinition`. + """ + def __new__( + cls, + name: str, + arguments: List["CallArgument"], + ) -> Self: ... + @property + def name(self) -> str: ... + @name.setter + def name(self, name: str) -> None: ... + @property + def arguments(self) -> "CallArguments": ... + @arguments.setter + def arguments(self, arguments: "CallArguments") -> None: ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + def __deepcopy__(self, _: Dict) -> Self: + """Creates and returns a deep copy of the class. + + If the instruction contains any ``QubitPlaceholder`` or ``TargetPlaceholder``, then they will be replaced with + new placeholders so resolving them in the copy will not resolve them in the original. + Should be used by passing an instance of the class to ``copy.deepcopy`` + """ + def __copy__(self) -> Self: + """Returns a shallow copy of the class.""" + +@final +class ExternParameterType: + """The type of an `ExternParameter`. + + This type is used to define the expected type of the parameter in the `ExternSignature`. May be + either a scalar, fixed-length vector, or variable-length vector. + + Note, both scalars and fixed-length vectors are fully specified by a `ScalarType`, but are indeed + distinct types. + """ + + def inner(self) -> Union["ScalarType", "Vector"]: + """Returns the inner value of the variant. Raises a ``RuntimeError`` if inner data doesn't exist.""" + ... + def is_scalar(self) -> bool: ... + def is_fixed_length_vector(self) -> bool: ... + def is_variable_length_vector(self) -> bool: ... + def as_scalar(self) -> Optional["ScalarType"]: ... + def to_scalar(self) -> "ScalarType": ... + def as_fixed_length_vector(self) -> Optional["Vector"]: ... + def to_fixed_length_vector(self) -> "MemoryReference": ... + def as_variable_length_vector(self) -> Optional["ScalarType"]: ... + def to_variable_length_vector(self) -> complex: ... + @staticmethod + def from_scalar(inner: "ScalarType") -> "ExternParameterType": ... + @staticmethod + def from_fixed_length_vector(inner: "Vector") -> "ExternParameterType": ... + @staticmethod + def from_variable_length_vector(inner: "ScalarType") -> "ExternParameterType": ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + +class ExternParameter: + """A parameter within an `ExternSignature`. These are defined by a name, mutability, and type.""" + def __new__( + cls, + name: str, + mutable: bool, + data_type: ExternParameterType, + ) -> Self: ... + @property + def name(self) -> str: ... + @name.setter + def name(self, name: str) -> None: ... + @property + def mutable(self) -> bool: ... + @mutable.setter + def mutable(self, mutable: bool) -> None: ... + @property + def data_type(self) -> ExternParameterType: ... + @data_type.setter + def data_type(self, data_type: ExternParameterType) -> None: ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + def __deepcopy__(self, _: Dict) -> Self: + """Creates and returns a deep copy of the class. + + If the instruction contains any ``QubitPlaceholder`` or ``TargetPlaceholder``, then they will be replaced with + new placeholders so resolving them in the copy will not resolve them in the original. + Should be used by passing an instance of the class to ``copy.deepcopy`` + """ + def __copy__(self) -> Self: + """Returns a shallow copy of the class.""" + +class ExternSignature: + """The signature of an `ExternDefinition`. + + This signature is defined by a list of `ExternParameter`s and an + optional return type. See the `Quil Specification `_ + for details on how these signatures are formed. + """ + def __new__( + cls, + parameters: List[ExternParameter], + return_type: Optional[ScalarType], + ) -> Self: ... + @property + def parameters(self) -> List[ExternParameter]: ... + @parameters.setter + def parameters(self, parameters: str) -> None: ... + @property + def return_type(self) -> Optional[ScalarType]: ... + @return_type.setter + def return_type(self, return_type: ScalarType) -> None: ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + def __deepcopy__(self, _: Dict) -> Self: + """Creates and returns a deep copy of the class. + + If the instruction contains any ``QubitPlaceholder`` or ``TargetPlaceholder``, then they will be replaced with + new placeholders so resolving them in the copy will not resolve them in the original. + Should be used by passing an instance of the class to ``copy.deepcopy`` + """ + def __copy__(self) -> Self: + """Returns a shallow copy of the class.""" + +class ExternValidationError(ValueError): + """An error that may occur when initializing or validating an ``ExternDefinition``.""" + + ... + +class ExternDefinition: + """An external function declaration. + + These are generally specific to a particular hardware or virtual machine backend. Note, + these are not standard Quil instructions, but rather a type of `ReservedPragma`. + + For further detail, see: + + * `Other instructions and Directives `_ + in the Quil specification. + * `EXTERN / CALL RFC `_ + * `quil#87 `_ + + Also see `Call`. + """ + def __new__( + cls, + name: str, + signature: Optional[ExternSignature], + ) -> Self: ... + @property + def name(self) -> str: ... + @name.setter + def name(self, name: str) -> None: ... + @property + def signature(self) -> Optional[ExternSignature]: ... + @signature.setter + def signature(self, signature: ExternSignature) -> None: ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + def __deepcopy__(self, _: Dict) -> Self: + """Creates and returns a deep copy of the class. + + If the instruction contains any ``QubitPlaceholder`` or ``TargetPlaceholder``, then they will be replaced with + new placeholders so resolving them in the copy will not resolve them in the original. + Should be used by passing an instance of the class to ``copy.deepcopy`` + """ + def __copy__(self) -> Self: + """Returns a shallow copy of the class.""" + class Capture: def __new__( cls, @@ -1883,6 +2194,33 @@ class PragmaArgument: If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. """ +@final +class ReservedPragma: + """An instruction represented by as a specially structured `Pragma`. + + Currently, this may only be an `ExternDefinition`. + """ + + def inner(self) -> Union["ExternDefinition"]: # type: ignore + """Returns the inner value of the variant. Raises a ``RuntimeError`` if inner data doesn't exist.""" + ... + def is_extern_definition(self) -> bool: ... + def as_extern_definition(self) -> Optional[ExternDefinition]: ... + def to_extern_definition(self) -> ExternDefinition: ... + @staticmethod + def from_extern_definition(inner: ExternDefinition) -> "ReservedPragma": ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + class Include: def __new__(cls, filename: str) -> Self: ... @property diff --git a/quil-py/quil/program/__init__.pyi b/quil-py/quil/program/__init__.pyi index 5478a841..068865c8 100644 --- a/quil-py/quil/program/__init__.pyi +++ b/quil-py/quil/program/__init__.pyi @@ -141,6 +141,21 @@ class Program: while ensuring that unique value is not already in use within the program. """ ... + def resolve_call_instructions(self) -> None: + """Resolve ``CALL`` instructions within the program. + + This does two things: + + 1. It validates all ``EXTERN`` instructions in the program. + 2. It resolves the ``CALL`` instruction data types based on the declared memory types + and ensures those match one and only one ``EXTERN`` instruction. + + This method is only useful for validating program structure before it is compiled. Note, + this mutates the inner representation of the program. + + :raises ProgramError: If any of the above conditions are not met. + """ + ... def resolve_placeholders_with_custom_resolvers( self, *, diff --git a/quil-py/src/instruction/declaration.rs b/quil-py/src/instruction/declaration.rs index 5d165bd6..256381f8 100644 --- a/quil-py/src/instruction/declaration.rs +++ b/quil-py/src/instruction/declaration.rs @@ -38,6 +38,22 @@ impl_repr!(PyScalarType); impl_to_quil!(PyScalarType); impl_hash!(PyScalarType); +impl rigetti_pyo3::PyTryFrom for PyScalarType { + fn py_try_from(_py: Python, item: &pyo3::PyAny) -> PyResult { + let item = item.extract::()?; + match item.as_str() { + "BIT" => Ok(Self::Bit), + "INTEGER" => Ok(Self::Integer), + "OCTET" => Ok(Self::Octet), + "REAL" => Ok(Self::Real), + _ => Err(PyValueError::new_err(format!( + "Invalid value for ScalarType: {}", + item + ))), + } + } +} + py_wrap_data_struct! { #[derive(Debug, Hash, PartialEq, Eq)] #[pyo3(subclass)] diff --git a/quil-py/src/instruction/extern_call.rs b/quil-py/src/instruction/extern_call.rs new file mode 100644 index 00000000..c815a86d --- /dev/null +++ b/quil-py/src/instruction/extern_call.rs @@ -0,0 +1,198 @@ +use quil_rs::instruction::{ + Call, CallArguments, ExternDefinition, ExternParameter, ExternParameterType, ExternSignature, + ScalarType, UnresolvedCallArgument, +}; + +use rigetti_pyo3::{ + impl_hash, impl_repr, py_wrap_data_struct, py_wrap_error, py_wrap_union_enum, + pyo3::{pymethods, types::PyString, Py, PyResult, Python}, + wrap_error, PyTryFrom, ToPythonError, +}; + +use crate::{impl_copy_for_instruction, impl_eq, impl_pickle_for_instruction, impl_to_quil}; + +use super::{PyScalarType, PyVector}; + +wrap_error!(RustCallValidationError( + quil_rs::instruction::CallValidationError +)); +py_wrap_error!( + quil, + RustCallValidationError, + CallValidationError, + rigetti_pyo3::pyo3::exceptions::PyValueError +); + +wrap_error!(RustExternValidationError( + quil_rs::instruction::ExternValidationError +)); +py_wrap_error!( + quil, + RustExternValidationError, + ExternValidationError, + rigetti_pyo3::pyo3::exceptions::PyValueError +); + +py_wrap_data_struct! { + #[derive(Debug, PartialEq, Eq)] + #[pyo3(subclass, module = "quil.instructions")] + PyCall(Call) as "Call" { + name: String => Py, + arguments: CallArguments => PyCallArguments + } +} +impl_repr!(PyCall); +impl_to_quil!(PyCall); +impl_copy_for_instruction!(PyCall); +impl_hash!(PyCall); +impl_eq!(PyCall); +impl_pickle_for_instruction!(PyCall); + +#[pymethods] +impl PyCall { + #[new] + fn new(py: Python<'_>, name: String, arguments: Vec) -> PyResult { + Ok(Self( + Call::try_new( + name, + Vec::::py_try_from(py, &arguments)?, + ) + .map_err(RustCallValidationError::from) + .map_err(RustCallValidationError::to_py_err)?, + )) + } +} + +py_wrap_union_enum! { + #[derive(Debug, PartialEq, Eq)] + PyCallArgument(UnresolvedCallArgument) as "CallArgument" { + identifier: Identifier => Py, + memory_reference: MemoryReference => super::PyMemoryReference, + immediate: Immediate => Py + } +} +impl_repr!(PyCallArgument); +impl_to_quil!(PyCallArgument); +impl_hash!(PyCallArgument); +impl_eq!(PyCallArgument); + +py_wrap_union_enum! { + #[derive(Debug, PartialEq, Eq)] + PyCallArguments(CallArguments) as "CallArguments" { + arguments: Unresolved => Vec + } +} +impl_repr!(PyCallArguments); +impl_to_quil!(PyCallArguments); +impl_hash!(PyCallArguments); +impl_eq!(PyCallArguments); + +py_wrap_union_enum! { + #[derive(Debug, PartialEq, Eq)] + PyExternParameterType(ExternParameterType) as "ExternParameterType" { + scalar: Scalar => PyScalarType, + fixed_length_vector: FixedLengthVector => PyVector, + variable_length_vector: VariableLengthVector => PyScalarType + } +} +impl_repr!(PyExternParameterType); +impl_to_quil!(PyExternParameterType); +impl_hash!(PyExternParameterType); +impl_eq!(PyExternParameterType); + +py_wrap_data_struct! { + #[derive(Debug, PartialEq, Eq)] + #[pyo3(subclass, module = "quil.instructions")] + PyExternParameter(ExternParameter) as "ExternParameter" { + name: String => Py, + mutable: bool => bool, + data_type: ExternParameterType => PyExternParameterType + } +} +impl_repr!(PyExternParameter); +impl_to_quil!(PyExternParameter); +impl_copy_for_instruction!(PyExternParameter); +impl_hash!(PyExternParameter); +impl_eq!(PyExternParameter); +impl_pickle_for_instruction!(PyExternParameter); + +#[pymethods] +impl PyExternParameter { + #[new] + fn new( + py: Python<'_>, + name: String, + mutable: bool, + data_type: PyExternParameterType, + ) -> PyResult { + Ok(Self(ExternParameter::new( + name, + mutable, + ExternParameterType::py_try_from(py, &data_type)?, + ))) + } +} + +py_wrap_data_struct! { + #[derive(Debug, PartialEq, Eq)] + #[pyo3(subclass, module = "quil.instructions")] + PyExternSignature(ExternSignature) as "ExternSignature" { + return_type: Option => Option, + parameters: Vec => Vec + } +} +impl_repr!(PyExternSignature); +impl_to_quil!(PyExternSignature); +impl_copy_for_instruction!(PyExternSignature); +impl_hash!(PyExternSignature); +impl_eq!(PyExternSignature); +impl_pickle_for_instruction!(PyExternSignature); + +#[pymethods] +impl PyExternSignature { + #[new] + fn new( + py: Python<'_>, + parameters: Vec, + return_type: Option, + ) -> PyResult { + Ok(Self(ExternSignature::new( + return_type + .map(|scalar_type| ScalarType::py_try_from(py, &scalar_type)) + .transpose()?, + Vec::::py_try_from(py, ¶meters)?, + ))) + } +} + +py_wrap_data_struct! { + #[derive(Debug, PartialEq, Eq)] + #[pyo3(subclass, module = "quil.instructions")] + PyExternDefinition(ExternDefinition) as "ExternDefinition" { + name: String => Py, + signature: Option => Option + } +} +impl_repr!(PyExternDefinition); +impl_to_quil!(PyExternDefinition); +impl_copy_for_instruction!(PyExternDefinition); +impl_hash!(PyExternDefinition); +impl_eq!(PyExternDefinition); +impl_pickle_for_instruction!(PyExternDefinition); + +#[pymethods] +impl PyExternDefinition { + #[new] + fn new(py: Python<'_>, name: String, signature: Option) -> PyResult { + Ok(Self( + ExternDefinition::try_new( + name, + signature + .map(|signature| ExternSignature::py_try_from(py, &signature)) + .transpose()?, + ) + .map_err(RustExternValidationError::from) + .map_err(RustExternValidationError::to_py_err)?, + )) + } +} diff --git a/quil-py/src/instruction/mod.rs b/quil-py/src/instruction/mod.rs index 9de51f2f..241a2985 100644 --- a/quil-py/src/instruction/mod.rs +++ b/quil-py/src/instruction/mod.rs @@ -23,6 +23,10 @@ pub use self::{ ParseMemoryReferenceError, PyDeclaration, PyLoad, PyMemoryReference, PyOffset, PyScalarType, PySharing, PyStore, PyVector, }, + extern_call::{ + CallValidationError, ExternValidationError, PyCall, PyCallArgument, PyCallArguments, + PyExternDefinition, PyExternParameter, PyExternParameterType, PyExternSignature, + }, frame::{ PyAttributeValue, PyCapture, PyFrameAttributes, PyFrameDefinition, PyFrameIdentifier, PyPulse, PyRawCapture, PySetFrequency, PySetPhase, PySetScale, PyShiftFrequency, @@ -33,7 +37,7 @@ pub use self::{ PyPauliSum, PyPauliTerm, }, measurement::PyMeasurement, - pragma::{PyInclude, PyPragma, PyPragmaArgument}, + pragma::{PyInclude, PyPragma, PyPragmaArgument, PyReservedPragma}, qubit::{PyQubit, PyQubitPlaceholder}, reset::PyReset, timing::{PyDelay, PyFence}, @@ -45,6 +49,7 @@ mod circuit; mod classical; mod control_flow; mod declaration; +mod extern_call; mod frame; mod gate; mod measurement; @@ -60,6 +65,7 @@ py_wrap_union_enum! { arithmetic: Arithmetic => PyArithmetic, binary_logic: BinaryLogic => PyBinaryLogic, calibration_definition: CalibrationDefinition => PyCalibration, + call: Call => PyCall, capture: Capture => PyCapture, circuit_definition: CircuitDefinition => PyCircuitDefinition, convert: Convert => PyConvert, @@ -85,6 +91,7 @@ py_wrap_union_enum! { pragma: Pragma => PyPragma, pulse: Pulse => PyPulse, raw_capture: RawCapture => PyRawCapture, + reserved_pragma: ReservedPragma => PyReservedPragma, reset: Reset => PyReset, set_frequency: SetFrequency => PySetFrequency, set_phase: SetPhase => PySetPhase, @@ -149,11 +156,18 @@ create_init_submodule! { PyBinaryLogic, PyBinaryOperand, PyBinaryOperator, + PyCall, + PyCallArgument, + PyCallArguments, PyComparison, PyComparisonOperand, PyComparisonOperator, PyConvert, PyExchange, + PyExternDefinition, + PyExternParameter, + PyExternParameterType, + PyExternSignature, PyMove, PyUnaryLogic, PyUnaryOperator, @@ -176,6 +190,7 @@ create_init_submodule! { PyFrameDefinition, PyFrameIdentifier, PyPulse, + PyReservedPragma, PyRawCapture, PySetFrequency, PySetPhase, @@ -207,7 +222,7 @@ create_init_submodule! { PyWaveformDefinition, PyWaveformInvocation ], - errors: [ GateError, ParseMemoryReferenceError ], + errors: [ CallValidationError, ExternValidationError, GateError, ParseMemoryReferenceError ], } /// Implements __copy__ and __deepcopy__ on any variant of the [`PyInstruction`] class, making diff --git a/quil-py/src/instruction/pragma.rs b/quil-py/src/instruction/pragma.rs index f2b5766c..e95e095b 100644 --- a/quil-py/src/instruction/pragma.rs +++ b/quil-py/src/instruction/pragma.rs @@ -1,4 +1,4 @@ -use quil_rs::instruction::{Include, Pragma, PragmaArgument}; +use quil_rs::instruction::{Include, Pragma, PragmaArgument, ReservedPragma}; use rigetti_pyo3::{ impl_hash, impl_repr, py_wrap_data_struct, py_wrap_union_enum, @@ -57,6 +57,17 @@ impl_to_quil!(PyPragmaArgument); impl_hash!(PyPragmaArgument); impl_eq!(PyPragmaArgument); +py_wrap_union_enum! { + #[derive(Debug, PartialEq, Eq)] + PyReservedPragma(ReservedPragma) as "ReservedPragma" { + extern_definition: Extern => super::extern_call::PyExternDefinition + } +} +impl_repr!(PyReservedPragma); +impl_to_quil!(PyReservedPragma); +impl_hash!(PyReservedPragma); +impl_eq!(PyReservedPragma); + py_wrap_data_struct! { #[derive(Debug, PartialEq, Eq)] #[pyo3(subclass, module = "quil.instructions")] diff --git a/quil-py/src/program/mod.rs b/quil-py/src/program/mod.rs index 9df4ab7d..0f416f60 100644 --- a/quil-py/src/program/mod.rs +++ b/quil-py/src/program/mod.rs @@ -268,6 +268,13 @@ impl PyProgram { self.as_inner_mut().resolve_placeholders(); } + pub fn resolve_call_instructions(&mut self) -> PyResult<()> { + self.as_inner_mut() + .resolve_call_instructions() + .map_err(ProgramError::from) + .map_err(ProgramError::to_py_err) + } + pub fn wrap_in_loop( &self, loop_count_reference: PyMemoryReference, diff --git a/quil-py/test/instructions/test_extern_call.py b/quil-py/test/instructions/test_extern_call.py new file mode 100644 index 00000000..3d2a6f38 --- /dev/null +++ b/quil-py/test/instructions/test_extern_call.py @@ -0,0 +1,66 @@ +import pytest + +from quil.instructions import ( + Call, + CallArgument, + Declaration, + ExternDefinition, + ExternParameter, + ExternParameterType, + ExternSignature, + Instruction, + MemoryReference, + ReservedPragma, + ScalarType, + Vector, +) +from quil.program import Program + + +@pytest.mark.parametrize( + "return_argument", + [CallArgument.from_identifier("real"), CallArgument.from_memory_reference(MemoryReference("real", 1))], +) +def test_extern_call_instructions(return_argument: CallArgument): + p = Program() + extern_definition = ExternDefinition( + name="test", + signature=ExternSignature( + return_type=ScalarType.Real, + parameters=[ + ExternParameter( + name="a", + mutable=False, + data_type=ExternParameterType.from_variable_length_vector(ScalarType.Integer), + ) + ], + ), + ) + p.add_instruction(Instruction(ReservedPragma.from_extern_definition(extern_definition))) + real_declaration = Declaration(name="real", size=Vector(data_type=ScalarType.Real, length=3), sharing=None) + p.add_instruction(Instruction(real_declaration)) + integer_declaration = Declaration(name="integer", size=Vector(data_type=ScalarType.Integer, length=3), sharing=None) + p.add_instruction(Instruction(integer_declaration)) + call = Call( + name=extern_definition.name, + arguments=[ + return_argument, + CallArgument.from_identifier(integer_declaration.name), + ], + ) + p.add_instruction(Instruction(call)) + + parsed_program = Program.parse(p.to_quil()) + assert p == parsed_program + + p.resolve_call_instructions() + + +def test_extern_call_quil(): + input = """DECLARE reals REAL[3] +DECLARE octets OCTET[3] +PRAGMA EXTERN foo "OCTET (params : mut REAL[3])" +CALL foo octets[1] reals +""" + program = Program.parse(input) + program.resolve_call_instructions() diff --git a/quil-rs/src/expression/mod.rs b/quil-rs/src/expression/mod.rs index c82a2540..1221afe2 100644 --- a/quil-rs/src/expression/mod.rs +++ b/quil-rs/src/expression/mod.rs @@ -462,7 +462,7 @@ static FORMAT_IMAGINARY_OPTIONS: Lazy = Lazy::new(|| { /// - When imaginary is 0, show real only /// - When both are non-zero, show with the correct operator in between #[inline(always)] -fn format_complex(value: &Complex64) -> String { +pub(crate) fn format_complex(value: &Complex64) -> String { const FORMAT: u128 = format::STANDARD; if value.re == 0f64 && value.im == 0f64 { "0".to_owned() diff --git a/quil-rs/src/instruction/extern_call.rs b/quil-rs/src/instruction/extern_call.rs new file mode 100644 index 00000000..1d8280be --- /dev/null +++ b/quil-rs/src/instruction/extern_call.rs @@ -0,0 +1,2060 @@ +/// This module provides support for the `CALL` instruction and the reserved `PRAGMA EXTERN` instruction. +/// +/// For additional detail on its design and specification see: +/// +/// * [Quil specification "Other"](https://github.com/quil-lang/quil/blob/7f532c7cdde9f51eae6abe7408cc868fba9f91f6/specgen/spec/sec-other.s_) +/// * [Quil EXTERN / CALL RFC](https://github.com/quil-lang/quil/blob/master/rfcs/extern-call.md) +/// * https://github.com/quil-lang/quil/pull/69 +use std::{collections::HashSet, str::FromStr}; + +use indexmap::IndexMap; +use nom_locate::LocatedSpan; +use num_complex::Complex64; + +use crate::{ + expression::format_complex, + hash::hash_f64, + parser::{lex, InternalParseError, ParserErrorKind, ParserInput}, + program::{disallow_leftover, MemoryRegion, SyntaxError}, + quil::Quil, + validation::identifier::{validate_user_identifier, IdentifierValidationError}, +}; + +use super::{MemoryReference, ScalarType, Vector}; + +/// A parameter type within an extern signature. +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub enum ExternParameterType { + Scalar(ScalarType), + FixedLengthVector(Vector), + VariableLengthVector(ScalarType), +} + +impl Quil for ExternParameterType { + fn write( + &self, + f: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> crate::quil::ToQuilResult<()> { + match self { + ExternParameterType::Scalar(value) => value.write(f, fall_back_to_debug), + ExternParameterType::FixedLengthVector(value) => value.write(f, fall_back_to_debug), + ExternParameterType::VariableLengthVector(value) => { + value.write(f, fall_back_to_debug)?; + write!(f, "[]").map_err(Into::into) + } + } + } +} + +/// An extern parameter with a name, mutability, and data type. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ExternParameter { + /// The name of the parameter. This must be a valid user identifier. + pub name: String, + /// Whether the parameter is mutable. + pub mutable: bool, + /// The data type of the parameter. + pub data_type: ExternParameterType, +} + +impl ExternParameter { + /// Create a new extern parameter. + pub fn new(name: String, mutable: bool, data_type: ExternParameterType) -> Self { + Self { + name, + mutable, + data_type, + } + } +} + +impl Quil for ExternParameter { + fn write( + &self, + writer: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> Result<(), crate::quil::ToQuilError> { + write!(writer, "{} : ", self.name)?; + if self.mutable { + write!(writer, "mut ")?; + } + self.data_type.write(writer, fall_back_to_debug) + } +} + +/// An extern signature with a return type and parameters. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ExternSignature { + /// The return type of the extern signature, if any. + pub return_type: Option, + /// The parameters of the extern signature. + pub parameters: Vec, +} + +impl ExternSignature { + /// Create a new extern signature. + pub fn new(return_type: Option, parameters: Vec) -> Self { + Self { + return_type, + parameters, + } + } + + fn has_return_or_parameters(&self) -> bool { + self.return_type.is_some() || !self.parameters.is_empty() + } +} + +/// An error that can occur when parsing an extern signature. +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum ExternSignatureError { + /// An error occurred while parsing the contents of the extern signature. + #[error("invalid extern signature syntax: {0}")] + Syntax(SyntaxError), + /// An error occurred while lexing the extern signature. + #[error("failed to lex extern signature: {0}")] + Lex(crate::parser::LexError), +} + +impl ExternSignatureError { + pub(crate) fn into_internal_parse_error( + self, + input: ParserInput<'_>, + ) -> InternalParseError<'_> { + InternalParseError::from_kind(input, ParserErrorKind::from(Box::new(self))) + } +} + +impl FromStr for ExternSignature { + type Err = ExternSignatureError; + + fn from_str(s: &str) -> Result { + let signature_input = LocatedSpan::new(s); + let signature_tokens = lex(signature_input).map_err(ExternSignatureError::Lex)?; + let signature = disallow_leftover( + crate::parser::reserved_pragma_extern::parse_extern_signature( + signature_tokens.as_slice(), + ) + .map_err(crate::parser::ParseError::from_nom_internal_err), + ) + .map_err(ExternSignatureError::Syntax)?; + Ok(signature) + } +} + +impl Quil for ExternSignature { + fn write( + &self, + writer: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> Result<(), crate::quil::ToQuilError> { + if let Some(return_type) = &self.return_type { + return_type.write(writer, fall_back_to_debug)?; + if !self.parameters.is_empty() { + write!(writer, " ")?; + } + } + if self.parameters.is_empty() { + return Ok(()); + } + write!(writer, "(")?; + for (i, parameter) in self.parameters.iter().enumerate() { + if i > 0 { + write!(writer, ", ")?; + } + parameter.write(writer, fall_back_to_debug)?; + } + write!(writer, ")").map_err(Into::into) + } +} + +/// An extern definition with a name and optional signature. Note, this is not a +/// Quil instruction or command, though it may become so in the future. Currently, +/// it is defined as a reserved pragma. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ExternDefinition { + /// The name of the extern definition. This must be a valid user identifier. + pub name: String, + /// The signature of the extern definition, if any. + pub signature: Option, +} + +impl ExternDefinition { + pub fn try_new( + name: String, + signature: Option, + ) -> Result { + validate_user_identifier(name.as_str()).map_err(ExternValidationError::Name)?; + + Ok(Self { name, signature }) + } +} + +impl Quil for ExternDefinition { + fn write( + &self, + writer: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> Result<(), crate::quil::ToQuilError> { + write!(writer, "{}", self.name)?; + if let Some(signature) = &self.signature { + write!(writer, " \"")?; + signature.write(writer, fall_back_to_debug)?; + write!(writer, "\"")?; + } + Ok(()) + } +} + +/// An error that can occur when validating an extern definition. +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum ExternValidationError { + /// The specified name is not a valid user identifier. + #[error(transparent)] + Name(#[from] IdentifierValidationError), + /// There are more than one extern definitions with the same name. + #[error("duplicate extern definition {0}")] + Duplicate(String), + /// The extern definition has a signature but it lacks a return or parameters. + #[error("extern definition {0} has a signature but it lacks a return or parameters")] + NoReturnOrParameters(String), +} + +impl ExternDefinition { + /// Validate a list of extern definitions from the same program. It validates the + /// names, uniqueness, and the presence of return or parameters in the signature. + pub(crate) fn validate_all( + extern_definitions: &[ExternDefinition], + ) -> Result<(), ExternValidationError> { + extern_definitions + .iter() + .try_fold(HashSet::new(), |mut acc, extern_definition| { + validate_user_identifier(extern_definition.name.as_str()) + .map_err(ExternValidationError::Name)?; + if acc.contains(&extern_definition.name) { + return Err(ExternValidationError::Duplicate( + extern_definition.name.clone(), + )); + } + if let Some(signature) = extern_definition.signature.as_ref() { + if !signature.has_return_or_parameters() { + return Err(ExternValidationError::NoReturnOrParameters( + extern_definition.name.clone(), + )); + } + } + acc.insert(extern_definition.name.clone()); + Ok(acc) + })?; + Ok(()) + } +} + +/// An error that can occur when resolving a call instruction. +#[derive(Clone, Debug, thiserror::Error, PartialEq)] +pub enum CallArgumentResolutionError { + /// An undeclared memory reference was encountered. + #[error("undeclared memory reference {0}")] + UndeclaredMemoryReference(String), + /// A mismatched vector was encountered. + #[error("mismatched vector: expected {expected:?}, found {found:?}")] + MismatchedVector { expected: Vector, found: Vector }, + /// A mismatched scalar was encountered. + #[error("mismatched scalar: expected {expected:?}, found {found:?}")] + MismatchedScalar { + expected: ScalarType, + found: ScalarType, + }, + /// The argument for a vector parameter was invalid. + #[error("vector parameters must be passed as an identifier, found {0:?}")] + InvalidVectorArgument(UnresolvedCallArgument), + /// the argument for a return parameter was invalid. + #[error("return argument must be a memory reference, found {found:?}")] + ReturnArgument { found: UnresolvedCallArgument }, +} + +/// A parsed, but unresolved call argument. This may be resolved into a [`ResolvedCallArgument`] +/// with the appropriate [`ExternSignature`]. Resolution is required for building the +/// [`crate::Program`] memory graph. +#[derive(Clone, Debug, PartialEq)] +pub enum UnresolvedCallArgument { + /// A reference to a declared memory location. Note, this may be resolved to either + /// a scalar or vector. In the former case, the assumed index is 0. + Identifier(String), + /// A reference to a memory location. This may be resolved to a scalar. + MemoryReference(MemoryReference), + /// An immediate value. This may be resolved to a scalar. + Immediate(Complex64), +} + +impl Eq for UnresolvedCallArgument {} + +impl std::hash::Hash for UnresolvedCallArgument { + fn hash(&self, state: &mut H) { + match self { + UnresolvedCallArgument::Identifier(value) => { + "Identifier".hash(state); + value.hash(state); + } + UnresolvedCallArgument::MemoryReference(value) => { + "MemoryReference".hash(state); + value.hash(state); + } + UnresolvedCallArgument::Immediate(value) => { + "Immediate".hash(state); + hash_complex_64(value, state); + } + } + } +} + +impl UnresolvedCallArgument { + /// Check if the argument is compatible with the given [`ExternParameter`]. If so, return + /// the appropriate [`ResolvedCallArgument`]. If not, return an error. + fn resolve( + &self, + memory_regions: &IndexMap, + extern_parameter: &ExternParameter, + ) -> Result { + match self { + UnresolvedCallArgument::Identifier(value) => { + let expected_vector = match &extern_parameter.data_type { + ExternParameterType::Scalar(_) => { + return UnresolvedCallArgument::MemoryReference(MemoryReference::new( + value.clone(), + 0, + )) + .resolve(memory_regions, extern_parameter); + } + ExternParameterType::FixedLengthVector(expected_vector) => { + let memory_region = + memory_regions.get(value.as_str()).ok_or_else(|| { + CallArgumentResolutionError::UndeclaredMemoryReference( + value.clone(), + ) + })?; + if &memory_region.size != expected_vector { + return Err(CallArgumentResolutionError::MismatchedVector { + expected: expected_vector.clone(), + found: memory_region.size.clone(), + }); + } + + Ok(expected_vector.clone()) + } + ExternParameterType::VariableLengthVector(scalar_type) => { + let memory_region = + memory_regions.get(value.as_str()).ok_or_else(|| { + CallArgumentResolutionError::UndeclaredMemoryReference( + value.clone(), + ) + })?; + if &memory_region.size.data_type != scalar_type { + return Err(CallArgumentResolutionError::MismatchedScalar { + expected: *scalar_type, + found: memory_region.size.data_type, + }); + } + Ok(memory_region.size.clone()) + } + }?; + Ok(ResolvedCallArgument::Vector { + memory_region_name: value.clone(), + vector: expected_vector, + mutable: extern_parameter.mutable, + }) + } + UnresolvedCallArgument::MemoryReference(value) => { + let expected_scalar = match extern_parameter.data_type { + ExternParameterType::Scalar(ref scalar) => Ok(scalar), + ExternParameterType::FixedLengthVector(_) + | ExternParameterType::VariableLengthVector(_) => { + Err(CallArgumentResolutionError::InvalidVectorArgument( + Self::MemoryReference(value.clone()), + )) + } + }?; + let memory_region = memory_regions.get(value.name.as_str()).ok_or_else(|| { + CallArgumentResolutionError::UndeclaredMemoryReference(value.name.clone()) + })?; + if memory_region.size.data_type != *expected_scalar { + return Err(CallArgumentResolutionError::MismatchedScalar { + expected: *expected_scalar, + found: memory_region.size.data_type, + }); + } + Ok(ResolvedCallArgument::MemoryReference { + memory_reference: value.clone(), + scalar_type: *expected_scalar, + mutable: extern_parameter.mutable, + }) + } + UnresolvedCallArgument::Immediate(value) => { + let expected_scalar = match extern_parameter.data_type { + ExternParameterType::Scalar(ref scalar) => Ok(scalar), + ExternParameterType::FixedLengthVector(_) + | ExternParameterType::VariableLengthVector(_) => Err( + CallArgumentResolutionError::InvalidVectorArgument(self.clone()), + ), + }?; + Ok(ResolvedCallArgument::Immediate { + value: *value, + scalar_type: *expected_scalar, + }) + } + } + } + + /// Check if the argument is compatible with the return type of the [`ExternSignature`]. If so, + /// return the appropriate [`ResolvedCallArgument`]. If not, return an error. + fn resolve_return( + &self, + memory_regions: &IndexMap, + return_type: ScalarType, + ) -> Result { + let memory_reference = match self { + UnresolvedCallArgument::MemoryReference(memory_reference) => { + Ok(memory_reference.clone()) + } + UnresolvedCallArgument::Identifier(identifier) => { + Ok(MemoryReference::new(identifier.clone(), 0)) + } + _ => Err(CallArgumentResolutionError::ReturnArgument { + found: self.clone(), + }), + }?; + let memory_region = memory_regions + .get(memory_reference.name.as_str()) + .ok_or_else(|| { + CallArgumentResolutionError::UndeclaredMemoryReference( + memory_reference.name.clone(), + ) + })?; + if memory_region.size.data_type != return_type { + return Err(CallArgumentResolutionError::MismatchedScalar { + expected: return_type, + found: memory_region.size.data_type, + }); + } + Ok(ResolvedCallArgument::MemoryReference { + memory_reference: memory_reference.clone(), + scalar_type: return_type, + mutable: true, + }) + } +} + +impl Quil for UnresolvedCallArgument { + fn write( + &self, + f: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> crate::quil::ToQuilResult<()> { + match &self { + UnresolvedCallArgument::Identifier(value) => write!(f, "{value}",).map_err(Into::into), + UnresolvedCallArgument::MemoryReference(value) => value.write(f, fall_back_to_debug), + UnresolvedCallArgument::Immediate(value) => { + write!(f, "{}", format_complex(value)).map_err(Into::into) + } + } + } +} + +/// A resolved call argument. This is the result of resolving an [`UnresolvedCallArgument`] with +/// the appropriate [`ExternSignature`]. It annotates the argument both with a type (and possibly +/// a length in the case of a vector) and whether it is mutable. +#[derive(Clone, Debug, PartialEq)] +pub enum ResolvedCallArgument { + Vector { + memory_region_name: String, + vector: Vector, + mutable: bool, + }, + MemoryReference { + memory_reference: MemoryReference, + scalar_type: ScalarType, + mutable: bool, + }, + Immediate { + value: Complex64, + scalar_type: ScalarType, + }, +} + +impl From for UnresolvedCallArgument { + fn from(value: ResolvedCallArgument) -> Self { + match value { + ResolvedCallArgument::Vector { + memory_region_name, .. + } => UnresolvedCallArgument::Identifier(memory_region_name), + ResolvedCallArgument::MemoryReference { + memory_reference, .. + } => UnresolvedCallArgument::MemoryReference(memory_reference), + ResolvedCallArgument::Immediate { value, .. } => { + UnresolvedCallArgument::Immediate(value) + } + } + } +} + +impl Eq for ResolvedCallArgument {} + +impl std::hash::Hash for ResolvedCallArgument { + fn hash(&self, state: &mut H) { + match self { + ResolvedCallArgument::Vector { + memory_region_name, + vector, + mutable, + } => { + "Vector".hash(state); + memory_region_name.hash(state); + vector.hash(state); + mutable.hash(state); + } + ResolvedCallArgument::MemoryReference { + memory_reference, + scalar_type, + mutable, + } => { + "MemoryReference".hash(state); + memory_reference.hash(state); + scalar_type.hash(state); + mutable.hash(state); + } + ResolvedCallArgument::Immediate { value, scalar_type } => { + "Immediate".hash(state); + hash_complex_64(value, state); + scalar_type.hash(state); + } + } + } +} + +fn hash_complex_64(value: &Complex64, state: &mut H) { + if value.re.abs() > 0f64 { + hash_f64(value.re, state); + } + if value.im.abs() > 0f64 { + hash_f64(value.im, state); + } +} + +impl Quil for ResolvedCallArgument { + fn write( + &self, + f: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> crate::quil::ToQuilResult<()> { + match self { + ResolvedCallArgument::Vector { + memory_region_name: value, + .. + } => write!(f, "{value}").map_err(Into::into), + ResolvedCallArgument::MemoryReference { + memory_reference: value, + .. + } => value.write(f, fall_back_to_debug), + ResolvedCallArgument::Immediate { value, .. } => { + write!(f, "{value}").map_err(Into::into) + } + } + } +} + +impl ResolvedCallArgument { + /// Indicates whether the argument is mutable. + pub(crate) fn is_mutable(&self) -> bool { + match self { + ResolvedCallArgument::Vector { mutable, .. } => *mutable, + ResolvedCallArgument::MemoryReference { mutable, .. } => *mutable, + ResolvedCallArgument::Immediate { .. } => false, + } + } + + /// Returns the name of the memory region. In the case of an immediate value, + /// this will be `None`. + pub(crate) fn name(&self) -> Option { + match self { + ResolvedCallArgument::Vector { + memory_region_name, .. + } => Some(memory_region_name.clone()), + ResolvedCallArgument::MemoryReference { + memory_reference, .. + } => Some(memory_reference.name.clone()), + ResolvedCallArgument::Immediate { .. } => None, + } + } +} + +/// A list of arguments for a call instruction. These may be resolved or unresolved. +/// To resolve a [`Call`] instruction, use [`crate::Program::resolve_call_instructions`]. +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub enum CallArguments { + /// The resolved call arguments. + Resolved(Vec), + /// The unresolved call arguments. + Unresolved(Vec), +} + +impl Quil for CallArguments { + fn write( + &self, + writer: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> Result<(), crate::quil::ToQuilError> { + match self { + CallArguments::Resolved(arguments) => { + if !arguments.is_empty() { + write!(writer, " ")?; + } + for (i, argument) in arguments.iter().enumerate() { + argument.write(writer, fall_back_to_debug)?; + if i < arguments.len() - 1 { + write!(writer, " ")?; + } + } + } + CallArguments::Unresolved(arguments) => { + if !arguments.is_empty() { + write!(writer, " ")?; + } + for (i, argument) in arguments.iter().enumerate() { + argument.write(writer, fall_back_to_debug)?; + if i < arguments.len() - 1 { + write!(writer, " ")?; + } + } + } + } + Ok(()) + } +} + +/// An error that can occur when validating a call instruction. +#[derive(Clone, Debug, PartialEq, thiserror::Error, Eq)] +pub enum CallValidationError { + /// The specified name is not a valid user identifier. + #[error(transparent)] + Name(#[from] IdentifierValidationError), +} + +/// A call instruction with a name and arguments. +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub struct Call { + /// The name of the call instruction. This must be a valid user identifier. + pub name: String, + /// The arguments of the call instruction. + pub arguments: CallArguments, +} + +impl Call { + /// Create a new call instruction with resolved arguments. This will validate the + /// name as a user identifier. + pub fn try_new( + name: String, + arguments: Vec, + ) -> Result { + validate_user_identifier(name.as_str()).map_err(CallValidationError::Name)?; + + Ok(Self { + name, + arguments: CallArguments::Unresolved(arguments), + }) + } +} + +/// An error that can occur when resolving a call instruction. +#[derive(Clone, Debug, thiserror::Error, PartialEq)] +pub enum CallArgumentError { + /// The return argument could not be resolved. + #[error("error resolving return argument: {0:?}")] + Return(CallArgumentResolutionError), + /// An argument could not be resolved. + #[error("error resolving argument {index}: {error:?}")] + Argument { + index: usize, + error: CallArgumentResolutionError, + }, +} + +/// An error that can occur when resolving a call instruction to a specific +/// [`ExternSignature`]. +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum CallSignatureError { + #[error("expected {expected} arguments, found {found}")] + ParameterCount { expected: usize, found: usize }, + #[error("error resolving arguments: {0:?}")] + Arguments(Vec), +} + +/// An error that can occur when resolving a call instruction within the context +/// of a program. +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum CallResolutionError { + /// A matching extern instruction was found, but the signature validation failed. + #[error("call found matching extern instruction for {name}, but signature validation failed: {error:?}")] + Signature { + name: String, + error: CallSignatureError, + }, + /// A matching extern instruction was found, but it has no signature. + #[error("call found matching extern instruction for {0}, but it has no signature")] + NoSignature(String), + /// No matching extern instruction was found. + #[error("no extern instruction found with name {0}")] + NoMatchingExternInstruction(String), +} + +#[allow(clippy::manual_try_fold)] +fn convert_unresolved_to_resolved_call_arguments( + arguments: &[UnresolvedCallArgument], + signature: &ExternSignature, + memory_regions: &IndexMap, +) -> Result, CallSignatureError> { + arguments + .iter() + .enumerate() + .map(|(i, argument)| { + if i == 0 { + if let Some(return_type) = signature.return_type { + return argument + .resolve_return(memory_regions, return_type) + .map_err(CallArgumentError::Return); + } + } + let parameter_index = if signature.return_type.is_some() { + i - 1 + } else { + i + }; + let parameter = &signature.parameters[parameter_index]; + argument + .resolve(memory_regions, parameter) + .map_err(|error| CallArgumentError::Argument { + index: parameter_index, + error, + }) + }) + .fold( + Ok(Vec::new()), + |acc: Result, Vec>, + result: Result| { + match (acc, result) { + (Ok(mut acc), Ok(resolved)) => { + acc.push(resolved); + Ok(acc) + } + (Ok(_), Err(error)) => Err(vec![error]), + (Err(errors), Ok(_)) => Err(errors), + (Err(mut errors), Err(error)) => { + errors.push(error); + Err(errors) + } + } + }, + ) + .map_err(CallSignatureError::Arguments) +} + +impl Call { + /// Resolve the [`Call`] instruction to the given [`ExternSignature`]. + fn resolve_to_signature( + &mut self, + signature: &ExternSignature, + memory_regions: &IndexMap, + ) -> Result<(), CallSignatureError> { + let mut expected_parameter_count = signature.parameters.len(); + if signature.return_type.is_some() { + expected_parameter_count += 1; + } + let actual_parameter_count = match &self.arguments { + CallArguments::Resolved(arguments) => arguments.len(), + CallArguments::Unresolved(arguments) => arguments.len(), + }; + if actual_parameter_count != expected_parameter_count { + return Err(CallSignatureError::ParameterCount { + expected: expected_parameter_count, + found: actual_parameter_count, + }); + } + + let resolved_call_arguments = match &self.arguments { + CallArguments::Resolved(arguments) => { + let unresolved_call_arguments = arguments + .iter() + .cloned() + .map(UnresolvedCallArgument::from) + .collect::>(); + + convert_unresolved_to_resolved_call_arguments( + unresolved_call_arguments.as_slice(), + signature, + memory_regions, + )? + } + CallArguments::Unresolved(arguments) => { + convert_unresolved_to_resolved_call_arguments(arguments, signature, memory_regions)? + } + }; + + self.arguments = CallArguments::Resolved(resolved_call_arguments); + Ok(()) + } + + /// Resolve the [`Call`] instruction to any of the given [`ExternSignature`] and memory regions. + /// If no matching extern instruction is found, return an error. + pub fn resolve( + &mut self, + memory_regions: &IndexMap, + extern_definitions: &[ExternDefinition], + ) -> Result<(), CallResolutionError> { + for definition in extern_definitions { + if definition.name == self.name { + let signature = definition + .signature + .as_ref() + .ok_or_else(|| CallResolutionError::NoSignature(self.name.clone()))?; + return self + .resolve_to_signature(signature, memory_regions) + .map_err(|error| CallResolutionError::Signature { + name: self.name.clone(), + error, + }); + } + } + Err(CallResolutionError::NoMatchingExternInstruction( + self.name.clone(), + )) + } +} + +impl Quil for Call { + fn write( + &self, + f: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> crate::quil::ToQuilResult<()> { + write!(f, "CALL {}", self.name)?; + self.arguments.write(f, fall_back_to_debug) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::*; + + /// Test cases for the `ExternDefinition` Quil representation. + struct ExternDefinitionQuilTestCase { + /// The extern definition to test. + definition: ExternDefinition, + /// The expected Quil representation. + expected: &'static str, + } + + impl ExternDefinitionQuilTestCase { + /// Signature with return and parameters + fn case_01() -> Self { + Self { + definition: ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + ExternParameter { + name: "baz".to_string(), + mutable: true, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 2, + }), + }, + ], + }), + }, + expected: "foo \"INTEGER (bar : INTEGER, baz : mut BIT[2])\"", + } + } + + /// Signature with only parameters + fn case_02() -> Self { + let definition = ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: None, + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + ExternParameter { + name: "baz".to_string(), + mutable: true, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 2, + }), + }, + ], + }), + }; + Self { + definition, + expected: "foo \"(bar : INTEGER, baz : mut BIT[2])\"", + } + } + + /// Signature with return only + fn case_03() -> Self { + let definition = ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }), + }; + Self { + definition, + expected: "foo \"INTEGER\"", + } + } + + /// Signature with no return nor parameters + fn case_04() -> Self { + let definition = ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: None, + parameters: vec![], + }), + }; + Self { + definition, + expected: "foo \"\"", + } + } + + /// No signature + fn case_05() -> Self { + let definition = ExternDefinition { + name: "foo".to_string(), + signature: None, + }; + Self { + definition, + expected: "foo", + } + } + + /// Variable length vector + fn case_06() -> Self { + let definition = ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: None, + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::VariableLengthVector(ScalarType::Integer), + }], + }), + }; + Self { + definition, + expected: "foo \"(bar : INTEGER[])\"", + } + } + } + + /// Test that the Quil representation of an `ExternDefinition` is as expected. + #[rstest] + #[case(ExternDefinitionQuilTestCase::case_01())] + #[case(ExternDefinitionQuilTestCase::case_02())] + #[case(ExternDefinitionQuilTestCase::case_03())] + #[case(ExternDefinitionQuilTestCase::case_04())] + #[case(ExternDefinitionQuilTestCase::case_05())] + #[case(ExternDefinitionQuilTestCase::case_06())] + fn test_extern_definition_quil(#[case] test_case: ExternDefinitionQuilTestCase) { + assert_eq!( + test_case + .definition + .to_quil() + .expect("must be able to call to quil"), + test_case.expected.to_string() + ); + } + + /// Test cases for the `Call` Quil representation. + struct CallQuilTestCase { + /// The call instruction to test. + call: Call, + /// The expected Quil representation. + expected: &'static str, + } + + impl CallQuilTestCase { + fn case_01() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "bar".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("baz".to_string()), + ]), + }; + Self { + call, + expected: "CALL foo bar[0] 2 baz", + } + } + + fn case_02() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "bar".to_string(), + index: 0, + }), + UnresolvedCallArgument::Identifier("baz".to_string()), + ]), + }; + Self { + call, + expected: "CALL foo bar[0] baz", + } + } + + fn case_03() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "bar".to_string(), + index: 0, + }), + ]), + }; + Self { + call, + expected: "CALL foo bar[0]", + } + } + + fn case_04() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![]), + }; + + Self { + call, + expected: "CALL foo", + } + } + } + + /// Test that the Quil representation of a `Call` instruction is as expected. + #[rstest] + #[case(CallQuilTestCase::case_01())] + #[case(CallQuilTestCase::case_02())] + #[case(CallQuilTestCase::case_03())] + #[case(CallQuilTestCase::case_04())] + fn test_call_definition_quil(#[case] test_case: CallQuilTestCase) { + assert_eq!( + test_case + .call + .to_quil() + .expect("must be able to call to quil"), + test_case.expected.to_string() + ); + } + + /// Build a set of memory regions for testing. + fn build_declarations() -> IndexMap { + [ + ("integer", Vector::new(ScalarType::Integer, 3)), + ("real", Vector::new(ScalarType::Real, 3)), + ("bit", Vector::new(ScalarType::Bit, 3)), + ("octet", Vector::new(ScalarType::Octet, 3)), + ] + .into_iter() + .map(|(name, vector)| (name.to_string(), MemoryRegion::new(vector, None))) + .collect() + } + + /// Test cases for resolving call arguments. + struct ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument, + extern_parameter: ExternParameter, + expected: Result, + } + + impl ArgumentResolutionTestCase { + /// Memory reference as scalar + fn case_01() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Ok(ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: false, + }), + } + } + + /// Identifier as vector + fn case_02() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Identifier("real".to_string()), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Real, + length: 3, + }), + }, + expected: Ok(ResolvedCallArgument::Vector { + memory_region_name: "real".to_string(), + vector: Vector { + data_type: ScalarType::Real, + length: 3, + }, + mutable: false, + }), + } + } + + /// Immediate value + fn case_03() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Ok(ResolvedCallArgument::Immediate { + value: Complex64::new(2.0, 0.0), + scalar_type: ScalarType::Integer, + }), + } + } + + /// Undeclared identifier + fn case_04() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Identifier("undeclared".to_string()), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Real, + length: 3, + }), + }, + expected: Err(CallArgumentResolutionError::UndeclaredMemoryReference( + "undeclared".to_string(), + )), + } + } + + /// Undeclared memory reference + fn case_05() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "undeclared".to_string(), + index: 0, + }), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Err(CallArgumentResolutionError::UndeclaredMemoryReference( + "undeclared".to_string(), + )), + } + } + + /// Vector data type mismatch + fn case_06() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Identifier("integer".to_string()), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Real, + length: 3, + }), + }, + expected: Err(CallArgumentResolutionError::MismatchedVector { + expected: Vector { + data_type: ScalarType::Real, + length: 3, + }, + found: Vector { + data_type: ScalarType::Integer, + length: 3, + }, + }), + } + } + + /// Vector length mismatch + fn case_07() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Identifier("integer".to_string()), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Integer, + length: 4, + }), + }, + expected: Err(CallArgumentResolutionError::MismatchedVector { + expected: Vector { + data_type: ScalarType::Integer, + length: 4, + }, + found: Vector { + data_type: ScalarType::Integer, + length: 3, + }, + }), + } + } + + /// Scalar data type mismatch + fn case_08() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "octet".to_string(), + index: 0, + }), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Err(CallArgumentResolutionError::MismatchedScalar { + expected: ScalarType::Integer, + found: ScalarType::Octet, + }), + } + } + + /// Scalar arguments may be passed as identifiers, in which case `0` index is + /// inferred. + fn case_09() -> Self { + let call_argument = UnresolvedCallArgument::Identifier("integer".to_string()); + ArgumentResolutionTestCase { + call_argument: call_argument.clone(), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Ok(ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference::new("integer".to_string(), 0), + scalar_type: ScalarType::Integer, + mutable: false, + }), + } + } + + /// Vector arguments must be passed as identifiers, not memory references. + fn case_10() -> Self { + let call_argument = UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }); + ArgumentResolutionTestCase { + call_argument: call_argument.clone(), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Integer, + length: 3, + }), + }, + expected: Err(CallArgumentResolutionError::InvalidVectorArgument( + call_argument, + )), + } + } + + /// Vector arguments must be passed as identifiers, not immediate values. + fn case_11() -> Self { + let call_argument = UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)); + ArgumentResolutionTestCase { + call_argument: call_argument.clone(), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Integer, + length: 3, + }), + }, + expected: Err(CallArgumentResolutionError::InvalidVectorArgument( + call_argument, + )), + } + } + + /// Variable vector arguments are resolved to a specific vector length based on the + /// declaration (see [`build_declarations`]). + fn case_12() -> Self { + let call_argument = UnresolvedCallArgument::Identifier("integer".to_string()); + ArgumentResolutionTestCase { + call_argument: call_argument.clone(), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::VariableLengthVector(ScalarType::Integer), + }, + expected: Ok(ResolvedCallArgument::Vector { + memory_region_name: "integer".to_string(), + mutable: false, + vector: Vector { + data_type: ScalarType::Integer, + length: 3, + }, + }), + } + } + } + + /// Test resolution of call arguments. + #[rstest] + #[case(ArgumentResolutionTestCase::case_01())] + #[case(ArgumentResolutionTestCase::case_02())] + #[case(ArgumentResolutionTestCase::case_03())] + #[case(ArgumentResolutionTestCase::case_04())] + #[case(ArgumentResolutionTestCase::case_05())] + #[case(ArgumentResolutionTestCase::case_06())] + #[case(ArgumentResolutionTestCase::case_07())] + #[case(ArgumentResolutionTestCase::case_08())] + #[case(ArgumentResolutionTestCase::case_09())] + #[case(ArgumentResolutionTestCase::case_10())] + #[case(ArgumentResolutionTestCase::case_11())] + #[case(ArgumentResolutionTestCase::case_12())] + fn test_argument_resolution(#[case] test_case: ArgumentResolutionTestCase) { + let memory_regions = build_declarations(); + let found = test_case + .call_argument + .resolve(&memory_regions, &test_case.extern_parameter); + match (test_case.expected, found) { + (Ok(expected), Ok(found)) => assert_eq!(expected, found), + (Ok(expected), Err(found)) => { + panic!("expected resolution {:?}, found err {:?}", expected, found) + } + (Err(expected), Ok(found)) => { + panic!("expected err {:?}, found resolution {:?}", expected, found) + } + (Err(expected), Err(found)) => assert_eq!(expected, found), + } + } + + /// Test cases for resolving return arguments. + struct ReturnArgumentResolutionTestCase { + /// The call argument to resolve. + call_argument: UnresolvedCallArgument, + /// The return type of the function. + return_type: ScalarType, + /// The expected result of the resolution. + expected: Result, + } + + impl ReturnArgumentResolutionTestCase { + /// Memory reference is ok. + fn case_01() -> Self { + let call_argument = UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }); + let expected = Ok(ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }); + Self { + call_argument, + return_type: ScalarType::Integer, + expected, + } + } + + /// Immediate value is not ok. + fn case_02() -> Self { + let call_argument = UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)); + let expected = Err(CallArgumentResolutionError::ReturnArgument { + found: call_argument.clone(), + }); + Self { + call_argument, + return_type: ScalarType::Integer, + expected, + } + } + + /// Allow plain identifiers to be upcast to memory references. + fn case_03() -> Self { + let call_argument = UnresolvedCallArgument::Identifier("integer".to_string()); + let expected = Ok(ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference::new("integer".to_string(), 0), + scalar_type: ScalarType::Integer, + mutable: true, + }); + Self { + call_argument, + return_type: ScalarType::Integer, + expected, + } + } + } + + /// Test resolution of return arguments. + #[rstest] + #[case(ReturnArgumentResolutionTestCase::case_01())] + #[case(ReturnArgumentResolutionTestCase::case_02())] + #[case(ReturnArgumentResolutionTestCase::case_03())] + fn test_return_argument_resolution(#[case] test_case: ReturnArgumentResolutionTestCase) { + let memory_regions = build_declarations(); + + let found = test_case + .call_argument + .resolve_return(&memory_regions, test_case.return_type); + match (test_case.expected, found) { + (Ok(expected), Ok(found)) => assert_eq!(expected, found), + (Ok(expected), Err(found)) => { + panic!("expected resolution {:?}, found err {:?}", expected, found) + } + (Err(expected), Ok(found)) => { + panic!("expected err {:?}, found resolution {:?}", expected, found) + } + (Err(expected), Err(found)) => assert_eq!(expected, found), + } + } + + /// Test cases for resolving call arguments to a specific signature. + struct ResolveToSignatureTestCase { + /// The call instruction to resolve. + call: Call, + /// The signature to resolve to. + signature: ExternSignature, + /// The expected result of the resolution. + expected: Result, CallSignatureError>, + } + + impl ResolveToSignatureTestCase { + /// Valid match with return and parameters + fn case_01() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ]), + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + ExternParameter { + name: "baz".to_string(), + mutable: true, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 3, + }), + }, + ], + }; + let resolved = vec![ + ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }, + ResolvedCallArgument::Immediate { + value: Complex64::new(2.0, 0.0), + scalar_type: ScalarType::Integer, + }, + ResolvedCallArgument::Vector { + memory_region_name: "bit".to_string(), + vector: Vector { + data_type: ScalarType::Bit, + length: 3, + }, + mutable: true, + }, + ]; + Self { + call, + signature, + expected: Ok(resolved), + } + } + + /// Valid match with parameteters only + fn case_02() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ]), + }; + let signature = ExternSignature { + return_type: None, + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + ExternParameter { + name: "baz".to_string(), + mutable: true, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 3, + }), + }, + ], + }; + let resolved = vec![ + ResolvedCallArgument::Immediate { + value: Complex64::new(2.0, 0.0), + scalar_type: ScalarType::Integer, + }, + ResolvedCallArgument::Vector { + memory_region_name: "bit".to_string(), + vector: Vector { + data_type: ScalarType::Bit, + length: 3, + }, + mutable: true, + }, + ]; + Self { + call, + signature, + expected: Ok(resolved), + } + } + + /// Valid match with return only + fn case_03() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + ]), + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }; + let resolved = vec![ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }]; + Self { + call, + signature, + expected: Ok(resolved), + } + } + + /// Already resolved is converted back to unresolved and re-resolved. + fn case_04() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Resolved(vec![ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }]), + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }; + Self { + call: call.clone(), + signature, + expected: Ok(vec![ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }]), + } + } + + /// Parameter count mismatch with return and parameters + fn case_05() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ]), + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }], + }; + + Self { + call, + signature, + expected: Err(CallSignatureError::ParameterCount { + expected: 2, + found: 3, + }), + } + } + + /// Parameter count mismatch return only + fn case_06() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + ]), + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }; + + Self { + call, + signature, + expected: Err(CallSignatureError::ParameterCount { + expected: 1, + found: 2, + }), + } + } + + /// Parameter count mismatch parameters only + fn case_07() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ]), + }; + let signature = ExternSignature { + return_type: None, + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }], + }; + + Self { + call, + signature, + expected: Err(CallSignatureError::ParameterCount { + expected: 1, + found: 3, + }), + } + } + + /// Argument mismatch + fn case_08() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ]), + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Real), + }], + }; + + Self { + call, + signature, + expected: Err(CallSignatureError::Arguments(vec![ + CallArgumentError::Return(CallArgumentResolutionError::ReturnArgument { + found: UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + }), + CallArgumentError::Argument { + index: 0, + error: CallArgumentResolutionError::MismatchedScalar { + expected: ScalarType::Real, + found: ScalarType::Bit, + }, + }, + ])), + } + } + } + + /// Test resolution of `Call` instructions to a specific signature. + #[rstest] + #[case(ResolveToSignatureTestCase::case_01())] + #[case(ResolveToSignatureTestCase::case_02())] + #[case(ResolveToSignatureTestCase::case_03())] + #[case(ResolveToSignatureTestCase::case_04())] + #[case(ResolveToSignatureTestCase::case_05())] + #[case(ResolveToSignatureTestCase::case_06())] + #[case(ResolveToSignatureTestCase::case_07())] + #[case(ResolveToSignatureTestCase::case_08())] + fn test_assert_matching_signature(#[case] mut test_case: ResolveToSignatureTestCase) { + let memory_regions = build_declarations(); + let found = test_case + .call + .resolve_to_signature(&test_case.signature, &memory_regions); + match (test_case.expected, found) { + (Ok(_), Ok(_)) => {} + (Ok(expected), Err(found)) => { + panic!("expected resolution {:?}, found err {:?}", expected, found) + } + (Err(expected), Ok(found)) => { + panic!("expected err {:?}, found resolution {:?}", expected, found) + } + (Err(expected), Err(found)) => assert_eq!(expected, found), + } + } + + /// Test cases for call resolution against a set of extern definitions. + struct CallResolutionTestCase { + /// The call instruction to resolve. + call: Call, + /// The set of extern definitions to resolve against. + extern_definitions: Vec, + /// The expected result of the resolution. + expected: Result, CallResolutionError>, + } + + impl CallResolutionTestCase { + /// Valid resolution + fn case_01() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + ]), + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }; + let resolved = vec![ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }]; + Self { + call, + extern_definitions: vec![ExternDefinition { + name: "foo".to_string(), + signature: Some(signature), + }], + expected: Ok(resolved), + } + } + + /// Signature does not match + fn case_02() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + ]), + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Real), + parameters: vec![], + }; + Self { + call, + extern_definitions: vec![ExternDefinition { + name: "foo".to_string(), + signature: Some(signature), + }], + expected: Err(CallResolutionError::Signature { + name: "foo".to_string(), + error: CallSignatureError::Arguments(vec![CallArgumentError::Return( + CallArgumentResolutionError::MismatchedScalar { + expected: ScalarType::Real, + found: ScalarType::Integer, + }, + )]), + }), + } + } + + /// No signature on extern definition + fn case_03() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + ]), + }; + Self { + call, + extern_definitions: vec![ExternDefinition { + name: "foo".to_string(), + signature: None, + }], + expected: Err(CallResolutionError::NoSignature("foo".to_string())), + } + } + + /// No corresponding extern definition + fn case_04() -> Self { + let call = Call { + name: "undeclared".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + ]), + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Real), + parameters: vec![], + }; + Self { + call, + extern_definitions: vec![ExternDefinition { + name: "foo".to_string(), + signature: Some(signature), + }], + expected: Err(CallResolutionError::NoMatchingExternInstruction( + "undeclared".to_string(), + )), + } + } + } + + /// Test resolution of `Call` instructions against a set of extern definitions. + #[rstest] + #[case(CallResolutionTestCase::case_01())] + #[case(CallResolutionTestCase::case_02())] + #[case(CallResolutionTestCase::case_03())] + #[case(CallResolutionTestCase::case_04())] + fn test_call_resolution(#[case] mut test_case: CallResolutionTestCase) { + let memory_regions = build_declarations(); + let found = test_case + .call + .resolve(&memory_regions, &test_case.extern_definitions); + match (test_case.expected, found) { + (Ok(expected), Ok(_)) => { + assert_eq!(CallArguments::Resolved(expected), test_case.call.arguments) + } + (Ok(expected), Err(found)) => { + panic!("expected resolution {:?}, found err {:?}", expected, found) + } + (Err(expected), Ok(_)) => { + panic!( + "expected err {:?}, found resolution {:?}", + expected, test_case.call.arguments + ) + } + (Err(expected), Err(found)) => assert_eq!(expected, found), + } + } + + /// Test cases for validating extern definitions. + struct ExternDefinitionTestCase { + /// The set of extern definitions to validate. + definitions: Vec, + /// The expected result of the validation. + expected: Result<(), ExternValidationError>, + } + + impl ExternDefinitionTestCase { + /// Valid definitions + fn case_01() -> Self { + let definition1 = ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }], + }), + }; + let definition2 = ExternDefinition { + name: "baz".to_string(), + signature: Some(ExternSignature { + return_type: Some(ScalarType::Real), + parameters: vec![ExternParameter { + name: "biz".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Real), + }], + }), + }; + let definitions = vec![definition1, definition2]; + let expected = Ok(()); + Self { + definitions, + expected, + } + } + + /// Duplicate + fn case_02() -> Self { + let definition1 = ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }], + }), + }; + let definition2 = definition1.clone(); + let definitions = vec![definition1, definition2]; + let expected = Err(ExternValidationError::Duplicate("foo".to_string())); + Self { + definitions, + expected, + } + } + + /// No return nor parameters + fn case_03() -> Self { + let definition1 = ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: None, + parameters: vec![], + }), + }; + let definitions = vec![definition1]; + let expected = Err(ExternValidationError::NoReturnOrParameters( + "foo".to_string(), + )); + Self { + definitions, + expected, + } + } + } + + /// Test validation of extern definitions. + #[rstest] + #[case(ExternDefinitionTestCase::case_01())] + #[case(ExternDefinitionTestCase::case_02())] + #[case(ExternDefinitionTestCase::case_03())] + fn test_extern_definition_validation(#[case] test_case: ExternDefinitionTestCase) { + let found = ExternDefinition::validate_all(test_case.definitions.as_slice()); + match (test_case.expected, found) { + (Ok(_), Ok(_)) => {} + (Ok(_), Err(found)) => { + panic!("expected valid, found err {:?}", found) + } + (Err(expected), Ok(_)) => { + panic!("expected err {:?}, found valid", expected) + } + (Err(expected), Err(found)) => assert_eq!(expected, found), + } + } +} diff --git a/quil-rs/src/instruction/mod.rs b/quil-rs/src/instruction/mod.rs index 02855c75..9dbe9271 100644 --- a/quil-rs/src/instruction/mod.rs +++ b/quil-rs/src/instruction/mod.rs @@ -30,6 +30,7 @@ mod circuit; mod classical; mod control_flow; mod declaration; +mod extern_call; mod frame; mod gate; mod measurement; @@ -50,6 +51,7 @@ pub use self::control_flow::{Jump, JumpUnless, JumpWhen, Label, Target, TargetPl pub use self::declaration::{ Declaration, Load, MemoryReference, Offset, ScalarType, Sharing, Store, Vector, }; +pub use self::extern_call::*; pub use self::frame::{ AttributeValue, Capture, FrameAttributes, FrameDefinition, FrameIdentifier, Pulse, RawCapture, SetFrequency, SetPhase, SetScale, ShiftFrequency, ShiftPhase, SwapPhases, @@ -59,7 +61,7 @@ pub use self::gate::{ PauliSum, PauliTerm, }; pub use self::measurement::Measurement; -pub use self::pragma::{Include, Pragma, PragmaArgument}; +pub use self::pragma::{Include, Pragma, PragmaArgument, ReservedPragma, RESERVED_PRAGMA_EXTERN}; pub use self::qubit::{Qubit, QubitPlaceholder}; pub use self::reset::Reset; pub use self::timing::{Delay, Fence}; @@ -76,6 +78,7 @@ pub enum Instruction { Arithmetic(Arithmetic), BinaryLogic(BinaryLogic), CalibrationDefinition(Calibration), + Call(Call), Capture(Capture), CircuitDefinition(CircuitDefinition), Convert(Convert), @@ -101,6 +104,7 @@ pub enum Instruction { Pragma(Pragma), Pulse(Pulse), RawCapture(RawCapture), + ReservedPragma(ReservedPragma), Reset(Reset), SetFrequency(SetFrequency), SetPhase(SetPhase), @@ -149,6 +153,7 @@ impl From<&Instruction> for InstructionRole { | Instruction::ShiftPhase(_) | Instruction::SwapPhases(_) => InstructionRole::RFControl, Instruction::Arithmetic(_) + | Instruction::Call(_) | Instruction::Comparison(_) | Instruction::Convert(_) | Instruction::BinaryLogic(_) @@ -158,6 +163,7 @@ impl From<&Instruction> for InstructionRole { | Instruction::Load(_) | Instruction::Nop | Instruction::Pragma(_) + | Instruction::ReservedPragma(ReservedPragma::Extern(_)) | Instruction::Store(_) => InstructionRole::ClassicalCompute, Instruction::Halt | Instruction::Jump(_) @@ -267,6 +273,7 @@ impl Quil for Instruction { Instruction::CalibrationDefinition(calibration) => { calibration.write(f, fall_back_to_debug) } + Instruction::Call(call) => call.write(f, fall_back_to_debug), Instruction::Capture(capture) => capture.write(f, fall_back_to_debug), Instruction::CircuitDefinition(circuit) => circuit.write(f, fall_back_to_debug), Instruction::Convert(convert) => convert.write(f, fall_back_to_debug), @@ -292,6 +299,9 @@ impl Quil for Instruction { Instruction::Pulse(pulse) => pulse.write(f, fall_back_to_debug), Instruction::Pragma(pragma) => pragma.write(f, fall_back_to_debug), Instruction::RawCapture(raw_capture) => raw_capture.write(f, fall_back_to_debug), + Instruction::ReservedPragma(reserved_pragma) => { + reserved_pragma.write(f, fall_back_to_debug) + } Instruction::Reset(reset) => reset.write(f, fall_back_to_debug), Instruction::SetFrequency(set_frequency) => set_frequency.write(f, fall_back_to_debug), Instruction::SetPhase(set_phase) => set_phase.write(f, fall_back_to_debug), @@ -534,6 +544,7 @@ impl Instruction { Instruction::Arithmetic(_) | Instruction::BinaryLogic(_) | Instruction::CalibrationDefinition(_) + | Instruction::Call(_) | Instruction::CircuitDefinition(_) | Instruction::Comparison(_) | Instruction::Convert(_) @@ -554,6 +565,7 @@ impl Instruction { | Instruction::Move(_) | Instruction::Nop | Instruction::Pragma(_) + | Instruction::ReservedPragma(ReservedPragma::Extern(_)) | Instruction::Store(_) | Instruction::UnaryLogic(_) | Instruction::WaveformDefinition(_) @@ -663,6 +675,7 @@ impl Instruction { | Instruction::WaveformDefinition(_) => true, Instruction::Arithmetic(_) | Instruction::BinaryLogic(_) + | Instruction::Call(_) | Instruction::CircuitDefinition(_) | Instruction::Convert(_) | Instruction::Comparison(_) @@ -681,6 +694,7 @@ impl Instruction { | Instruction::Move(_) | Instruction::Nop | Instruction::Pragma(_) + | Instruction::ReservedPragma(ReservedPragma::Extern(_)) | Instruction::Reset(_) | Instruction::Store(_) | Instruction::Wait @@ -708,6 +722,7 @@ impl Instruction { Instruction::Arithmetic(_) | Instruction::BinaryLogic(_) | Instruction::CalibrationDefinition(_) + | Instruction::Call(_) | Instruction::CircuitDefinition(_) | Instruction::Convert(_) | Instruction::Comparison(_) @@ -728,6 +743,7 @@ impl Instruction { | Instruction::Move(_) | Instruction::Nop | Instruction::Pragma(_) + | Instruction::ReservedPragma(ReservedPragma::Extern(_)) | Instruction::Reset(_) | Instruction::Store(_) | Instruction::UnaryLogic(_) @@ -917,10 +933,14 @@ impl InstructionHandler { /// This uses the return value of the override function, if set and returns `Some`. If not set /// or the function returns `None`, defaults to the return value of /// [`Instruction::get_memory_accesses`]. - pub fn memory_accesses(&mut self, instruction: &Instruction) -> MemoryAccesses { + pub fn memory_accesses( + &mut self, + instruction: &Instruction, + ) -> crate::program::MemoryAccessesResult { self.get_memory_accesses .as_mut() .and_then(|f| f(instruction)) + .map(Ok) .unwrap_or_else(|| instruction.get_memory_accesses()) } } diff --git a/quil-rs/src/instruction/pragma.rs b/quil-rs/src/instruction/pragma.rs index d24ffbd0..d04d34cc 100644 --- a/quil-rs/src/instruction/pragma.rs +++ b/quil-rs/src/instruction/pragma.rs @@ -1,6 +1,6 @@ use crate::quil::Quil; -use super::QuotedString; +use super::{extern_call::ExternDefinition, QuotedString}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Pragma { @@ -77,3 +77,26 @@ impl Include { Self { filename } } } + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum ReservedPragma { + Extern(ExternDefinition), +} + +pub const RESERVED_PRAGMA_EXTERN: &str = "EXTERN"; + +impl Quil for ReservedPragma { + fn write( + &self, + f: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> crate::quil::ToQuilResult<()> { + write!(f, "PRAGMA ")?; + match self { + ReservedPragma::Extern(extern_definition) => { + write!(f, "{} ", RESERVED_PRAGMA_EXTERN)?; + extern_definition.write(f, fall_back_to_debug) + } + } + } +} diff --git a/quil-rs/src/parser/command.rs b/quil-rs/src/parser/command.rs index d1009a94..70331bbd 100644 --- a/quil-rs/src/parser/command.rs +++ b/quil-rs/src/parser/command.rs @@ -2,24 +2,27 @@ use nom::branch::alt; use nom::combinator::{map, map_res, opt}; use nom::multi::{many0, many1, separated_list0, separated_list1}; use nom::sequence::{delimited, pair, preceded, tuple}; +use num_complex::Complex64; use crate::expression::Expression; use crate::instruction::{ - Arithmetic, ArithmeticOperator, BinaryLogic, BinaryOperator, Calibration, Capture, - CircuitDefinition, Comparison, ComparisonOperator, Convert, Declaration, Delay, Exchange, - Fence, FrameDefinition, GateDefinition, GateSpecification, GateType, Include, Instruction, - Jump, JumpUnless, JumpWhen, Label, Load, MeasureCalibrationDefinition, Measurement, Move, - PauliSum, Pragma, PragmaArgument, Pulse, Qubit, RawCapture, Reset, SetFrequency, SetPhase, - SetScale, ShiftFrequency, ShiftPhase, Store, SwapPhases, Target, UnaryLogic, UnaryOperator, - ValidationError, Waveform, WaveformDefinition, + Arithmetic, ArithmeticOperator, BinaryLogic, BinaryOperator, Calibration, Call, CallArguments, + Capture, CircuitDefinition, Comparison, ComparisonOperator, Convert, Declaration, Delay, + Exchange, Fence, FrameDefinition, GateDefinition, GateSpecification, GateType, Include, + Instruction, Jump, JumpUnless, JumpWhen, Label, Load, MeasureCalibrationDefinition, + Measurement, Move, PauliSum, Pragma, PragmaArgument, Pulse, Qubit, RawCapture, Reset, + SetFrequency, SetPhase, SetScale, ShiftFrequency, ShiftPhase, Store, SwapPhases, Target, + UnaryLogic, UnaryOperator, UnresolvedCallArgument, ValidationError, Waveform, + WaveformDefinition, RESERVED_PRAGMA_EXTERN, }; use crate::parser::instruction::parse_block; use crate::parser::InternalParserResult; use crate::quil::Quil; -use crate::{real, token}; +use crate::{expected_token, real, token, unexpected_eof}; -use super::common::parse_variable_qubit; +use super::common::{parse_i, parse_memory_reference_with_brackets, parse_variable_qubit}; +use super::Token; use super::{ common::{ parse_arithmetic_operand, parse_binary_logic_operand, parse_comparison_operand, @@ -120,6 +123,61 @@ pub(crate) fn parse_declare<'a>(input: ParserInput<'a>) -> InternalParserResult< )) } +/// Parse the contents of a `CALL` instruction. +/// +/// Note, the `CALL` instruction here is unresolved; it may only be resolved within the +/// full context of a program. See [`crate::Program::resolve_call_instructions`]. +/// +/// Call instructions are of the form: +/// `CALL @ms{Identifier} @rep[:min 1]{@group{@ms{Identifier} @alt @ms{Memory Reference} @alt @ms{Complex}}}` +/// +/// For additional detail, see the [Quil specification "Call"](https://github.com/quil-lang/quil/blob/7f532c7cdde9f51eae6abe7408cc868fba9f91f6/specgen/spec/sec-other.s). +pub(crate) fn parse_call<'a>(input: ParserInput<'a>) -> InternalParserResult<'a, Instruction> { + let (input, name) = token!(Identifier(v))(input)?; + + let (input, arguments) = many0(parse_call_argument)(input)?; + let call = Call { + name, + arguments: CallArguments::Unresolved(arguments), + }; + + Ok((input, Instruction::Call(call))) +} + +fn parse_call_argument<'a>( + input: ParserInput<'a>, +) -> InternalParserResult<'a, UnresolvedCallArgument> { + alt(( + map( + parse_memory_reference_with_brackets, + UnresolvedCallArgument::MemoryReference, + ), + map(token!(Identifier(v)), UnresolvedCallArgument::Identifier), + map(parse_immediate_value, UnresolvedCallArgument::Immediate), + ))(input) +} + +fn parse_immediate_value(input: ParserInput) -> InternalParserResult { + match super::split_first_token(input) { + Some((Token::Integer(value), remainder)) => { + let (remainder, imaginary) = opt(parse_i)(remainder)?; + match imaginary { + None => Ok((remainder, crate::real!(*value as f64))), + Some(_) => Ok((remainder, crate::imag!(*value as f64))), + } + } + Some((Token::Float(value), remainder)) => { + let (remainder, imaginary) = opt(parse_i)(remainder)?; + match imaginary { + None => Ok((remainder, crate::real!(*value))), + Some(_) => Ok((remainder, crate::imag!(*value))), + } + } + Some((token, _)) => expected_token!(input, token, "integer or float".to_owned()), + None => unexpected_eof!(input), + } +} + /// Parse the contents of a `CAPTURE` instruction. /// /// Unlike most other instructions, this can be _prefixed_ with the NONBLOCKING keyword, @@ -463,6 +521,9 @@ pub(crate) fn parse_store<'a>(input: ParserInput<'a>) -> InternalParserResult<'a /// Parse the contents of a `PRAGMA` instruction. pub(crate) fn parse_pragma<'a>(input: ParserInput<'a>) -> InternalParserResult<'a, Instruction> { let (input, pragma_type) = token!(Identifier(v))(input)?; + if pragma_type == RESERVED_PRAGMA_EXTERN { + return super::reserved_pragma_extern::parse_reserved_pragma_extern(input); + } let (input, arguments) = many0(alt(( map(token!(Identifier(v)), PragmaArgument::Identifier), map(token!(Integer(i)), PragmaArgument::Integer), @@ -604,10 +665,11 @@ mod tests { PrefixExpression, PrefixOperator, }; use crate::instruction::{ - GateDefinition, GateSpecification, Offset, PauliGate, PauliSum, PauliTerm, PragmaArgument, - Sharing, + Call, CallArguments, GateDefinition, GateSpecification, Offset, PauliGate, PauliSum, + PauliTerm, PragmaArgument, Sharing, UnresolvedCallArgument, }; use crate::parser::lexer::lex; + use crate::parser::Token; use crate::{imag, real}; use crate::{ instruction::{ @@ -616,6 +678,7 @@ mod tests { }, make_test, }; + use rstest::*; use super::{parse_declare, parse_defcircuit, parse_defgate, parse_measurement, parse_pragma}; @@ -1021,4 +1084,111 @@ mod tests { }) }) ); + + /// Test case for parsing a `CALL` instruction. + struct ParseCallTestCase { + /// The input to parse. + input: &'static str, + /// The remaining tokens after parsing. + remainder: Vec, + /// The expected result. + expected: Result, + } + + impl ParseCallTestCase { + /// Basic call with arguments. + fn case_01() -> Self { + Self { + input: "foo integer[0] 1.0 bar", + remainder: vec![], + expected: Ok(Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(real!(1.0)), + UnresolvedCallArgument::Identifier("bar".to_string()), + ]), + }), + } + } + + /// No arguments does in fact parse. + fn case_02() -> Self { + Self { + input: "foo", + remainder: vec![], + expected: Ok(Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![]), + }), + } + } + + /// Invalid identifier. + fn case_03() -> Self { + Self { + input: "INCLUDE", + remainder: vec![], + expected: Err( + "ExpectedToken { actual: COMMAND(INCLUDE), expected: \"Identifier\" }" + .to_string(), + ), + } + } + + /// Valid with leftover + fn case_04() -> Self { + Self { + input: "foo integer[0] 1.0 bar; baz", + remainder: vec![Token::Semicolon, Token::Identifier("baz".to_string())], + expected: Ok(Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(real!(1.0)), + UnresolvedCallArgument::Identifier("bar".to_string()), + ]), + }), + } + } + } + + /// Test that the `parse_call` function works as expected. + #[rstest] + #[case(ParseCallTestCase::case_01())] + #[case(ParseCallTestCase::case_02())] + #[case(ParseCallTestCase::case_03())] + #[case(ParseCallTestCase::case_04())] + fn test_parse_call(#[case] test_case: ParseCallTestCase) { + let input = ::nom_locate::LocatedSpan::new(test_case.input); + let tokens = lex(input).unwrap(); + match (test_case.expected, super::parse_call(&tokens)) { + (Ok(expected), Ok((remainder, parsed))) => { + assert_eq!(parsed, Instruction::Call(expected)); + let remainder: Vec<_> = remainder.iter().map(|t| t.as_token().clone()).collect(); + assert_eq!(remainder, test_case.remainder); + } + (Ok(expected), Err(e)) => { + panic!("Expected {:?}, got error: {:?}", expected, e); + } + (Err(expected), Ok((_, parsed))) => { + panic!("Expected error: {:?}, got {:?}", expected, parsed); + } + (Err(expected), Err(found)) => { + let found = format!("{found:?}"); + assert!( + found.contains(&expected), + "`{}` not in `{}`", + expected, + found + ); + } + } + } } diff --git a/quil-rs/src/parser/common.rs b/quil-rs/src/parser/common.rs index a85dd845..dae8ac73 100644 --- a/quil-rs/src/parser/common.rs +++ b/quil-rs/src/parser/common.rs @@ -30,7 +30,7 @@ use crate::{ Vector, WaveformInvocation, WaveformParameters, }, parser::lexer::Operator, - token, + token, unexpected_eof, }; use crate::parser::{InternalParseError, InternalParserResult}; @@ -348,7 +348,7 @@ pub(crate) fn parse_variable_qubit(input: ParserInput) -> InternalParserResult ScalarType { +pub(super) fn match_data_type_token(token: DataType) -> ScalarType { match token { DataType::Bit => ScalarType::Bit, DataType::Integer => ScalarType::Integer, @@ -399,6 +399,18 @@ pub(crate) fn parse_vector<'a>(input: ParserInput<'a>) -> InternalParserResult<' Ok((input, Vector { data_type, length })) } +/// Parse a "vector" which is an integer index, such as `[0]`. The brackets are requried here. +pub(crate) fn parse_vector_with_brackets<'a>( + input: ParserInput<'a>, +) -> InternalParserResult<'a, Vector> { + let (input, data_type_token) = token!(DataType(v))(input)?; + let data_type = match_data_type_token(data_type_token); + + let (input, length) = delimited(token!(LBracket), token!(Integer(v)), token!(RBracket))(input)?; + + Ok((input, Vector { data_type, length })) +} + /// Parse a waveform name which may look like `custom` or `q20_q27_xy/sqrtiSWAP` pub(crate) fn parse_waveform_name<'a>(input: ParserInput<'a>) -> InternalParserResult<'a, String> { use crate::parser::lexer::Operator::Slash; @@ -425,6 +437,15 @@ pub(crate) fn skip_newlines_and_comments<'a>( Ok((input, ())) } +/// Returns successfully if the head of input is the identifier `i`, returns error otherwise. +pub(super) fn parse_i(input: ParserInput) -> InternalParserResult<()> { + match super::split_first_token(input) { + None => unexpected_eof!(input), + Some((Token::Identifier(v), remainder)) if v == "i" => Ok((remainder, ())), + Some((other_token, _)) => expected_token!(input, other_token, "i".to_owned()), + } +} + #[cfg(test)] mod describe_skip_newlines_and_comments { use crate::parser::lex; diff --git a/quil-rs/src/parser/error/mod.rs b/quil-rs/src/parser/error/mod.rs index f277bd17..4e1230a4 100644 --- a/quil-rs/src/parser/error/mod.rs +++ b/quil-rs/src/parser/error/mod.rs @@ -18,6 +18,8 @@ mod input; mod internal; mod kind; +use crate::instruction::ExternSignatureError; + use super::lexer::{Command, Token}; pub use error::Error; @@ -74,4 +76,7 @@ pub enum ParserErrorKind { "expected a Pauli term with a word length of {word_length} to match the number of arguments, {num_args}" )] PauliTermArgumentMismatch { word_length: usize, num_args: usize }, + + #[error(transparent)] + ExternSignature(#[from] Box), } diff --git a/quil-rs/src/parser/expression.rs b/quil-rs/src/parser/expression.rs index 63f73d62..08725d66 100644 --- a/quil-rs/src/parser/expression.rs +++ b/quil-rs/src/parser/expression.rs @@ -25,6 +25,7 @@ use crate::{ token, unexpected_eof, }; +use super::common::parse_i; use super::lexer::{Operator, Token}; use super::ParserInput; @@ -122,15 +123,6 @@ fn parse(input: ParserInput, precedence: Precedence) -> InternalParserResult InternalParserResult<()> { - match super::split_first_token(input) { - None => unexpected_eof!(input), - Some((Token::Identifier(v), remainder)) if v == "i" => Ok((remainder, ())), - Some((other_token, _)) => expected_token!(input, other_token, "i".to_owned()), - } -} - /// Given an expression function, parse the expression within its parentheses. fn parse_function_call<'a>( input: ParserInput<'a>, diff --git a/quil-rs/src/parser/instruction.rs b/quil-rs/src/parser/instruction.rs index 8c539311..d9900d92 100644 --- a/quil-rs/src/parser/instruction.rs +++ b/quil-rs/src/parser/instruction.rs @@ -45,6 +45,7 @@ pub(crate) fn parse_instruction(input: ParserInput) -> InternalParserResult match command { Command::Add => command::parse_arithmetic(ArithmeticOperator::Add, remainder), Command::And => command::parse_logical_binary(BinaryOperator::And, remainder), + Command::Call => command::parse_call(remainder), Command::Capture => command::parse_capture(remainder, true), Command::Convert => command::parse_convert(remainder), Command::Declare => command::parse_declare(remainder), diff --git a/quil-rs/src/parser/lexer/mod.rs b/quil-rs/src/parser/lexer/mod.rs index 235bc810..07d847dd 100644 --- a/quil-rs/src/parser/lexer/mod.rs +++ b/quil-rs/src/parser/lexer/mod.rs @@ -17,7 +17,7 @@ mod quoted_strings; mod wrapped_parsers; use nom::{ - bytes::complete::{is_a, take_till, take_while, take_while1}, + bytes::complete::{is_a, tag_no_case, take_till, take_while, take_while1}, character::complete::{digit1, one_of}, combinator::{all_consuming, map, recognize, value}, multi::many0, @@ -39,6 +39,7 @@ pub use error::{LexError, LexErrorKind}; pub enum Command { Add, And, + Call, Capture, Convert, Declare, @@ -196,6 +197,7 @@ fn recognize_command_or_identifier(identifier: String) -> Token { "DEFGATE" => Token::Command(DefGate), "ADD" => Token::Command(Add), "AND" => Token::Command(And), + "CALL" => Token::Command(Call), "CONVERT" => Token::Command(Convert), "DIV" => Token::Command(Div), "EQ" => Token::Command(Eq), @@ -316,6 +318,7 @@ fn lex_modifier(input: LexInput) -> InternalLexResult { value(Token::Modifier(Modifier::Controlled), tag("CONTROLLED")), value(Token::Modifier(Modifier::Dagger), tag("DAGGER")), value(Token::Modifier(Modifier::Forked), tag("FORKED")), + value(Token::Mutable, tag_no_case("MUT")), value(Token::Offset, tag("OFFSET")), value(Token::PauliSum, tag("PAULI-SUM")), value(Token::Permutation, tag("PERMUTATION")), diff --git a/quil-rs/src/parser/mod.rs b/quil-rs/src/parser/mod.rs index 95dc8f75..29d367ba 100644 --- a/quil-rs/src/parser/mod.rs +++ b/quil-rs/src/parser/mod.rs @@ -26,6 +26,7 @@ mod error; mod expression; pub(crate) mod instruction; mod lexer; +pub(crate) mod reserved_pragma_extern; mod token; pub(crate) use error::{ErrorInput, InternalParseError}; @@ -33,7 +34,7 @@ pub use error::{ParseError, ParserErrorKind}; pub use lexer::LexError; pub use token::{Token, TokenWithLocation}; -type ParserInput<'a> = &'a [TokenWithLocation<'a>]; +pub(crate) type ParserInput<'a> = &'a [TokenWithLocation<'a>]; type InternalParserResult<'a, R, E = InternalParseError<'a>> = IResult, R, E>; /// Pops the first token off of the `input` and returns it and the remaining input. diff --git a/quil-rs/src/parser/reserved_pragma_extern.rs b/quil-rs/src/parser/reserved_pragma_extern.rs new file mode 100644 index 00000000..1f37b534 --- /dev/null +++ b/quil-rs/src/parser/reserved_pragma_extern.rs @@ -0,0 +1,335 @@ +use std::str::FromStr; + +use nom::{ + branch::alt, + combinator::{map, opt}, + multi::separated_list0, +}; + +use crate::{ + instruction::{ + ExternDefinition, ExternParameter, ExternParameterType, ExternSignature, ReservedPragma, + ScalarType, + }, + token, +}; + +use super::{ + common::{match_data_type_token, parse_vector_with_brackets}, + InternalParserResult, ParserInput, +}; + +pub(super) fn parse_reserved_pragma_extern<'a>( + input: ParserInput<'a>, +) -> InternalParserResult<'a, crate::instruction::Instruction> { + let (input, name) = token!(Identifier(v))(input)?; + let (remainder, signature) = opt(token!(String(v)))(input)?; + let signature = if let Some(signature) = signature { + Some( + ExternSignature::from_str(signature.as_str()) + .map_err(|e| nom::Err::Error(e.into_internal_parse_error(input)))?, + ) + } else { + None + }; + + let extern_definition = ExternDefinition { name, signature }; + + Ok(( + remainder, + crate::instruction::Instruction::ReservedPragma(ReservedPragma::Extern(extern_definition)), + )) +} + +/// Parse an [`ExternSignature`] from a string. Note, externs are currently defined within a +/// [`crate::instruction::Pragma] instruction as `PRAGMA EXTERN foo "signature"`; the "signature" +/// currently represents its own mini-language within the Quil specification. +/// +/// Signatures are of the form: +/// `@rep[:min 0 :max 1]{@ms{Base Type}} ( @ms{Extern Parameter} @rep[:min 0]{@group{ , @ms{Extern Parameter} }} )` +/// +/// For details on the signature format, see the [Quil specification "Extern Signature"](https://github.com/quil-lang/quil/blob/7f532c7cdde9f51eae6abe7408cc868fba9f91f6/specgen/spec/sec-other.s). +pub(crate) fn parse_extern_signature<'a>( + input: ParserInput<'a>, +) -> InternalParserResult<'a, ExternSignature> { + let (input, return_type) = opt(token!(DataType(v)))(input)?; + let (input, lparen) = opt(token!(LParenthesis))(input)?; + let (input, parameters) = if lparen.is_some() { + let (input, parameters) = + opt(separated_list0(token!(Comma), parse_extern_parameter))(input)?; + let (input, _) = token!(RParenthesis)(input)?; + (input, parameters.unwrap_or_default()) + } else { + (input, vec![]) + }; + + let signature = ExternSignature { + return_type: return_type.map(match_data_type_token), + parameters, + }; + + Ok((input, signature)) +} + +fn parse_extern_parameter<'a>(input: ParserInput<'a>) -> InternalParserResult<'a, ExternParameter> { + let (input, name) = token!(Identifier(v))(input)?; + let (input, _) = token!(Colon)(input)?; + let (input, mutable) = opt(token!(Mutable))(input)?; + let (input, data_type) = alt(( + map( + parse_vector_with_brackets, + ExternParameterType::FixedLengthVector, + ), + map( + parse_variable_length_vector, + ExternParameterType::VariableLengthVector, + ), + map(token!(DataType(v)), |data_type| { + ExternParameterType::Scalar(match_data_type_token(data_type)) + }), + ))(input)?; + Ok(( + input, + ExternParameter { + name, + mutable: mutable.is_some(), + data_type, + }, + )) +} + +/// Parse a "vector", which is a [`DataType`] followed by empty brackets `[]`. +fn parse_variable_length_vector<'a>( + input: ParserInput<'a>, +) -> InternalParserResult<'a, ScalarType> { + let (input, data_type_token) = token!(DataType(v))(input)?; + let data_type = match_data_type_token(data_type_token); + let (input, _) = token!(LBracket)(input)?; + let (input, _) = token!(RBracket)(input)?; + + Ok((input, data_type)) +} + +#[cfg(test)] +mod tests { + use crate::{ + instruction::Vector, + parser::{lex, InternalParseError, Token}, + }; + + use super::*; + use rstest::*; + + struct ParseReservedPragmaExternTestCase { + input: &'static str, + remainder: Vec, + expected: Result>, + } + + impl ParseReservedPragmaExternTestCase { + /// No signature + fn case_01() -> Self { + Self { + input: "foo", + remainder: vec![], + expected: Ok(crate::instruction::ExternDefinition { + name: "foo".to_string(), + signature: None, + }), + } + } + + /// Empty signature + fn case_02() -> Self { + Self { + input: "foo \"\"", + remainder: vec![], + expected: Ok(crate::instruction::ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: None, + parameters: vec![], + }), + }), + } + } + + /// Empty signature with parentheses + fn case_03() -> Self { + Self { + input: "foo \"()\";", + remainder: vec![Token::Semicolon], + expected: Ok(crate::instruction::ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: None, + parameters: vec![], + }), + }), + } + } + + /// Return without parameters + fn case_04() -> Self { + Self { + input: "foo \"INTEGER\"", + remainder: vec![], + expected: Ok(crate::instruction::ExternDefinition { + name: "foo".to_string(), + signature: Some(crate::instruction::ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }), + }), + } + } + + /// Return with empty parentheses + fn case_05() -> Self { + Self { + input: "foo \"INTEGER ()\"", + remainder: vec![], + expected: Ok(crate::instruction::ExternDefinition { + name: "foo".to_string(), + signature: Some(crate::instruction::ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }), + }), + } + } + + /// Return with parameters + fn case_06() -> Self { + Self { + input: "foo \"INTEGER (bar: REAL, baz: BIT[10], biz: mut OCTET)\"", + remainder: vec![], + expected: Ok(crate::instruction::ExternDefinition { + name: "foo".to_string(), + signature: Some(crate::instruction::ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Real), + }, + ExternParameter { + name: "baz".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 10, + }), + }, + ExternParameter { + name: "biz".to_string(), + mutable: true, + data_type: ExternParameterType::Scalar(ScalarType::Octet), + }, + ], + }), + }), + } + } + + /// Parameters without return + fn case_07() -> Self { + Self { + input: "foo \"(bar: REAL, baz: BIT[10], biz : mut OCTET)\";", + remainder: vec![Token::Semicolon], + expected: Ok(crate::instruction::ExternDefinition { + name: "foo".to_string(), + signature: Some(crate::instruction::ExternSignature { + return_type: None, + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Real), + }, + ExternParameter { + name: "baz".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 10, + }), + }, + ExternParameter { + name: "biz".to_string(), + mutable: true, + data_type: ExternParameterType::Scalar(ScalarType::Octet), + }, + ], + }), + }), + } + } + + /// Variable length vector. + fn case_08() -> Self { + Self { + input: "foo \"(bar : mut REAL[])\";", + remainder: vec![Token::Semicolon], + expected: Ok(crate::instruction::ExternDefinition { + name: "foo".to_string(), + signature: Some(crate::instruction::ExternSignature { + return_type: None, + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: true, + data_type: ExternParameterType::VariableLengthVector(ScalarType::Real), + }], + }), + }), + } + } + } + + /// Test parsing of `PRAGMA EXTERN` instructions. + #[rstest] + #[case(ParseReservedPragmaExternTestCase::case_01())] + #[case(ParseReservedPragmaExternTestCase::case_02())] + #[case(ParseReservedPragmaExternTestCase::case_03())] + #[case(ParseReservedPragmaExternTestCase::case_04())] + #[case(ParseReservedPragmaExternTestCase::case_05())] + #[case(ParseReservedPragmaExternTestCase::case_06())] + #[case(ParseReservedPragmaExternTestCase::case_07())] + #[case(ParseReservedPragmaExternTestCase::case_08())] + fn test_parse_reserved_pragma_extern(#[case] test_case: ParseReservedPragmaExternTestCase) { + let input = ::nom_locate::LocatedSpan::new(test_case.input); + let tokens = lex(input).unwrap(); + match ( + test_case.expected, + super::parse_reserved_pragma_extern(&tokens), + ) { + (Ok(expected), Ok((remainder, parsed))) => { + assert_eq!( + parsed, + crate::instruction::Instruction::ReservedPragma( + crate::instruction::ReservedPragma::Extern(expected) + ) + ); + let remainder: Vec<_> = remainder.iter().map(|t| t.as_token().clone()).collect(); + assert_eq!(remainder, test_case.remainder); + } + (Ok(expected), Err(e)) => { + panic!("Expected {:?}, got error: {:?}", expected, e); + } + (Err(expected), Ok((_, parsed))) => { + panic!("Expected error: {:?}, got {:?}", expected, parsed); + } + (Err(expected), Err(found)) => { + let expected = format!("{expected:?}"); + let found = format!("{found:?}"); + assert!( + found.contains(&expected), + "`{}` not in `{}`", + expected, + found + ); + } + } + } +} diff --git a/quil-rs/src/parser/token.rs b/quil-rs/src/parser/token.rs index 18210b27..6f0b1a9e 100644 --- a/quil-rs/src/parser/token.rs +++ b/quil-rs/src/parser/token.rs @@ -85,6 +85,7 @@ pub enum Token { NonBlocking, Matrix, Modifier(Modifier), + Mutable, NewLine, Operator(Operator), Offset, @@ -117,6 +118,7 @@ impl fmt::Display for Token { Token::NonBlocking => write!(f, "NONBLOCKING"), Token::Matrix => write!(f, "MATRIX"), Token::Modifier(m) => write!(f, "{m}"), + Token::Mutable => write!(f, "MUT"), Token::NewLine => write!(f, "NEWLINE"), Token::Operator(op) => write!(f, "{op}"), Token::Offset => write!(f, "OFFSET"), @@ -151,6 +153,7 @@ impl fmt::Debug for Token { Token::NonBlocking => write!(f, "{self}"), Token::Matrix => write!(f, "{self}"), Token::Modifier(m) => write!(f, "MODIFIER({m})"), + Token::Mutable => write!(f, "{self}"), Token::NewLine => write!(f, "NEWLINE"), Token::Operator(op) => write!(f, "OPERATOR({op})"), Token::Offset => write!(f, "{self}"), diff --git a/quil-rs/src/program/analysis/control_flow_graph.rs b/quil-rs/src/program/analysis/control_flow_graph.rs index 98f60170..9abaaa63 100644 --- a/quil-rs/src/program/analysis/control_flow_graph.rs +++ b/quil-rs/src/program/analysis/control_flow_graph.rs @@ -21,7 +21,8 @@ use std::{ use crate::{ instruction::{ - Instruction, InstructionHandler, Jump, JumpUnless, JumpWhen, Label, MemoryReference, Target, + Instruction, InstructionHandler, Jump, JumpUnless, JumpWhen, Label, MemoryReference, + ReservedPragma, Target, }, program::{ scheduling::{ @@ -418,6 +419,7 @@ impl<'p> From<&'p Program> for ControlFlowGraph<'p> { match instruction { Instruction::Arithmetic(_) | Instruction::BinaryLogic(_) + | Instruction::Call(_) | Instruction::Capture(_) | Instruction::Convert(_) | Instruction::Comparison(_) @@ -432,6 +434,7 @@ impl<'p> From<&'p Program> for ControlFlowGraph<'p> { | Instruction::Nop | Instruction::Pulse(_) | Instruction::RawCapture(_) + | Instruction::ReservedPragma(ReservedPragma::Extern(_)) | Instruction::Reset(_) | Instruction::SetFrequency(_) | Instruction::SetPhase(_) diff --git a/quil-rs/src/program/memory.rs b/quil-rs/src/program/memory.rs index 3d914aff..2ee18f9a 100644 --- a/quil-rs/src/program/memory.rs +++ b/quil-rs/src/program/memory.rs @@ -16,12 +16,14 @@ use std::collections::HashSet; use crate::expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression}; use crate::instruction::{ - Arithmetic, ArithmeticOperand, BinaryLogic, BinaryOperand, Capture, CircuitDefinition, - Comparison, ComparisonOperand, Convert, Delay, Exchange, Gate, GateDefinition, - GateSpecification, Instruction, JumpUnless, JumpWhen, Load, MeasureCalibrationDefinition, - Measurement, MemoryReference, Move, Pulse, RawCapture, SetFrequency, SetPhase, SetScale, - Sharing, ShiftFrequency, ShiftPhase, Store, UnaryLogic, Vector, WaveformInvocation, + Arithmetic, ArithmeticOperand, BinaryLogic, BinaryOperand, Call, CallArguments, Capture, + CircuitDefinition, Comparison, ComparisonOperand, Convert, Delay, Exchange, Gate, + GateDefinition, GateSpecification, Instruction, JumpUnless, JumpWhen, Load, + MeasureCalibrationDefinition, Measurement, MemoryReference, Move, Pulse, RawCapture, + ReservedPragma, SetFrequency, SetPhase, SetScale, Sharing, ShiftFrequency, ShiftPhase, Store, + UnaryLogic, Vector, WaveformInvocation, }; +use crate::quil::Quil; #[derive(Clone, Debug, Hash, PartialEq)] pub struct MemoryRegion { @@ -43,7 +45,7 @@ pub struct MemoryAccess { pub access_type: MemoryAccessType, } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct MemoryAccesses { pub captures: HashSet, pub reads: HashSet, @@ -92,10 +94,18 @@ macro_rules! set_from_memory_references { }; } +#[derive(Clone, thiserror::Error, Debug, PartialEq)] +pub enum MemoryAccessesError { + #[error("cannot get instruction memory accesses before resolving call instructions: {}", .0.to_quil_or_debug())] + UnresolvedCallInstruction(Call), +} + +pub type MemoryAccessesResult = Result; + impl Instruction { /// Return all memory accesses by the instruction - in expressions, captures, and memory manipulation - pub fn get_memory_accesses(&self) -> MemoryAccesses { - match self { + pub fn get_memory_accesses(&self) -> MemoryAccessesResult { + Ok(match self { Instruction::Convert(Convert { source, destination, @@ -104,6 +114,28 @@ impl Instruction { writes: set_from_memory_references![[destination]], ..Default::default() }, + Instruction::Call(call) => match call.arguments { + CallArguments::Resolved(ref arguments) => { + let mut reads = HashSet::new(); + let mut writes = HashSet::new(); + for argument in arguments { + if let Some(name) = argument.name() { + if argument.is_mutable() { + writes.insert(name.clone()); + } + reads.insert(name); + } + } + MemoryAccesses { + reads, + writes, + ..Default::default() + } + } + CallArguments::Unresolved(_) => { + return Err(MemoryAccessesError::UnresolvedCallInstruction(call.clone())); + } + }, Instruction::Comparison(Comparison { destination, lhs, @@ -187,14 +219,16 @@ impl Instruction { | Instruction::MeasureCalibrationDefinition(MeasureCalibrationDefinition { instructions, .. - }) => instructions.iter().fold(Default::default(), |acc, el| { - let el_accesses = el.get_memory_accesses(); - MemoryAccesses { - reads: merge_sets!(acc.reads, el_accesses.reads), - writes: merge_sets!(acc.writes, el_accesses.writes), - captures: merge_sets!(acc.captures, el_accesses.captures), - } - }), + }) => instructions + .iter() + .try_fold(Default::default(), |acc: MemoryAccesses, el| { + let el_accesses = el.get_memory_accesses()?; + Ok(MemoryAccesses { + reads: merge_sets!(acc.reads, el_accesses.reads), + writes: merge_sets!(acc.writes, el_accesses.writes), + captures: merge_sets!(acc.captures, el_accesses.captures), + }) + })?, Instruction::Delay(Delay { duration, .. }) => MemoryAccesses { reads: set_from_memory_references!(duration.get_memory_references()), ..Default::default() @@ -298,10 +332,11 @@ impl Instruction { | Instruction::Label(_) | Instruction::Nop | Instruction::Pragma(_) + | Instruction::ReservedPragma(ReservedPragma::Extern(_)) | Instruction::Reset(_) | Instruction::SwapPhases(_) | Instruction::WaveformDefinition(_) => Default::default(), - } + }) } } @@ -451,7 +486,9 @@ mod tests { #[case] instruction: Instruction, #[case] expected: MemoryAccesses, ) { - let memory_accesses = instruction.get_memory_accesses(); + let memory_accesses = instruction + .get_memory_accesses() + .expect("must be able to get memory accesses"); assert_eq!(memory_accesses.captures, expected.captures); assert_eq!(memory_accesses.reads, expected.reads); assert_eq!(memory_accesses.writes, expected.writes); diff --git a/quil-rs/src/program/mod.rs b/quil-rs/src/program/mod.rs index 7846bb79..c1a885d7 100644 --- a/quil-rs/src/program/mod.rs +++ b/quil-rs/src/program/mod.rs @@ -21,20 +21,25 @@ use ndarray::Array2; use nom_locate::LocatedSpan; use crate::instruction::{ - Arithmetic, ArithmeticOperand, ArithmeticOperator, Declaration, FrameDefinition, - FrameIdentifier, GateDefinition, GateError, Instruction, Jump, JumpUnless, Label, Matrix, - MemoryReference, Move, Qubit, QubitPlaceholder, ScalarType, Target, TargetPlaceholder, Vector, - Waveform, WaveformDefinition, + Arithmetic, ArithmeticOperand, ArithmeticOperator, CallResolutionError, Declaration, + ExternDefinition, ExternValidationError, FrameDefinition, FrameIdentifier, GateDefinition, + GateError, Instruction, Jump, JumpUnless, Label, Matrix, MemoryReference, Move, Qubit, + QubitPlaceholder, ReservedPragma, ScalarType, Target, TargetPlaceholder, Vector, Waveform, + WaveformDefinition, }; use crate::parser::{lex, parse_instructions, ParseError}; use crate::quil::Quil; pub use self::calibration::Calibrations; pub use self::calibration_set::CalibrationSet; -pub use self::error::{disallow_leftover, map_parsed, recover, ParseProgramError, SyntaxError}; +pub use self::error::{ + disallow_leftover, map_parsed, recover, LeftoverError, ParseProgramError, SyntaxError, +}; pub use self::frame::FrameSet; pub use self::frame::MatchedFrames; -pub use self::memory::{MemoryAccess, MemoryAccesses, MemoryRegion}; +pub use self::memory::{ + MemoryAccess, MemoryAccesses, MemoryAccessesError, MemoryAccessesResult, MemoryRegion, +}; pub mod analysis; mod calibration; @@ -61,6 +66,12 @@ pub enum ProgramError { #[error("can only compute program unitary for programs composed of `Gate`s; found unsupported instruction: {}", .0.to_quil_or_debug())] UnsupportedForUnitary(Instruction), + + #[error("failed to resolve call instructions: {0}")] + CallResolution(#[from] CallResolutionError), + + #[error(transparent)] + ExternValidation(#[from] ExternValidationError), } type Result = std::result::Result; @@ -622,6 +633,47 @@ impl Program { pub fn get_instruction(&self, index: usize) -> Option<&Instruction> { self.instructions.get(index) } + + pub fn resolve_call_instructions(&mut self) -> Result<()> { + if !self + .instructions + .iter() + .any(|instruction| matches!(instruction, Instruction::Call(_))) + { + return Ok(()); + } + + let extern_definitions = self + .instructions + .iter() + .filter_map(|instruction| { + if let Instruction::ReservedPragma(ReservedPragma::Extern(extern_definition)) = + instruction + { + Some(extern_definition.clone()) + } else { + None + } + }) + .collect::>(); + ExternDefinition::validate_all(extern_definitions.as_slice())?; + + let calls = self + .instructions + .iter_mut() + .filter_map(|instruction| { + if let Instruction::Call(call) = instruction { + Some(call) + } else { + None + } + }) + .collect::>(); + for call in calls { + call.resolve(&self.memory_regions, extern_definitions.as_slice())?; + } + Ok(()) + } } impl Quil for Program { @@ -699,9 +751,12 @@ mod tests { use crate::{ imag, instruction::{ - Gate, Instruction, Jump, JumpUnless, JumpWhen, Label, Matrix, MemoryReference, Qubit, - QubitPlaceholder, Target, TargetPlaceholder, + Call, CallArguments, Declaration, ExternDefinition, ExternParameter, + ExternParameterType, ExternSignature, Gate, Instruction, Jump, JumpUnless, JumpWhen, + Label, Matrix, MemoryReference, Qubit, QubitPlaceholder, ReservedPragma, ScalarType, + Target, TargetPlaceholder, UnresolvedCallArgument, Vector, }, + program::MemoryAccesses, quil::Quil, real, }; @@ -1507,4 +1562,81 @@ DEFFRAME 0 \"xy\": assert_eq!(new_program.to_quil().unwrap(), quil); } } + + /// Test that a program with a `CALL` instruction can be parsed and properly resolved to + /// the corresponding `EXTERN` instruction. Finally, test that the memory accesses are + /// correctly calculated with the resolved `CALL` instruction. + #[test] + fn test_extern_call() { + let input = r#"DECLARE reals REAL[3] +DECLARE octets OCTET[3] +PRAGMA EXTERN foo "OCTET (params : mut REAL[3])" +CALL foo octets[1] reals +"#; + let mut program = Program::from_str(input).expect("should be able to parse program"); + let reserialized = program + .to_quil() + .expect("should be able to serialize program"); + assert_eq!(input, reserialized); + + let expected_program = Program::from_instructions(vec![ + Instruction::Declaration(Declaration::new( + "reals".to_string(), + Vector::new(ScalarType::Real, 3), + None, + )), + Instruction::Declaration(Declaration::new( + "octets".to_string(), + Vector::new(ScalarType::Octet, 3), + None, + )), + Instruction::ReservedPragma(ReservedPragma::Extern(ExternDefinition { + name: "foo".to_string(), + signature: Some(ExternSignature { + return_type: Some(ScalarType::Octet), + parameters: vec![ExternParameter { + name: "params".to_string(), + mutable: true, + data_type: ExternParameterType::FixedLengthVector(Vector::new( + ScalarType::Real, + 3, + )), + }], + }), + })), + Instruction::Call(Call { + name: "foo".to_string(), + arguments: CallArguments::Unresolved(vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "octets".to_string(), + index: 1, + }), + UnresolvedCallArgument::Identifier("reals".to_string()), + ]), + }), + ]); + assert_eq!(expected_program, program); + + program + .resolve_call_instructions() + .expect("should be able to resolve calls"); + + assert_eq!( + program.instructions[0] + .get_memory_accesses() + .expect("should be able to get memory accesses"), + MemoryAccesses::default() + ); + + assert_eq!( + program.instructions[1] + .get_memory_accesses() + .expect("should be able to get memory accesses"), + MemoryAccesses { + reads: ["octets", "reals"].into_iter().map(String::from).collect(), + writes: ["octets", "reals"].into_iter().map(String::from).collect(), + ..MemoryAccesses::default() + } + ); + } } diff --git a/quil-rs/src/program/scheduling/graph.rs b/quil-rs/src/program/scheduling/graph.rs index 2e9cd4e8..171cc7c2 100644 --- a/quil-rs/src/program/scheduling/graph.rs +++ b/quil-rs/src/program/scheduling/graph.rs @@ -31,6 +31,7 @@ pub use crate::program::memory::MemoryAccessType; pub enum ScheduleErrorVariant { DuplicateLabel, UncalibratedInstruction, + UnresolvedCallInstruction, UnschedulableInstruction, } @@ -303,7 +304,14 @@ impl<'a> ScheduledBasicBlock<'a> { for (index, &instruction) in basic_block.instructions().iter().enumerate() { let node = graph.add_node(ScheduledGraphNode::InstructionIndex(index)); - let accesses = custom_handler.memory_accesses(instruction); + let accesses = + custom_handler + .memory_accesses(instruction) + .map_err(|_| ScheduleError { + instruction_index: Some(index), + instruction: instruction.clone(), + variant: ScheduleErrorVariant::UnresolvedCallInstruction, + })?; let memory_dependencies = [ (accesses.reads, MemoryAccessType::Read),