diff --git a/.gitignore b/.gitignore index d91fab29..7fbdc4bb 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ __pycache__ # quil-py documentation quil-py/build +quil-py/pyrightconfig.json # unversioned developer notes scratch/ diff --git a/Cargo.lock b/Cargo.lock index c655ffaf..e15f9f50 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1132,7 +1132,7 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quil-cli" -version = "0.5.1-rc.0" +version = "0.5.1" dependencies = [ "anyhow", "clap", @@ -1141,7 +1141,7 @@ dependencies = [ [[package]] name = "quil-py" -version = "0.12.1-rc.0" +version = "0.12.1" dependencies = [ "indexmap", "ndarray", @@ -1155,7 +1155,7 @@ dependencies = [ [[package]] name = "quil-rs" -version = "0.28.1-rc.0" +version = "0.28.1" dependencies = [ "approx", "clap", diff --git a/quil-py/quil/instructions/__init__.pyi b/quil-py/quil/instructions/__init__.pyi index 545472d1..1b7c005a 100644 --- a/quil-py/quil/instructions/__init__.pyi +++ b/quil-py/quil/instructions/__init__.pyi @@ -79,6 +79,7 @@ class Instruction: Calibration, Capture, BinaryLogic, + Call, CircuitDefinition, Convert, Comparison, @@ -131,6 +132,7 @@ class Instruction: Capture, BinaryLogic, CircuitDefinition, + Call, Convert, Comparison, Declaration, @@ -168,6 +170,7 @@ class Instruction: def is_arithmetic(self) -> bool: ... def is_binary_logic(self) -> bool: ... def is_calibration_definition(self) -> bool: ... + def is_call(self) -> bool: ... def is_capture(self) -> bool: ... def is_circuit_definition(self) -> bool: ... def is_convert(self) -> bool: ... @@ -217,6 +220,8 @@ class Instruction: @staticmethod def from_calibration_definition(inner: Calibration) -> Instruction: ... @staticmethod + def from_call(inner: Call) -> Instruction: ... + @staticmethod def from_capture(inner: Capture) -> Instruction: ... @staticmethod def from_circuit_definition( @@ -294,6 +299,8 @@ class Instruction: def to_arithmetic(self) -> Arithmetic: ... def as_binary_logic(self) -> Optional[BinaryLogic]: ... def to_binary_logic(self) -> BinaryLogic: ... + def as_call(self) -> Optional[Call]: ... + def to_call(self) -> Call: ... def as_convert(self) -> Optional[Convert]: ... def to_convert(self) -> Convert: ... def as_comparison(self) -> Optional[Comparison]: ... @@ -1175,6 +1182,213 @@ class FrameIdentifier: If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. """ +@final +class CallArgument: + """An argument to a ``Call`` instruction. + + This may be expressed as an identifier, a memory reference, or an immediate value. Memory references and identifiers require a corresponding memory region declaration by the time of + compilation (at the time of call argument resolution and memory graph construction, to be more precise). + + Additionally, an argument's resolved type must match the expected type of the corresponding ``ExternParameter`` + in the ``ExternSignature``. + """ + def inner(self) -> Union[List[List[Expression]], List[int], PauliSum]: + """Returns the inner value of the variant. Raises a ``RuntimeError`` if inner data doesn't exist.""" + ... + def is_identifier(self) -> bool: ... + def is_memory_reference(self) -> bool: ... + def is_immediate(self) -> bool: ... + def as_identifier(self) -> Optional[str]: ... + def to_identifier(self) -> str: ... + def as_memory_reference(self) -> Optional["MemoryReference"]: ... + def to_memory_reference(self) -> "MemoryReference": ... + def as_immediate(self) -> Optional[complex]: ... + def to_immediate(self) -> complex: ... + @staticmethod + def from_identifier(inner: str) -> "CallArgument": ... + @staticmethod + def from_memory_reference(inner: "MemoryReference") -> "CallArgument": ... + @staticmethod + def from_immediate(inner: complex) -> "CallArgument": ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + +class CallError(ValueError): + """An error that may occur when initializing a ``Call``.""" + + ... + +class Call: + """An instruction that calls an external function declared with a `PRAGMA EXTERN` instruction. + + These calls are generally specific to a particular hardware or virtual machine + backend. For further detail, see: + + * `Other instructions and Directives `_ + in the Quil specification. + * `EXTERN / CALL RFC `_ + * `quil#87 `_ + + + Also see ``ExternSignature``. + """ + def __new__( + cls, + name: str, + arguments: List["CallArgument"], + ) -> Self: ... + @property + def name(self) -> str: ... + @property + def arguments(self) -> List[CallArgument]: ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + def __deepcopy__(self, _: Dict) -> Self: + """Creates and returns a deep copy of the class. + + If the instruction contains any ``QubitPlaceholder`` or ``TargetPlaceholder``, then they will be replaced with + new placeholders so resolving them in the copy will not resolve them in the original. + Should be used by passing an instance of the class to ``copy.deepcopy`` + """ + def __copy__(self) -> Self: + """Returns a shallow copy of the class.""" + +@final +class ExternParameterType: + """The type of an ``ExternParameter``. + + This type is used to define the expected type of a parameter within an ``ExternSignature``. May be + either a scalar, fixed-length vector, or variable-length vector. + + Note, both scalars and fixed-length vectors are fully specified by a ``ScalarType``, but are indeed + distinct ``ExternParameterType``s. + """ + + def inner(self) -> Union["ScalarType", "Vector"]: + """Returns the inner value of the variant. Raises a ``RuntimeError`` if inner data doesn't exist.""" + ... + def is_scalar(self) -> bool: ... + def is_fixed_length_vector(self) -> bool: ... + def is_variable_length_vector(self) -> bool: ... + def as_scalar(self) -> Optional["ScalarType"]: ... + def to_scalar(self) -> "ScalarType": ... + def as_fixed_length_vector(self) -> Optional["Vector"]: ... + def to_fixed_length_vector(self) -> "MemoryReference": ... + def as_variable_length_vector(self) -> Optional["ScalarType"]: ... + def to_variable_length_vector(self) -> complex: ... + @staticmethod + def from_scalar(inner: "ScalarType") -> "ExternParameterType": ... + @staticmethod + def from_fixed_length_vector(inner: "Vector") -> "ExternParameterType": ... + @staticmethod + def from_variable_length_vector(inner: "ScalarType") -> "ExternParameterType": ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + +class ExternParameter: + """A parameter within an ``ExternSignature``. These are defined by a name, mutability, and type.""" + def __new__( + cls, + name: str, + mutable: bool, + data_type: ExternParameterType, + ) -> Self: ... + @property + def name(self) -> str: ... + @property + def mutable(self) -> bool: ... + @property + def data_type(self) -> ExternParameterType: ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + def __deepcopy__(self, _: Dict) -> Self: + """Creates and returns a deep copy of the class. + + If the instruction contains any ``QubitPlaceholder`` or ``TargetPlaceholder``, then they will be replaced with + new placeholders so resolving them in the copy will not resolve them in the original. + Should be used by passing an instance of the class to ``copy.deepcopy`` + """ + def __copy__(self) -> Self: + """Returns a shallow copy of the class.""" + +class ExternSignature: + """The signature of a ``PRAGMA EXTERN`` instruction. + + This signature is defined by a list of ``ExternParameter``s and an + optional return type. See the `Quil Specification `_ + for details on how these signatures are formed. + """ + def __new__( + cls, + parameters: List[ExternParameter], + return_type: Optional[ScalarType], + ) -> Self: ... + @property + def parameters(self) -> List[ExternParameter]: ... + @property + def return_type(self) -> Optional[ScalarType]: ... + def to_quil(self) -> str: + """Attempt to convert the instruction to a valid Quil string. + + Raises an exception if the instruction can't be converted to valid Quil. + """ + ... + def to_quil_or_debug(self) -> str: + """Convert the instruction to a Quil string. + + If any part of the instruction can't be converted to valid Quil, it will be printed in a human-readable debug format. + """ + def __deepcopy__(self, _: Dict) -> Self: + """Creates and returns a deep copy of the class. + + If the instruction contains any ``QubitPlaceholder`` or ``TargetPlaceholder``, then they will be replaced with + new placeholders so resolving them in the copy will not resolve them in the original. + Should be used by passing an instance of the class to ``copy.deepcopy`` + """ + def __copy__(self) -> Self: + """Returns a shallow copy of the class.""" + +class ExternError(ValueError): + """An error that may occur when initializing or validating a ``PRAGMA EXTERN`` instruction.""" + + ... + class Capture: def __new__( cls, diff --git a/quil-py/src/instruction/declaration.rs b/quil-py/src/instruction/declaration.rs index 5d165bd6..256381f8 100644 --- a/quil-py/src/instruction/declaration.rs +++ b/quil-py/src/instruction/declaration.rs @@ -38,6 +38,22 @@ impl_repr!(PyScalarType); impl_to_quil!(PyScalarType); impl_hash!(PyScalarType); +impl rigetti_pyo3::PyTryFrom for PyScalarType { + fn py_try_from(_py: Python, item: &pyo3::PyAny) -> PyResult { + let item = item.extract::()?; + match item.as_str() { + "BIT" => Ok(Self::Bit), + "INTEGER" => Ok(Self::Integer), + "OCTET" => Ok(Self::Octet), + "REAL" => Ok(Self::Real), + _ => Err(PyValueError::new_err(format!( + "Invalid value for ScalarType: {}", + item + ))), + } + } +} + py_wrap_data_struct! { #[derive(Debug, Hash, PartialEq, Eq)] #[pyo3(subclass)] diff --git a/quil-py/src/instruction/extern_call.rs b/quil-py/src/instruction/extern_call.rs new file mode 100644 index 00000000..d3beaff2 --- /dev/null +++ b/quil-py/src/instruction/extern_call.rs @@ -0,0 +1,189 @@ +use quil_rs::instruction::{ + Call, ExternParameter, ExternParameterType, ExternSignature, ScalarType, UnresolvedCallArgument, +}; + +use rigetti_pyo3::{ + impl_hash, impl_repr, py_wrap_error, py_wrap_union_enum, + pyo3::{pymethods, types::PyString, Py, PyResult, Python}, + wrap_error, PyTryFrom, ToPythonError, +}; + +use crate::{impl_copy_for_instruction, impl_eq, impl_pickle_for_instruction, impl_to_quil}; + +use super::{PyScalarType, PyVector}; + +wrap_error!(RustCallError(quil_rs::instruction::CallError)); +py_wrap_error!( + quil, + RustCallError, + CallError, + rigetti_pyo3::pyo3::exceptions::PyValueError +); + +wrap_error!(RustExternError(quil_rs::instruction::ExternError)); +py_wrap_error!( + quil, + RustExternError, + ExternError, + rigetti_pyo3::pyo3::exceptions::PyValueError +); + +rigetti_pyo3::py_wrap_type! { + #[derive(Debug, PartialEq, Eq)] + #[pyo3(subclass, module = "quil.instructions")] + PyCall(Call) as "Call" +} + +#[pymethods] +impl PyCall { + #[new] + fn new(name: String, arguments: Vec) -> PyResult { + Call::try_new( + name, + arguments.into_iter().map(PyCallArgument::into).collect(), + ) + .map(Self) + .map_err(RustCallError::from) + .map_err(RustCallError::to_py_err) + } + + #[getter] + fn name(&self) -> &str { + self.0.name() + } + + #[getter] + fn arguments(&self) -> Vec { + self.0 + .arguments() + .iter() + .map(PyCallArgument::from) + .collect() + } +} + +rigetti_pyo3::impl_as_mut_for_wrapper!(PyCall); +impl_repr!(PyCall); +impl_to_quil!(PyCall); +impl_copy_for_instruction!(PyCall); +impl_hash!(PyCall); +impl_eq!(PyCall); +impl_pickle_for_instruction!(PyCall); + +py_wrap_union_enum! { + #[derive(Debug, PartialEq, Eq)] + PyCallArgument(UnresolvedCallArgument) as "CallArgument" { + identifier: Identifier => Py, + memory_reference: MemoryReference => super::PyMemoryReference, + immediate: Immediate => Py + } +} +impl_repr!(PyCallArgument); +impl_to_quil!(PyCallArgument); +impl_hash!(PyCallArgument); +impl_eq!(PyCallArgument); + +py_wrap_union_enum! { + #[derive(Debug, PartialEq, Eq)] + PyExternParameterType(ExternParameterType) as "ExternParameterType" { + scalar: Scalar => PyScalarType, + fixed_length_vector: FixedLengthVector => PyVector, + variable_length_vector: VariableLengthVector => PyScalarType + } +} +impl_repr!(PyExternParameterType); +impl_to_quil!(PyExternParameterType); +impl_hash!(PyExternParameterType); +impl_eq!(PyExternParameterType); + +rigetti_pyo3::py_wrap_type! { + #[derive(Debug, PartialEq, Eq)] + #[pyo3(subclass, module = "quil.instructions")] + PyExternParameter(ExternParameter) as "ExternParameter" +} +rigetti_pyo3::impl_as_mut_for_wrapper!(PyExternParameter); +impl_repr!(PyExternParameter); +impl_to_quil!(PyExternParameter); +impl_copy_for_instruction!(PyExternParameter); +impl_hash!(PyExternParameter); +impl_eq!(PyExternParameter); +impl_pickle_for_instruction!(PyExternParameter); + +#[pymethods] +impl PyExternParameter { + #[new] + fn new( + py: Python<'_>, + name: String, + mutable: bool, + data_type: PyExternParameterType, + ) -> PyResult { + ExternParameter::try_new( + name, + mutable, + ExternParameterType::py_try_from(py, &data_type)?, + ) + .map(Self) + .map_err(RustExternError::from) + .map_err(RustExternError::to_py_err) + } + + #[getter] + fn name(&self) -> &str { + self.0.name() + } + + #[getter] + fn mutable(&self) -> bool { + self.0.mutable() + } + + #[getter] + fn data_type(&self) -> PyExternParameterType { + self.0.data_type().into() + } +} + +rigetti_pyo3::py_wrap_type! { + #[derive(Debug, PartialEq, Eq)] + #[pyo3(subclass, module = "quil.instructions")] + PyExternSignature(ExternSignature) as "ExternSignature" +} +rigetti_pyo3::impl_as_mut_for_wrapper!(PyExternSignature); +impl_repr!(PyExternSignature); +impl_to_quil!(PyExternSignature); +impl_copy_for_instruction!(PyExternSignature); +impl_hash!(PyExternSignature); +impl_eq!(PyExternSignature); +impl_pickle_for_instruction!(PyExternSignature); + +#[pymethods] +impl PyExternSignature { + #[new] + fn new( + py: Python<'_>, + parameters: Vec, + return_type: Option, + ) -> PyResult { + Ok(Self(ExternSignature::new( + return_type + .map(|scalar_type| ScalarType::py_try_from(py, &scalar_type)) + .transpose()?, + Vec::::py_try_from(py, ¶meters)?, + ))) + } + + #[getter] + fn parameters(&self) -> Vec { + self.0 + .parameters() + .iter() + .map(PyExternParameter::from) + .collect() + } + + #[getter] + fn return_type(&self) -> Option { + self.0.return_type().map(Into::into) + } +} diff --git a/quil-py/src/instruction/mod.rs b/quil-py/src/instruction/mod.rs index 9de51f2f..e0ea02d1 100644 --- a/quil-py/src/instruction/mod.rs +++ b/quil-py/src/instruction/mod.rs @@ -23,6 +23,10 @@ pub use self::{ ParseMemoryReferenceError, PyDeclaration, PyLoad, PyMemoryReference, PyOffset, PyScalarType, PySharing, PyStore, PyVector, }, + extern_call::{ + CallError, ExternError, PyCall, PyCallArgument, PyExternParameter, PyExternParameterType, + PyExternSignature, + }, frame::{ PyAttributeValue, PyCapture, PyFrameAttributes, PyFrameDefinition, PyFrameIdentifier, PyPulse, PyRawCapture, PySetFrequency, PySetPhase, PySetScale, PyShiftFrequency, @@ -45,6 +49,7 @@ mod circuit; mod classical; mod control_flow; mod declaration; +mod extern_call; mod frame; mod gate; mod measurement; @@ -60,6 +65,7 @@ py_wrap_union_enum! { arithmetic: Arithmetic => PyArithmetic, binary_logic: BinaryLogic => PyBinaryLogic, calibration_definition: CalibrationDefinition => PyCalibration, + call: Call => PyCall, capture: Capture => PyCapture, circuit_definition: CircuitDefinition => PyCircuitDefinition, convert: Convert => PyConvert, @@ -149,11 +155,16 @@ create_init_submodule! { PyBinaryLogic, PyBinaryOperand, PyBinaryOperator, + PyCall, + PyCallArgument, PyComparison, PyComparisonOperand, PyComparisonOperator, PyConvert, PyExchange, + PyExternParameter, + PyExternParameterType, + PyExternSignature, PyMove, PyUnaryLogic, PyUnaryOperator, @@ -207,7 +218,7 @@ create_init_submodule! { PyWaveformDefinition, PyWaveformInvocation ], - errors: [ GateError, ParseMemoryReferenceError ], + errors: [ CallError, ExternError, GateError, ParseMemoryReferenceError ], } /// Implements __copy__ and __deepcopy__ on any variant of the [`PyInstruction`] class, making diff --git a/quil-py/test/instructions/test_extern_call.py b/quil-py/test/instructions/test_extern_call.py new file mode 100644 index 00000000..dbe1d8d3 --- /dev/null +++ b/quil-py/test/instructions/test_extern_call.py @@ -0,0 +1,66 @@ +import pytest + +from quil.instructions import ( + Call, + CallArgument, + Declaration, + ExternParameter, + ExternParameterType, + ExternSignature, + Instruction, + MemoryReference, + Pragma, + PragmaArgument, + ScalarType, + Vector, +) +from quil.program import Program + + +@pytest.mark.parametrize( + "return_argument", + [CallArgument.from_identifier("real"), CallArgument.from_memory_reference(MemoryReference("real", 1))], +) +def test_extern_call_instructions(return_argument: CallArgument): + p = Program() + extern_signature = ExternSignature( + return_type=ScalarType.Real, + parameters=[ + ExternParameter( + name="a", + mutable=False, + data_type=ExternParameterType.from_variable_length_vector(ScalarType.Integer), + ) + ], + ) + pragma = Pragma( + "EXTERN", + [PragmaArgument.from_identifier("foo")], + f'"{extern_signature.to_quil()}"', + ) + p.add_instruction(Instruction(pragma)) + real_declaration = Declaration(name="real", size=Vector(data_type=ScalarType.Real, length=3), sharing=None) + p.add_instruction(Instruction(real_declaration)) + integer_declaration = Declaration(name="integer", size=Vector(data_type=ScalarType.Integer, length=3), sharing=None) + p.add_instruction(Instruction(integer_declaration)) + call = Call( + name="foo", + arguments=[ + return_argument, + CallArgument.from_identifier(integer_declaration.name), + ], + ) + p.add_instruction(Instruction(call)) + + parsed_program = Program.parse(p.to_quil()) + assert p == parsed_program + + +def test_extern_call_quil(): + input = """PRAGMA EXTERN foo "OCTET (params : mut REAL[3])" +DECLARE reals REAL[3] +DECLARE octets OCTET[3] +CALL foo octets[1] reals +""" + program = Program.parse(input) + assert program == Program.parse(program.to_quil()) diff --git a/quil-rs/proptest-regressions/expression/mod.txt b/quil-rs/proptest-regressions/expression/mod.txt new file mode 100644 index 00000000..0fad84e1 --- /dev/null +++ b/quil-rs/proptest-regressions/expression/mod.txt @@ -0,0 +1,8 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 4c32128d724ed0f840715fae4e194c99262dc153c64be39d2acf45b8903b20f7 # shrinks to value = Complex { re: 0.0, im: -0.13530277317700273 } +cc 5cc95f2159ad7120bbaf296d3a9fb26fef30f57b61e76b3e0dc99f4759009fdb # shrinks to e = Number(Complex { re: 0.0, im: -2.772221265116396 }) diff --git a/quil-rs/src/expression/mod.rs b/quil-rs/src/expression/mod.rs index c82a2540..1221afe2 100644 --- a/quil-rs/src/expression/mod.rs +++ b/quil-rs/src/expression/mod.rs @@ -462,7 +462,7 @@ static FORMAT_IMAGINARY_OPTIONS: Lazy = Lazy::new(|| { /// - When imaginary is 0, show real only /// - When both are non-zero, show with the correct operator in between #[inline(always)] -fn format_complex(value: &Complex64) -> String { +pub(crate) fn format_complex(value: &Complex64) -> String { const FORMAT: u128 = format::STANDARD; if value.re == 0f64 && value.im == 0f64 { "0".to_owned() diff --git a/quil-rs/src/instruction/extern_call.rs b/quil-rs/src/instruction/extern_call.rs new file mode 100644 index 00000000..a2e811f4 --- /dev/null +++ b/quil-rs/src/instruction/extern_call.rs @@ -0,0 +1,2294 @@ +/// This module provides support for the `CALL` instruction and the reserved `PRAGMA EXTERN` instruction. +/// +/// For additional detail on its design and specification see: +/// +/// * [Quil specification "Other"](https://github.com/quil-lang/quil/blob/7f532c7cdde9f51eae6abe7408cc868fba9f91f6/specgen/spec/sec-other.s_) +/// * [Quil EXTERN / CALL RFC](https://github.com/quil-lang/quil/blob/master/rfcs/extern-call.md) +/// * [quil#69](https://github.com/quil-lang/quil/pull/69) +use std::{collections::HashSet, str::FromStr}; + +use indexmap::IndexMap; +use nom_locate::LocatedSpan; +use num_complex::Complex64; + +use crate::{ + expression::format_complex, + hash::hash_f64, + parser::lex, + program::{disallow_leftover, MemoryAccesses, MemoryRegion, SyntaxError}, + quil::Quil, + validation::identifier::{validate_user_identifier, IdentifierValidationError}, +}; + +use super::{ + Instruction, MemoryReference, Pragma, PragmaArgument, ScalarType, Vector, + RESERVED_PRAGMA_EXTERN, +}; + +/// A parameter type within an extern signature. +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub enum ExternParameterType { + /// A scalar parameter, which may accept a memory reference or immediate value. + /// + /// For instance `PRAGMA EXTERN foo "(bar : INTEGER)"`. + Scalar(ScalarType), + /// A fixed-length vector, which must accept a memory region name of the appropriate + /// length and data type. + /// + /// For instance `PRAGMA EXTERN foo "(bar : INTEGER[2])"`. + FixedLengthVector(Vector), + /// A variable-length vector, which must accept a memory region name of the appropriate + /// data type. + /// + /// For instance `PRAGMA EXTERN foo "(bar : INTEGER[])"`. + VariableLengthVector(ScalarType), +} + +impl Quil for ExternParameterType { + fn write( + &self, + f: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> crate::quil::ToQuilResult<()> { + match self { + ExternParameterType::Scalar(value) => value.write(f, fall_back_to_debug), + ExternParameterType::FixedLengthVector(value) => value.write(f, fall_back_to_debug), + ExternParameterType::VariableLengthVector(value) => { + value.write(f, fall_back_to_debug)?; + Ok(write!(f, "[]")?) + } + } + } +} + +/// An extern parameter with a name, mutability, and data type. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ExternParameter { + /// The name of the parameter. This must be a valid user identifier. + pub(crate) name: String, + /// Whether the parameter is mutable. + pub(crate) mutable: bool, + /// The data type of the parameter. + pub(crate) data_type: ExternParameterType, +} + +impl ExternParameter { + /// Create a new extern parameter. This will fail if the parameter name + /// is not a valid user identifier. + pub fn try_new( + name: String, + mutable: bool, + data_type: ExternParameterType, + ) -> Result { + validate_user_identifier(name.as_str()).map_err(ExternError::from)?; + Ok(Self { + name, + mutable, + data_type, + }) + } + + pub fn name(&self) -> &str { + self.name.as_str() + } + + pub fn mutable(&self) -> bool { + self.mutable + } + + pub fn data_type(&self) -> &ExternParameterType { + &self.data_type + } +} + +impl Quil for ExternParameter { + fn write( + &self, + writer: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> Result<(), crate::quil::ToQuilError> { + write!(writer, "{} : ", self.name)?; + if self.mutable { + write!(writer, "mut ")?; + } + self.data_type.write(writer, fall_back_to_debug) + } +} + +/// An extern signature with a return type and parameters. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ExternSignature { + /// The return type of the extern signature, if any. + pub(crate) return_type: Option, + /// The parameters of the extern signature. + pub(crate) parameters: Vec, +} + +impl ExternSignature { + /// Create a new extern signature. + pub fn new(return_type: Option, parameters: Vec) -> Self { + Self { + return_type, + parameters, + } + } + + pub fn return_type(&self) -> Option<&ScalarType> { + self.return_type.as_ref() + } + + pub fn parameters(&self) -> &[ExternParameter] { + self.parameters.as_slice() + } +} + +const EXPECTED_PRAGMA_EXTERN_STRUCTURE: &str = "PRAGMA EXTERN {name} \"{scalar type}? (\\(({parameter name} : mut? {parameter type}) (, {parameter name} : mut? {parameter type})*\\))?\""; + +/// An error that can occur when parsing an extern signature. +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum ExternError { + /// An error occurred while parsing the contents of the extern signature. + #[error( + "invalid extern signature syntax: {0:?} (expected `{EXPECTED_PRAGMA_EXTERN_STRUCTURE}`)" + )] + Syntax(SyntaxError), + /// An error occurred while lexing the extern signature. + #[error( + "failed to lex extern signature: {0:?} (expected `{EXPECTED_PRAGMA_EXTERN_STRUCTURE}`)" + )] + Lex(crate::parser::LexError), + /// Pragma arguments are invalid. + #[error("`PRAGMA EXTERN` must have a single argument representing the extern name")] + InvalidPragmaArguments, + /// No signature found. + #[error("`PRAGMA EXTERN` instruction has no signature")] + NoSignature, + /// No extern name found. + #[error("`PRAGMA EXTERN` instruction has no name")] + NoName, + /// Pragma is not EXTERN. + #[error("ExternPragmaMap contained a pragma that was not EXTERN")] + PragmaIsNotExtern, + /// The extern definition has a signature but it has neither a return nor parameters. + #[error("extern definition has a signature but it has neither a return nor parameters")] + NoReturnOrParameters, + /// Either the name of the extern or one of its parameters is invalid. + #[error("invalid identifier: {0:?}")] + Name(#[from] IdentifierValidationError), +} + +impl FromStr for ExternSignature { + type Err = ExternError; + + fn from_str(s: &str) -> Result { + let signature_input = LocatedSpan::new(s); + let signature_tokens = lex(signature_input).map_err(ExternError::Lex)?; + let signature = disallow_leftover( + crate::parser::pragma_extern::parse_extern_signature(signature_tokens.as_slice()) + .map_err(crate::parser::ParseError::from_nom_internal_err), + ) + .map_err(ExternError::Syntax)?; + if signature.return_type.is_none() && signature.parameters.is_empty() { + return Err(ExternError::NoReturnOrParameters); + } + for parameter in &signature.parameters { + validate_user_identifier(parameter.name.as_str()).map_err(ExternError::from)?; + } + Ok(signature) + } +} + +impl Quil for ExternSignature { + fn write( + &self, + writer: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> Result<(), crate::quil::ToQuilError> { + if let Some(return_type) = &self.return_type { + return_type.write(writer, fall_back_to_debug)?; + if !self.parameters.is_empty() { + write!(writer, " ")?; + } + } + if self.parameters.is_empty() { + return Ok(()); + } + write!(writer, "(")?; + for (i, parameter) in self.parameters.iter().enumerate() { + if i > 0 { + write!(writer, ", ")?; + } + parameter.write(writer, fall_back_to_debug)?; + } + write!(writer, ")").map_err(Into::into) + } +} + +impl TryFrom for ExternSignature { + type Error = ExternError; + + fn try_from(value: Pragma) -> Result { + if value.name != RESERVED_PRAGMA_EXTERN { + return Err(ExternError::PragmaIsNotExtern); + } + if value.arguments.is_empty() + || !matches!(value.arguments[0], PragmaArgument::Identifier(_)) + { + return Err(ExternError::NoName); + } + if value.arguments.len() > 1 { + return Err(ExternError::InvalidPragmaArguments); + } + + match value.data { + Some(data) => ExternSignature::from_str(data.as_str()), + None => Err(ExternError::NoSignature), + } + } +} + +/// A map of all program `PRAGMA EXTERN` instructions from their name (if any) to +/// the corresponding [`Pragma`] instruction. Note, keys are [`Option`]s, but a +/// `None` key will be considered invalid when converting to an [`ExternSignatureMap`]. +#[derive(Clone, Debug, PartialEq, Default)] +pub struct ExternPragmaMap(IndexMap, Pragma>); + +impl ExternPragmaMap { + pub(crate) fn len(&self) -> usize { + self.0.len() + } + + pub(crate) fn to_instructions(&self) -> Vec { + self.0.values().cloned().map(Instruction::Pragma).collect() + } + + /// Insert a `PRAGMA EXTERN` instruction into the underlying [`IndexMap`]. + /// + /// If the first argument to the [`Pragma`] is not a [`PragmaArgument::Identifier`], or + /// does not exist, then the [`Pragma`] will be inserted with a `None` key. + /// + /// If the key already exists, the previous [`Pragma`] will be returned, similar to + /// the behavior of [`IndexMap::insert`]. + pub(crate) fn insert(&mut self, pragma: Pragma) -> Option { + self.0.insert( + match pragma.arguments.first() { + Some(PragmaArgument::Identifier(name)) => Some(name.clone()), + _ => None, + }, + pragma, + ) + } + + pub(crate) fn retain(&mut self, f: F) + where + F: FnMut(&Option, &mut Pragma) -> bool, + { + self.0.retain(f) + } +} + +/// A map of all program `PRAGMA EXTERN` instructions from their name to the corresponding +/// parsed and validated [`ExternSignature`]. +#[derive(Clone, Debug, PartialEq, Default)] +pub struct ExternSignatureMap(IndexMap); + +impl TryFrom for ExternSignatureMap { + /// The error type for converting an [`ExternPragmaMap`] to an [`ExternSignatureMap`] includes + /// the offending [`Pragma`] instruction and the error that occurred. + type Error = (Pragma, ExternError); + + fn try_from(value: ExternPragmaMap) -> Result { + Ok(ExternSignatureMap( + value + .0 + .into_iter() + .map(|(key, value)| -> Result<_, Self::Error> { + match key { + Some(name) => { + validate_user_identifier(name.as_str()) + .map_err(ExternError::from) + .map_err(|error| (value.clone(), error))?; + let signature = ExternSignature::try_from(value.clone()) + .map_err(|error| (value, error))?; + Ok((name, signature)) + } + _ => Err((value, ExternError::NoName)), + } + }) + .collect::>()?, + )) + } +} + +impl ExternSignatureMap { + #[cfg(test)] + pub(crate) fn len(&self) -> usize { + self.0.len() + } +} + +/// An error that can occur when resolving a call instruction. +#[derive(Clone, Debug, thiserror::Error, PartialEq)] +pub enum CallArgumentResolutionError { + /// An undeclared memory reference was encountered. + #[error("undeclared memory reference {0}")] + UndeclaredMemoryReference(String), + /// A mismatched vector was encountered. + #[error("mismatched vector: expected {expected:?}, found {found:?}")] + MismatchedVector { expected: Vector, found: Vector }, + /// A mismatched scalar was encountered. + #[error("mismatched scalar: expected {expected:?}, found {found:?}")] + MismatchedScalar { + expected: ScalarType, + found: ScalarType, + }, + /// The argument for a vector parameter was invalid. + #[error("vector parameters must be passed as an identifier, found {0:?}")] + InvalidVectorArgument(UnresolvedCallArgument), + /// The argument for a return parameter was invalid. + #[error("return argument must be a memory reference or identifier, found {found:?}")] + ReturnArgument { found: UnresolvedCallArgument }, + /// Immediate arguments cannot be specified for mutable parameters. + #[error("immediate arguments cannot be specified for mutable parameter {0}")] + ImmediateArgumentForMutable(String), +} + +/// A parsed, but unresolved call argument. This may be resolved into a [`ResolvedCallArgument`] +/// with the appropriate [`ExternSignature`]. Resolution is required for building the +/// [`crate::Program`] memory graph. +#[derive(Clone, Debug, PartialEq)] +pub enum UnresolvedCallArgument { + /// A reference to a declared memory location. Note, this may be resolved to either + /// a scalar or vector. In the former case, the assumed index is 0. + Identifier(String), + /// A reference to a memory location. This may be resolved to a scalar. + MemoryReference(MemoryReference), + /// An immediate value. This may be resolved to a non-mutable scalar. + Immediate(Complex64), +} + +impl Eq for UnresolvedCallArgument {} + +impl std::hash::Hash for UnresolvedCallArgument { + fn hash(&self, state: &mut H) { + match self { + UnresolvedCallArgument::Identifier(value) => { + "Identifier".hash(state); + value.hash(state); + } + UnresolvedCallArgument::MemoryReference(value) => { + "MemoryReference".hash(state); + value.hash(state); + } + UnresolvedCallArgument::Immediate(value) => { + "Immediate".hash(state); + hash_complex_64(value, state); + } + } + } +} + +impl UnresolvedCallArgument { + /// Check if the argument is compatible with the given [`ExternParameter`]. If so, return + /// the appropriate [`ResolvedCallArgument`]. If not, return an error. + fn resolve( + &self, + memory_regions: &IndexMap, + extern_parameter: &ExternParameter, + ) -> Result { + match self { + UnresolvedCallArgument::Identifier(value) => { + let expected_vector = match &extern_parameter.data_type { + ExternParameterType::Scalar(_) => { + return UnresolvedCallArgument::MemoryReference(MemoryReference::new( + value.clone(), + 0, + )) + .resolve(memory_regions, extern_parameter); + } + ExternParameterType::FixedLengthVector(expected_vector) => { + let memory_region = + memory_regions.get(value.as_str()).ok_or_else(|| { + CallArgumentResolutionError::UndeclaredMemoryReference( + value.clone(), + ) + })?; + if &memory_region.size != expected_vector { + return Err(CallArgumentResolutionError::MismatchedVector { + expected: expected_vector.clone(), + found: memory_region.size.clone(), + }); + } + + Ok(expected_vector.clone()) + } + ExternParameterType::VariableLengthVector(scalar_type) => { + let memory_region = + memory_regions.get(value.as_str()).ok_or_else(|| { + CallArgumentResolutionError::UndeclaredMemoryReference( + value.clone(), + ) + })?; + if &memory_region.size.data_type != scalar_type { + return Err(CallArgumentResolutionError::MismatchedScalar { + expected: *scalar_type, + found: memory_region.size.data_type, + }); + } + Ok(memory_region.size.clone()) + } + }?; + Ok(ResolvedCallArgument::Vector { + memory_region_name: value.clone(), + vector: expected_vector, + mutable: extern_parameter.mutable, + }) + } + UnresolvedCallArgument::MemoryReference(value) => { + let expected_scalar = match extern_parameter.data_type { + ExternParameterType::Scalar(ref scalar) => Ok(scalar), + ExternParameterType::FixedLengthVector(_) + | ExternParameterType::VariableLengthVector(_) => { + Err(CallArgumentResolutionError::InvalidVectorArgument( + Self::MemoryReference(value.clone()), + )) + } + }?; + let memory_region = memory_regions.get(value.name.as_str()).ok_or_else(|| { + CallArgumentResolutionError::UndeclaredMemoryReference(value.name.clone()) + })?; + if memory_region.size.data_type != *expected_scalar { + return Err(CallArgumentResolutionError::MismatchedScalar { + expected: *expected_scalar, + found: memory_region.size.data_type, + }); + } + Ok(ResolvedCallArgument::MemoryReference { + memory_reference: value.clone(), + scalar_type: *expected_scalar, + mutable: extern_parameter.mutable, + }) + } + UnresolvedCallArgument::Immediate(value) => { + if extern_parameter.mutable { + return Err(CallArgumentResolutionError::ImmediateArgumentForMutable( + extern_parameter.name.clone(), + )); + } + let expected_scalar = match extern_parameter.data_type { + ExternParameterType::Scalar(ref scalar) => Ok(scalar), + ExternParameterType::FixedLengthVector(_) + | ExternParameterType::VariableLengthVector(_) => Err( + CallArgumentResolutionError::InvalidVectorArgument(self.clone()), + ), + }?; + Ok(ResolvedCallArgument::Immediate { + value: *value, + scalar_type: *expected_scalar, + }) + } + } + } + + /// Check if the argument is compatible with the return type of the [`ExternSignature`]. If so, + /// return the appropriate [`ResolvedCallArgument`]. If not, return an error. + fn resolve_return( + &self, + memory_regions: &IndexMap, + return_type: ScalarType, + ) -> Result { + let memory_reference = match self { + UnresolvedCallArgument::MemoryReference(memory_reference) => { + Ok(memory_reference.clone()) + } + UnresolvedCallArgument::Identifier(identifier) => { + Ok(MemoryReference::new(identifier.clone(), 0)) + } + _ => Err(CallArgumentResolutionError::ReturnArgument { + found: self.clone(), + }), + }?; + let memory_region = memory_regions + .get(memory_reference.name.as_str()) + .ok_or_else(|| { + CallArgumentResolutionError::UndeclaredMemoryReference( + memory_reference.name.clone(), + ) + })?; + if memory_region.size.data_type != return_type { + return Err(CallArgumentResolutionError::MismatchedScalar { + expected: return_type, + found: memory_region.size.data_type, + }); + } + Ok(ResolvedCallArgument::MemoryReference { + memory_reference: memory_reference.clone(), + scalar_type: return_type, + mutable: true, + }) + } +} + +impl Quil for UnresolvedCallArgument { + fn write( + &self, + f: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> crate::quil::ToQuilResult<()> { + match &self { + UnresolvedCallArgument::Identifier(value) => write!(f, "{value}",).map_err(Into::into), + UnresolvedCallArgument::MemoryReference(value) => value.write(f, fall_back_to_debug), + UnresolvedCallArgument::Immediate(value) => { + write!(f, "{}", format_complex(value)).map_err(Into::into) + } + } + } +} + +/// A resolved call argument. This is the result of resolving an [`UnresolvedCallArgument`] with +/// the appropriate [`ExternParameter`]. It annotates the argument both with a type (and possibly +/// a length in the case of a vector) and mutability. +#[derive(Clone, Debug, PartialEq)] +pub enum ResolvedCallArgument { + /// A resolved vector argument, including its scalar type, length, and mutability. + Vector { + memory_region_name: String, + vector: Vector, + mutable: bool, + }, + /// A resolved memory reference, including its scalar type and mutability. + MemoryReference { + memory_reference: MemoryReference, + scalar_type: ScalarType, + mutable: bool, + }, + /// A resolved immediate value, including its scalar type. + Immediate { + value: Complex64, + scalar_type: ScalarType, + }, +} + +impl From for UnresolvedCallArgument { + fn from(value: ResolvedCallArgument) -> Self { + match value { + ResolvedCallArgument::Vector { + memory_region_name, .. + } => UnresolvedCallArgument::Identifier(memory_region_name), + ResolvedCallArgument::MemoryReference { + memory_reference, .. + } => UnresolvedCallArgument::MemoryReference(memory_reference), + ResolvedCallArgument::Immediate { value, .. } => { + UnresolvedCallArgument::Immediate(value) + } + } + } +} + +impl Eq for ResolvedCallArgument {} + +impl std::hash::Hash for ResolvedCallArgument { + fn hash(&self, state: &mut H) { + match self { + ResolvedCallArgument::Vector { + memory_region_name, + vector, + mutable, + } => { + "Vector".hash(state); + memory_region_name.hash(state); + vector.hash(state); + mutable.hash(state); + } + ResolvedCallArgument::MemoryReference { + memory_reference, + scalar_type, + mutable, + } => { + "MemoryReference".hash(state); + memory_reference.hash(state); + scalar_type.hash(state); + mutable.hash(state); + } + ResolvedCallArgument::Immediate { value, scalar_type } => { + "Immediate".hash(state); + hash_complex_64(value, state); + scalar_type.hash(state); + } + } + } +} + +fn hash_complex_64(value: &Complex64, state: &mut H) { + if value.re.abs() > 0f64 { + hash_f64(value.re, state); + } + if value.im.abs() > 0f64 { + hash_f64(value.im, state); + } +} + +/// An error that can occur when validating a call instruction. +#[derive(Clone, Debug, PartialEq, thiserror::Error, Eq)] +pub enum CallError { + /// The specified name is not a valid user identifier. + #[error(transparent)] + Name(#[from] IdentifierValidationError), +} + +/// A call instruction with a name and arguments. +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub struct Call { + /// The name of the call instruction. This must be a valid user identifier. + pub name: String, + /// The arguments of the call instruction. + pub arguments: Vec, +} + +impl Call { + /// Create a new call instruction with resolved arguments. This will validate the + /// name as a user identifier. + pub fn try_new( + name: String, + arguments: Vec, + ) -> Result { + validate_user_identifier(name.as_str()).map_err(CallError::Name)?; + + Ok(Self { name, arguments }) + } + + pub fn name(&self) -> &str { + self.name.as_str() + } + + pub fn arguments(&self) -> &[UnresolvedCallArgument] { + self.arguments.as_slice() + } +} + +/// An error that can occur when resolving a call instruction argument. +#[derive(Clone, Debug, thiserror::Error, PartialEq)] +pub enum CallArgumentError { + /// The return argument could not be resolved. + #[error("error resolving return argument: {0:?}")] + Return(CallArgumentResolutionError), + /// An argument could not be resolved. + #[error("error resolving argument {index}: {error:?}")] + Argument { + index: usize, + error: CallArgumentResolutionError, + }, +} + +/// An error that can occur when resolving a call instruction to a specific +/// [`ExternSignature`]. +#[derive(Debug, thiserror::Error, PartialEq, Clone)] +pub enum CallSignatureError { + #[error("expected {expected} arguments, found {found}")] + ParameterCount { expected: usize, found: usize }, + #[error("error resolving arguments: {0:?}")] + Arguments(Vec), +} + +/// An error that can occur when resolving a call instruction, given a complete +/// [`ExternPragmaMap`] for the [`crate::program::Program`]. +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum CallResolutionError { + /// A matching extern instruction was found, but signature validation failed. + #[error("call found matching extern instruction for {name}, but signature validation failed: {error:?}")] + Signature { + name: String, + error: CallSignatureError, + }, + /// No matching extern instruction was found. + #[error("no extern instruction found with name {0}")] + NoMatchingExternInstruction(String), + /// Failed to convernt the [`ExternPragmaMap`] to an [`ExternSignatureMap`]. + #[error(transparent)] + ExternSignature(#[from] ExternError), +} + +#[allow(clippy::manual_try_fold)] +fn convert_unresolved_to_resolved_call_arguments( + arguments: &[UnresolvedCallArgument], + signature: &ExternSignature, + memory_regions: &IndexMap, +) -> Result, CallSignatureError> { + arguments + .iter() + .enumerate() + .map(|(i, argument)| { + if i == 0 { + if let Some(return_type) = signature.return_type { + return argument + .resolve_return(memory_regions, return_type) + .map_err(CallArgumentError::Return); + } + } + let parameter_index = if signature.return_type.is_some() { + i - 1 + } else { + i + }; + let parameter = &signature.parameters[parameter_index]; + argument + .resolve(memory_regions, parameter) + .map_err(|error| CallArgumentError::Argument { + index: parameter_index, + error, + }) + }) + .fold( + Ok(Vec::new()), + |acc: Result, Vec>, + result: Result| { + match (acc, result) { + (Ok(mut acc), Ok(resolved)) => { + acc.push(resolved); + Ok(acc) + } + (Ok(_), Err(error)) => Err(vec![error]), + (Err(errors), Ok(_)) => Err(errors), + (Err(mut errors), Err(error)) => { + errors.push(error); + Err(errors) + } + } + }, + ) + .map_err(CallSignatureError::Arguments) +} + +impl Call { + /// Resolve the [`Call`] instruction to the given [`ExternSignature`]. + fn resolve_to_signature( + &self, + signature: &ExternSignature, + memory_regions: &IndexMap, + ) -> Result, CallSignatureError> { + let mut expected_parameter_count = signature.parameters.len(); + if signature.return_type.is_some() { + expected_parameter_count += 1; + } + + if self.arguments.len() != expected_parameter_count { + return Err(CallSignatureError::ParameterCount { + expected: expected_parameter_count, + found: self.arguments.len(), + }); + } + + let resolved_call_arguments = convert_unresolved_to_resolved_call_arguments( + &self.arguments, + signature, + memory_regions, + )?; + + Ok(resolved_call_arguments) + } + + /// Resolve the [`Call`] instruction to any of the given [`ExternSignature`]s and memory regions. + /// If no matching extern instruction is found, return an error. + pub fn resolve_arguments( + &self, + memory_regions: &IndexMap, + extern_signature_map: &ExternSignatureMap, + ) -> Result, CallResolutionError> { + let extern_signature = extern_signature_map + .0 + .get(self.name.as_str()) + .ok_or_else(|| CallResolutionError::NoMatchingExternInstruction(self.name.clone()))?; + + self.resolve_to_signature(extern_signature, memory_regions) + .map_err(|error| CallResolutionError::Signature { + name: self.name.clone(), + error, + }) + } + + /// Return the [`MemoryAccesses`] for the [`Call`] instruction given the [`ExternSignatureMap`]. + /// This assumes ALL parameters are read, including mutable parameters. + pub(crate) fn get_memory_accesses( + &self, + extern_signatures: &ExternSignatureMap, + ) -> Result { + let extern_signature = extern_signatures + .0 + .get(self.name.as_str()) + .ok_or_else(|| CallResolutionError::NoMatchingExternInstruction(self.name.clone()))?; + + let mut reads = HashSet::new(); + let mut writes = HashSet::new(); + let mut arguments = self.arguments.iter(); + if extern_signature.return_type.is_some() { + if let Some(argument) = self.arguments.first() { + arguments.next(); + match argument { + UnresolvedCallArgument::MemoryReference(memory_reference) => { + reads.insert(memory_reference.name.clone()); + writes.insert(memory_reference.name.clone()); + } + UnresolvedCallArgument::Identifier(identifier) => { + reads.insert(identifier.clone()); + writes.insert(identifier.clone()); + } + _ => {} + } + } + } + for (argument, parameter) in std::iter::zip(arguments, extern_signature.parameters.iter()) { + match argument { + UnresolvedCallArgument::MemoryReference(memory_reference) => { + reads.insert(memory_reference.name.clone()); + if parameter.mutable { + writes.insert(memory_reference.name.clone()); + } + } + UnresolvedCallArgument::Identifier(identifier) => { + reads.insert(identifier.clone()); + if parameter.mutable { + writes.insert(identifier.clone()); + } + } + _ => {} + } + } + Ok(MemoryAccesses { + reads, + writes, + captures: HashSet::new(), + }) + } +} + +impl Quil for Call { + fn write( + &self, + f: &mut impl std::fmt::Write, + fall_back_to_debug: bool, + ) -> crate::quil::ToQuilResult<()> { + write!(f, "CALL {}", self.name)?; + for argument in self.arguments.as_slice() { + write!(f, " ")?; + argument.write(f, fall_back_to_debug)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::instruction::PragmaArgument; + use rstest::*; + + /// Test cases for the [`ExternSignature`] Quil representation. + struct ExternSignatureQuilTestCase { + /// The extern signature to test. + signature: ExternSignature, + /// The expected Quil representation. + expected: &'static str, + } + + impl ExternSignatureQuilTestCase { + /// Signature with return and parameters + fn case_01() -> Self { + Self { + signature: ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + ExternParameter { + name: "baz".to_string(), + mutable: true, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 2, + }), + }, + ], + }, + expected: "INTEGER (bar : INTEGER, baz : mut BIT[2])", + } + } + + /// Signature with only parameters + fn case_02() -> Self { + let signature = ExternSignature { + return_type: None, + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + ExternParameter { + name: "baz".to_string(), + mutable: true, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 2, + }), + }, + ], + }; + Self { + signature, + expected: "(bar : INTEGER, baz : mut BIT[2])", + } + } + + /// Signature with return only + fn case_03() -> Self { + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }; + Self { + signature, + expected: "INTEGER", + } + } + + /// Signature with no return nor parameters + fn case_04() -> Self { + let signature = ExternSignature { + return_type: None, + parameters: vec![], + }; + Self { + signature, + expected: "", + } + } + + /// Variable length vector + fn case_05() -> Self { + let signature = ExternSignature { + return_type: None, + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::VariableLengthVector(ScalarType::Integer), + }], + }; + Self { + signature, + expected: "(bar : INTEGER[])", + } + } + } + + /// Test that the Quil representation of an [`ExternSignature`] is as expected. + #[rstest] + #[case(ExternSignatureQuilTestCase::case_01())] + #[case(ExternSignatureQuilTestCase::case_02())] + #[case(ExternSignatureQuilTestCase::case_03())] + #[case(ExternSignatureQuilTestCase::case_04())] + #[case(ExternSignatureQuilTestCase::case_05())] + #[case(ExternSignatureQuilTestCase::case_05())] + fn test_extern_signature_quil(#[case] test_case: ExternSignatureQuilTestCase) { + assert_eq!( + test_case + .signature + .to_quil() + .expect("must be able to call to quil"), + test_case.expected.to_string() + ); + } + + /// Test cases for the [`Call`] Quil representation. + struct CallQuilTestCase { + /// The call instruction to test. + call: Call, + /// The expected Quil representation. + expected: &'static str, + } + + impl CallQuilTestCase { + fn case_01() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "bar".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("baz".to_string()), + ], + }; + Self { + call, + expected: "CALL foo bar[0] 2 baz", + } + } + + fn case_02() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "bar".to_string(), + index: 0, + }), + UnresolvedCallArgument::Identifier("baz".to_string()), + ], + }; + Self { + call, + expected: "CALL foo bar[0] baz", + } + } + + fn case_03() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "bar".to_string(), + index: 0, + })], + }; + Self { + call, + expected: "CALL foo bar[0]", + } + } + + /// No arguments. + fn case_04() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![], + }; + + Self { + call, + expected: "CALL foo", + } + } + } + + /// Test that the Quil representation of a [`Call`] instruction is as expected. + #[rstest] + #[case(CallQuilTestCase::case_01())] + #[case(CallQuilTestCase::case_02())] + #[case(CallQuilTestCase::case_03())] + #[case(CallQuilTestCase::case_04())] + fn test_call_quil(#[case] test_case: CallQuilTestCase) { + assert_eq!( + test_case + .call + .to_quil() + .expect("must be able to call to quil"), + test_case.expected.to_string() + ); + } + + /// Build a set of memory regions for testing. + fn build_declarations() -> IndexMap { + [ + ("integer", Vector::new(ScalarType::Integer, 3)), + ("real", Vector::new(ScalarType::Real, 3)), + ("bit", Vector::new(ScalarType::Bit, 3)), + ("octet", Vector::new(ScalarType::Octet, 3)), + ] + .into_iter() + .map(|(name, vector)| (name.to_string(), MemoryRegion::new(vector, None))) + .collect() + } + + /// Test cases for resolving call arguments. + struct ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument, + extern_parameter: ExternParameter, + expected: Result, + } + + impl ArgumentResolutionTestCase { + /// Memory reference as scalar + fn case_01() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Ok(ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: false, + }), + } + } + + /// Identifier as vector + fn case_02() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Identifier("real".to_string()), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Real, + length: 3, + }), + }, + expected: Ok(ResolvedCallArgument::Vector { + memory_region_name: "real".to_string(), + vector: Vector { + data_type: ScalarType::Real, + length: 3, + }, + mutable: false, + }), + } + } + + /// Immediate value as scalar + fn case_03() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Ok(ResolvedCallArgument::Immediate { + value: Complex64::new(2.0, 0.0), + scalar_type: ScalarType::Integer, + }), + } + } + + /// Undeclared identifier + fn case_04() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Identifier("undeclared".to_string()), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Real, + length: 3, + }), + }, + expected: Err(CallArgumentResolutionError::UndeclaredMemoryReference( + "undeclared".to_string(), + )), + } + } + + /// Undeclared memory reference + fn case_05() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "undeclared".to_string(), + index: 0, + }), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Err(CallArgumentResolutionError::UndeclaredMemoryReference( + "undeclared".to_string(), + )), + } + } + + /// Vector data type mismatch + fn case_06() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Identifier("integer".to_string()), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Real, + length: 3, + }), + }, + expected: Err(CallArgumentResolutionError::MismatchedVector { + expected: Vector { + data_type: ScalarType::Real, + length: 3, + }, + found: Vector { + data_type: ScalarType::Integer, + length: 3, + }, + }), + } + } + + /// Vector length mismatch + fn case_07() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::Identifier("integer".to_string()), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Integer, + length: 4, + }), + }, + expected: Err(CallArgumentResolutionError::MismatchedVector { + expected: Vector { + data_type: ScalarType::Integer, + length: 4, + }, + found: Vector { + data_type: ScalarType::Integer, + length: 3, + }, + }), + } + } + + /// Scalar data type mismatch + fn case_08() -> Self { + ArgumentResolutionTestCase { + call_argument: UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "octet".to_string(), + index: 0, + }), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Err(CallArgumentResolutionError::MismatchedScalar { + expected: ScalarType::Integer, + found: ScalarType::Octet, + }), + } + } + + /// Scalar arguments may be passed as identifiers, in which case `0` index is + /// inferred. + fn case_09() -> Self { + let call_argument = UnresolvedCallArgument::Identifier("integer".to_string()); + ArgumentResolutionTestCase { + call_argument: call_argument.clone(), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Ok(ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference::new("integer".to_string(), 0), + scalar_type: ScalarType::Integer, + mutable: false, + }), + } + } + + /// Vector arguments must be passed as identifiers, not memory references. + fn case_10() -> Self { + let call_argument = UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }); + ArgumentResolutionTestCase { + call_argument: call_argument.clone(), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Integer, + length: 3, + }), + }, + expected: Err(CallArgumentResolutionError::InvalidVectorArgument( + call_argument, + )), + } + } + + /// Vector arguments must be passed as identifiers, not immediate values. + fn case_11() -> Self { + let call_argument = UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)); + ArgumentResolutionTestCase { + call_argument: call_argument.clone(), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Integer, + length: 3, + }), + }, + expected: Err(CallArgumentResolutionError::InvalidVectorArgument( + call_argument, + )), + } + } + + /// Variable vector arguments are resolved to a specific vector length based on the + /// declaration (see [`build_declarations`]). + fn case_12() -> Self { + let call_argument = UnresolvedCallArgument::Identifier("integer".to_string()); + ArgumentResolutionTestCase { + call_argument: call_argument.clone(), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::VariableLengthVector(ScalarType::Integer), + }, + expected: Ok(ResolvedCallArgument::Vector { + memory_region_name: "integer".to_string(), + mutable: false, + vector: Vector { + data_type: ScalarType::Integer, + length: 3, + }, + }), + } + } + + /// Immediate arguments cannot be passed for mutable parameters. + fn case_13() -> Self { + let call_argument = UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)); + ArgumentResolutionTestCase { + call_argument: call_argument.clone(), + extern_parameter: ExternParameter { + name: "bar".to_string(), + mutable: true, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + expected: Err(CallArgumentResolutionError::ImmediateArgumentForMutable( + "bar".to_string(), + )), + } + } + } + + /// Test resolution of call arguments. + #[rstest] + #[case(ArgumentResolutionTestCase::case_01())] + #[case(ArgumentResolutionTestCase::case_02())] + #[case(ArgumentResolutionTestCase::case_03())] + #[case(ArgumentResolutionTestCase::case_04())] + #[case(ArgumentResolutionTestCase::case_05())] + #[case(ArgumentResolutionTestCase::case_06())] + #[case(ArgumentResolutionTestCase::case_07())] + #[case(ArgumentResolutionTestCase::case_08())] + #[case(ArgumentResolutionTestCase::case_09())] + #[case(ArgumentResolutionTestCase::case_10())] + #[case(ArgumentResolutionTestCase::case_11())] + #[case(ArgumentResolutionTestCase::case_12())] + #[case(ArgumentResolutionTestCase::case_13())] + fn test_argument_resolution(#[case] test_case: ArgumentResolutionTestCase) { + let memory_regions = build_declarations(); + let found = test_case + .call_argument + .resolve(&memory_regions, &test_case.extern_parameter); + match (test_case.expected, found) { + (Ok(expected), Ok(found)) => assert_eq!(expected, found), + (Ok(expected), Err(found)) => { + panic!("expected resolution {:?}, found err {:?}", expected, found) + } + (Err(expected), Ok(found)) => { + panic!("expected err {:?}, found resolution {:?}", expected, found) + } + (Err(expected), Err(found)) => assert_eq!(expected, found), + } + } + + /// Test cases for resolving return arguments. + struct ReturnArgumentResolutionTestCase { + /// The call argument to resolve. + call_argument: UnresolvedCallArgument, + /// The return type of the function. + return_type: ScalarType, + /// The expected result of the resolution. + expected: Result, + } + + impl ReturnArgumentResolutionTestCase { + /// Memory reference is ok. + fn case_01() -> Self { + let call_argument = UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }); + let expected = Ok(ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }); + Self { + call_argument, + return_type: ScalarType::Integer, + expected, + } + } + + /// Immediate value is not ok. + fn case_02() -> Self { + let call_argument = UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)); + let expected = Err(CallArgumentResolutionError::ReturnArgument { + found: call_argument.clone(), + }); + Self { + call_argument, + return_type: ScalarType::Integer, + expected, + } + } + + /// Allow plain identifiers to be upcast to memory references. + fn case_03() -> Self { + let call_argument = UnresolvedCallArgument::Identifier("integer".to_string()); + let expected = Ok(ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference::new("integer".to_string(), 0), + scalar_type: ScalarType::Integer, + mutable: true, + }); + Self { + call_argument, + return_type: ScalarType::Integer, + expected, + } + } + } + + /// Test resolution of return arguments. + #[rstest] + #[case(ReturnArgumentResolutionTestCase::case_01())] + #[case(ReturnArgumentResolutionTestCase::case_02())] + #[case(ReturnArgumentResolutionTestCase::case_03())] + fn test_return_argument_resolution(#[case] test_case: ReturnArgumentResolutionTestCase) { + let memory_regions = build_declarations(); + + let found = test_case + .call_argument + .resolve_return(&memory_regions, test_case.return_type); + match (test_case.expected, found) { + (Ok(expected), Ok(found)) => assert_eq!(expected, found), + (Ok(expected), Err(found)) => { + panic!("expected resolution {:?}, found err {:?}", expected, found) + } + (Err(expected), Ok(found)) => { + panic!("expected err {:?}, found resolution {:?}", expected, found) + } + (Err(expected), Err(found)) => assert_eq!(expected, found), + } + } + + /// Test cases for resolving call arguments to a specific signature. + struct ResolveToSignatureTestCase { + /// The call instruction to resolve. + call: Call, + /// The signature to resolve to. + signature: ExternSignature, + /// The expected result of the resolution. + expected: Result, CallSignatureError>, + } + + impl ResolveToSignatureTestCase { + /// Valid match with return and parameters + fn case_01() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ], + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + ExternParameter { + name: "baz".to_string(), + mutable: true, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 3, + }), + }, + ], + }; + let resolved = vec![ + ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }, + ResolvedCallArgument::Immediate { + value: Complex64::new(2.0, 0.0), + scalar_type: ScalarType::Integer, + }, + ResolvedCallArgument::Vector { + memory_region_name: "bit".to_string(), + vector: Vector { + data_type: ScalarType::Bit, + length: 3, + }, + mutable: true, + }, + ]; + Self { + call, + signature, + expected: Ok(resolved), + } + } + + /// Valid match with parameteters only + fn case_02() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ], + }; + let signature = ExternSignature { + return_type: None, + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }, + ExternParameter { + name: "baz".to_string(), + mutable: true, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 3, + }), + }, + ], + }; + let resolved = vec![ + ResolvedCallArgument::Immediate { + value: Complex64::new(2.0, 0.0), + scalar_type: ScalarType::Integer, + }, + ResolvedCallArgument::Vector { + memory_region_name: "bit".to_string(), + vector: Vector { + data_type: ScalarType::Bit, + length: 3, + }, + mutable: true, + }, + ]; + Self { + call, + signature, + expected: Ok(resolved), + } + } + + /// Valid match with return only + fn case_03() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + })], + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }; + let resolved = vec![ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }]; + Self { + call, + signature, + expected: Ok(resolved), + } + } + + /// Parameter count mismatch with return and parameters + fn case_04() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ], + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }], + }; + + Self { + call, + signature, + expected: Err(CallSignatureError::ParameterCount { + expected: 2, + found: 3, + }), + } + } + + /// Parameter count mismatch return only + fn case_05() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + ], + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }; + + Self { + call, + signature, + expected: Err(CallSignatureError::ParameterCount { + expected: 1, + found: 2, + }), + } + } + + /// Parameter count mismatch parameters only + fn case_06() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ], + }; + let signature = ExternSignature { + return_type: None, + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }], + }; + + Self { + call, + signature, + expected: Err(CallSignatureError::ParameterCount { + expected: 1, + found: 3, + }), + } + } + + /// Argument mismatch + fn case_07() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + UnresolvedCallArgument::Identifier("bit".to_string()), + ], + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Real), + }], + }; + + Self { + call, + signature, + expected: Err(CallSignatureError::Arguments(vec![ + CallArgumentError::Return(CallArgumentResolutionError::ReturnArgument { + found: UnresolvedCallArgument::Immediate(Complex64::new(2.0, 0.0)), + }), + CallArgumentError::Argument { + index: 0, + error: CallArgumentResolutionError::MismatchedScalar { + expected: ScalarType::Real, + found: ScalarType::Bit, + }, + }, + ])), + } + } + } + + /// Test resolution of `Call` instructions to a specific signature. + #[rstest] + #[case(ResolveToSignatureTestCase::case_01())] + #[case(ResolveToSignatureTestCase::case_02())] + #[case(ResolveToSignatureTestCase::case_03())] + #[case(ResolveToSignatureTestCase::case_04())] + #[case(ResolveToSignatureTestCase::case_05())] + #[case(ResolveToSignatureTestCase::case_06())] + #[case(ResolveToSignatureTestCase::case_07())] + fn test_assert_matching_signature(#[case] test_case: ResolveToSignatureTestCase) { + let memory_regions = build_declarations(); + let found = test_case + .call + .resolve_to_signature(&test_case.signature, &memory_regions); + match (test_case.expected, found) { + (Ok(_), Ok(_)) => {} + (Ok(expected), Err(found)) => { + panic!("expected resolution {:?}, found err {:?}", expected, found) + } + (Err(expected), Ok(found)) => { + panic!("expected err {:?}, found resolution {:?}", expected, found) + } + (Err(expected), Err(found)) => assert_eq!(expected, found), + } + } + + /// Test cases for call resolution against an [`ExternSignatureMap`]. + struct CallResolutionTestCase { + /// The call instruction to resolve. + call: Call, + /// The set of extern definitions to resolve against. + extern_signature_map: ExternSignatureMap, + /// The expected result of the resolution. + expected: Result, CallResolutionError>, + } + + impl CallResolutionTestCase { + /// Valid resolution + fn case_01() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + })], + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }; + let resolved = vec![ResolvedCallArgument::MemoryReference { + memory_reference: MemoryReference { + name: "integer".to_string(), + index: 0, + }, + scalar_type: ScalarType::Integer, + mutable: true, + }]; + Self { + call, + extern_signature_map: ExternSignatureMap( + [("foo".to_string(), signature)].iter().cloned().collect(), + ), + expected: Ok(resolved), + } + } + + /// Signature does not match + fn case_02() -> Self { + let call = Call { + name: "foo".to_string(), + arguments: vec![UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + })], + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Real), + parameters: vec![], + }; + Self { + call, + extern_signature_map: ExternSignatureMap( + [("foo".to_string(), signature)].iter().cloned().collect(), + ), + expected: Err(CallResolutionError::Signature { + name: "foo".to_string(), + error: CallSignatureError::Arguments(vec![CallArgumentError::Return( + CallArgumentResolutionError::MismatchedScalar { + expected: ScalarType::Real, + found: ScalarType::Integer, + }, + )]), + }), + } + } + + /// No corresponding extern definition + fn case_03() -> Self { + let call = Call { + name: "undeclared".to_string(), + arguments: vec![UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + })], + }; + let signature = ExternSignature { + return_type: Some(ScalarType::Real), + parameters: vec![], + }; + Self { + call, + extern_signature_map: ExternSignatureMap( + [("foo".to_string(), signature)].iter().cloned().collect(), + ), + expected: Err(CallResolutionError::NoMatchingExternInstruction( + "undeclared".to_string(), + )), + } + } + } + + /// Test resolution of [`Call`] instructions against a set of extern definitions. + #[rstest] + #[case(CallResolutionTestCase::case_01())] + #[case(CallResolutionTestCase::case_02())] + #[case(CallResolutionTestCase::case_03())] + fn test_call_resolution(#[case] test_case: CallResolutionTestCase) { + let memory_regions = build_declarations(); + let found = test_case + .call + .resolve_arguments(&memory_regions, &test_case.extern_signature_map); + match (test_case.expected, found) { + (Ok(expected), Ok(found)) => { + assert_eq!(expected, found); + } + (Ok(expected), Err(found)) => { + panic!("expected resolution {:?}, found err {:?}", expected, found) + } + (Err(expected), Ok(_)) => { + panic!( + "expected err {:?}, found resolution {:?}", + expected, test_case.call.arguments + ) + } + (Err(expected), Err(found)) => assert_eq!(expected, found), + } + } + + /// Test cases for converting [`ExternPragmaMap`] to [`ExternSignatureMap`]. + struct ExternPragmaMapConverstionTestCase { + /// The set of extern definitions to validate. + extern_pragma_map: ExternPragmaMap, + /// The expected result of the validation. + expected: Result, + } + + impl ExternPragmaMapConverstionTestCase { + /// Valid [`ExternPragmaMap`]s. + fn case_01() -> Self { + let pragma1 = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![PragmaArgument::Identifier("foo".to_string())], + data: Some("(bar : INTEGER)".to_string()), + }; + let signature1 = ExternSignature { + return_type: None, + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Integer), + }], + }; + let pragma2 = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![PragmaArgument::Identifier("baz".to_string())], + data: Some("REAL (biz : REAL)".to_string()), + }; + let signature2 = ExternSignature { + return_type: Some(ScalarType::Real), + parameters: vec![ExternParameter { + name: "biz".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Real), + }], + }; + let pragma3 = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![PragmaArgument::Identifier("buzz".to_string())], + data: Some("OCTET".to_string()), + }; + let signature3 = ExternSignature { + return_type: Some(ScalarType::Octet), + parameters: vec![], + }; + Self { + extern_pragma_map: ExternPragmaMap( + [("foo", pragma1), ("baz", pragma2), ("buzz", pragma3)] + .into_iter() + .map(|(name, pragma)| (Some(name.to_string()), pragma)) + .collect(), + ), + expected: Ok(ExternSignatureMap( + [ + ("foo", signature1), + ("baz", signature2), + ("buzz", signature3), + ] + .into_iter() + .map(|(name, signature)| (name.to_string(), signature)) + .collect(), + )), + } + } + + /// No Signature + fn case_02() -> Self { + let pragma = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![PragmaArgument::Identifier("foo".to_string())], + data: None, + }; + let expected = Err(ExternError::NoSignature); + Self { + extern_pragma_map: ExternPragmaMap( + [(Some("foo".to_string()), pragma)].into_iter().collect(), + ), + expected, + } + } + + /// No return nor parameters + fn case_03() -> Self { + let pragma = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![PragmaArgument::Identifier("foo".to_string())], + data: Some("()".to_string()), + }; + let expected = Err(ExternError::NoReturnOrParameters); + Self { + extern_pragma_map: ExternPragmaMap( + [(Some("foo".to_string()), pragma)].into_iter().collect(), + ), + expected, + } + } + + /// No name + fn case_04() -> Self { + let pragma = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![], + data: Some("(bar : REAL)".to_string()), + }; + let expected = Err(ExternError::NoName); + Self { + extern_pragma_map: ExternPragmaMap([(None, pragma)].into_iter().collect()), + expected, + } + } + + /// Not extern + fn case_05() -> Self { + let pragma = Pragma { + name: "NOTEXTERN".to_string(), + arguments: vec![PragmaArgument::Identifier("foo".to_string())], + data: Some("(bar : REAL)".to_string()), + }; + let expected = Err(ExternError::NoName); + Self { + extern_pragma_map: ExternPragmaMap([(None, pragma)].into_iter().collect()), + expected, + } + } + + /// Extraneous arguments + fn case_06() -> Self { + let pragma = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![ + PragmaArgument::Identifier("foo".to_string()), + PragmaArgument::Identifier("bar".to_string()), + ], + data: Some("OCTET".to_string()), + }; + let expected = Err(ExternError::NoName); + Self { + extern_pragma_map: ExternPragmaMap([(None, pragma)].into_iter().collect()), + expected, + } + } + + /// Integer is not a name + fn case_07() -> Self { + let pragma = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![PragmaArgument::Integer(0)], + data: Some("OCTET".to_string()), + }; + let expected = Err(ExternError::NoName); + Self { + extern_pragma_map: ExternPragmaMap([(None, pragma)].into_iter().collect()), + expected, + } + } + + /// Lex error + fn case_08() -> Self { + let pragma = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![PragmaArgument::Identifier("foo".to_string())], + data: Some("OCTET (ㆆ _ ㆆ)".to_string()), + }; + let expected = Err(ExternSignature::from_str("OCTET (ㆆ _ ㆆ)").unwrap_err()); + Self { + extern_pragma_map: ExternPragmaMap( + [(Some("foo".to_string()), pragma)].into_iter().collect(), + ), + expected, + } + } + + /// Syntax error - missing parenthesis + fn case_09() -> Self { + let pragma = Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![PragmaArgument::Identifier("foo".to_string())], + data: Some("OCTET (bar : INTEGER".to_string()), + }; + let expected = Err(ExternSignature::from_str("OCTET (bar : INTEGER").unwrap_err()); + Self { + extern_pragma_map: ExternPragmaMap( + [(Some("foo".to_string()), pragma)].into_iter().collect(), + ), + expected, + } + } + } + + /// Test conversion of [`ExternPragmaMap`] to [`ExternSignatureMap`]. + #[rstest] + #[case(ExternPragmaMapConverstionTestCase::case_01())] + #[case(ExternPragmaMapConverstionTestCase::case_02())] + #[case(ExternPragmaMapConverstionTestCase::case_03())] + #[case(ExternPragmaMapConverstionTestCase::case_04())] + #[case(ExternPragmaMapConverstionTestCase::case_05())] + #[case(ExternPragmaMapConverstionTestCase::case_06())] + #[case(ExternPragmaMapConverstionTestCase::case_07())] + #[case(ExternPragmaMapConverstionTestCase::case_08())] + #[case(ExternPragmaMapConverstionTestCase::case_09())] + fn test_extern_signature_map_validation(#[case] test_case: ExternPragmaMapConverstionTestCase) { + let found = ExternSignatureMap::try_from(test_case.extern_pragma_map); + match (test_case.expected, found) { + (Ok(expected), Ok(found)) => { + assert_eq!(expected, found); + } + (Ok(_), Err(found)) => { + panic!("expected valid, found err {:?}", found) + } + (Err(expected), Ok(_)) => { + panic!("expected err {:?}, found valid", expected) + } + (Err(expected), Err((_, found))) => assert_eq!(expected, found), + } + } + + /// Test cases for parsing [`ExternSignature`]s. + struct ExternSignatureFromStrTestCase { + /// This string to parse. + input: &'static str, + /// The parsing result. + expected: Result, + } + + impl ExternSignatureFromStrTestCase { + /// Empty signature + fn case_01() -> Self { + Self { + input: "", + expected: Err(ExternError::NoReturnOrParameters), + } + } + + /// Empty signature with parentheses + fn case_02() -> Self { + Self { + input: "()", + expected: Err(ExternError::NoReturnOrParameters), + } + } + + /// Return without parameters + fn case_03() -> Self { + Self { + input: "INTEGER", + expected: Ok(crate::instruction::ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }), + } + } + + /// Return with empty parentheses + fn case_04() -> Self { + Self { + input: "INTEGER ()", + expected: Ok(crate::instruction::ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![], + }), + } + } + + /// Return with parameters + fn case_05() -> Self { + Self { + input: "INTEGER (bar: REAL, baz: BIT[10], biz: mut OCTET)", + expected: Ok(crate::instruction::ExternSignature { + return_type: Some(ScalarType::Integer), + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Real), + }, + ExternParameter { + name: "baz".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 10, + }), + }, + ExternParameter { + name: "biz".to_string(), + mutable: true, + data_type: ExternParameterType::Scalar(ScalarType::Octet), + }, + ], + }), + } + } + + /// Parameters without return + fn case_06() -> Self { + Self { + input: "(bar: REAL, baz: BIT[10], biz : mut OCTET)", + expected: Ok(crate::instruction::ExternSignature { + return_type: None, + parameters: vec![ + ExternParameter { + name: "bar".to_string(), + mutable: false, + data_type: ExternParameterType::Scalar(ScalarType::Real), + }, + ExternParameter { + name: "baz".to_string(), + mutable: false, + data_type: ExternParameterType::FixedLengthVector(Vector { + data_type: ScalarType::Bit, + length: 10, + }), + }, + ExternParameter { + name: "biz".to_string(), + mutable: true, + data_type: ExternParameterType::Scalar(ScalarType::Octet), + }, + ], + }), + } + } + + /// Variable length vector. + fn case_07() -> Self { + Self { + input: "(bar : mut REAL[])", + expected: Ok(crate::instruction::ExternSignature { + return_type: None, + parameters: vec![ExternParameter { + name: "bar".to_string(), + mutable: true, + data_type: ExternParameterType::VariableLengthVector(ScalarType::Real), + }], + }), + } + } + } + + /// Test parsing of `PRAGMA EXTERN` instructions. + #[rstest] + #[case(ExternSignatureFromStrTestCase::case_01())] + #[case(ExternSignatureFromStrTestCase::case_02())] + #[case(ExternSignatureFromStrTestCase::case_03())] + #[case(ExternSignatureFromStrTestCase::case_04())] + #[case(ExternSignatureFromStrTestCase::case_05())] + #[case(ExternSignatureFromStrTestCase::case_06())] + #[case(ExternSignatureFromStrTestCase::case_07())] + fn test_parse_reserved_pragma_extern(#[case] test_case: ExternSignatureFromStrTestCase) { + match ( + test_case.expected, + ExternSignature::from_str(test_case.input), + ) { + (Ok(expected), Ok(parsed)) => { + assert_eq!(expected, parsed); + } + (Ok(expected), Err(e)) => { + panic!("Expected {:?}, got error: {:?}", expected, e); + } + (Err(expected), Ok(parsed)) => { + panic!("Expected error: {:?}, got {:?}", expected, parsed); + } + (Err(expected), Err(found)) => { + let expected = format!("{expected:?}"); + let found = format!("{found:?}"); + assert!( + found.contains(&expected), + "`{}` not in `{}`", + expected, + found + ); + } + } + } +} diff --git a/quil-rs/src/instruction/mod.rs b/quil-rs/src/instruction/mod.rs index dd154dad..1914b666 100644 --- a/quil-rs/src/instruction/mod.rs +++ b/quil-rs/src/instruction/mod.rs @@ -31,6 +31,7 @@ mod circuit; mod classical; mod control_flow; mod declaration; +mod extern_call; mod frame; mod gate; mod measurement; @@ -51,6 +52,7 @@ pub use self::control_flow::{Jump, JumpUnless, JumpWhen, Label, Target, TargetPl pub use self::declaration::{ Declaration, Load, MemoryReference, Offset, ScalarType, Sharing, Store, Vector, }; +pub use self::extern_call::*; pub use self::frame::{ AttributeValue, Capture, FrameAttributes, FrameDefinition, FrameIdentifier, Pulse, RawCapture, SetFrequency, SetPhase, SetScale, ShiftFrequency, ShiftPhase, SwapPhases, @@ -60,7 +62,7 @@ pub use self::gate::{ PauliSum, PauliTerm, }; pub use self::measurement::Measurement; -pub use self::pragma::{Include, Pragma, PragmaArgument}; +pub use self::pragma::{Include, Pragma, PragmaArgument, RESERVED_PRAGMA_EXTERN}; pub use self::qubit::{Qubit, QubitPlaceholder}; pub use self::reset::Reset; pub use self::timing::{Delay, Fence}; @@ -77,6 +79,7 @@ pub enum Instruction { Arithmetic(Arithmetic), BinaryLogic(BinaryLogic), CalibrationDefinition(Calibration), + Call(Call), Capture(Capture), CircuitDefinition(CircuitDefinition), Convert(Convert), @@ -150,6 +153,7 @@ impl From<&Instruction> for InstructionRole { | Instruction::ShiftPhase(_) | Instruction::SwapPhases(_) => InstructionRole::RFControl, Instruction::Arithmetic(_) + | Instruction::Call(_) | Instruction::Comparison(_) | Instruction::Convert(_) | Instruction::BinaryLogic(_) @@ -268,6 +272,7 @@ impl Quil for Instruction { Instruction::CalibrationDefinition(calibration) => { calibration.write(f, fall_back_to_debug) } + Instruction::Call(call) => call.write(f, fall_back_to_debug), Instruction::Capture(capture) => capture.write(f, fall_back_to_debug), Instruction::CircuitDefinition(circuit) => circuit.write(f, fall_back_to_debug), Instruction::Convert(convert) => convert.write(f, fall_back_to_debug), @@ -535,6 +540,7 @@ impl Instruction { Instruction::Arithmetic(_) | Instruction::BinaryLogic(_) | Instruction::CalibrationDefinition(_) + | Instruction::Call(_) | Instruction::CircuitDefinition(_) | Instruction::Comparison(_) | Instruction::Convert(_) @@ -664,6 +670,7 @@ impl Instruction { | Instruction::WaveformDefinition(_) => true, Instruction::Arithmetic(_) | Instruction::BinaryLogic(_) + | Instruction::Call(_) | Instruction::CircuitDefinition(_) | Instruction::Convert(_) | Instruction::Comparison(_) @@ -709,6 +716,7 @@ impl Instruction { Instruction::Arithmetic(_) | Instruction::BinaryLogic(_) | Instruction::CalibrationDefinition(_) + | Instruction::Call(_) | Instruction::CircuitDefinition(_) | Instruction::Convert(_) | Instruction::Comparison(_) @@ -918,11 +926,16 @@ impl InstructionHandler { /// This uses the return value of the override function, if set and returns `Some`. If not set /// or the function returns `None`, defaults to the return value of /// [`Instruction::get_memory_accesses`]. - pub fn memory_accesses(&mut self, instruction: &Instruction) -> MemoryAccesses { + pub fn memory_accesses( + &mut self, + instruction: &Instruction, + extern_signature_map: &ExternSignatureMap, + ) -> crate::program::MemoryAccessesResult { self.get_memory_accesses .as_mut() .and_then(|f| f(instruction)) - .unwrap_or_else(|| instruction.get_memory_accesses()) + .map(Ok) + .unwrap_or_else(|| instruction.get_memory_accesses(extern_signature_map)) } /// Like [`Program::into_simplified`], but using custom instruction handling. diff --git a/quil-rs/src/instruction/pragma.rs b/quil-rs/src/instruction/pragma.rs index d24ffbd0..699d72be 100644 --- a/quil-rs/src/instruction/pragma.rs +++ b/quil-rs/src/instruction/pragma.rs @@ -77,3 +77,5 @@ impl Include { Self { filename } } } + +pub const RESERVED_PRAGMA_EXTERN: &str = "EXTERN"; diff --git a/quil-rs/src/parser/command.rs b/quil-rs/src/parser/command.rs index d1009a94..4a943157 100644 --- a/quil-rs/src/parser/command.rs +++ b/quil-rs/src/parser/command.rs @@ -5,13 +5,13 @@ use nom::sequence::{delimited, pair, preceded, tuple}; use crate::expression::Expression; use crate::instruction::{ - Arithmetic, ArithmeticOperator, BinaryLogic, BinaryOperator, Calibration, Capture, + Arithmetic, ArithmeticOperator, BinaryLogic, BinaryOperator, Calibration, Call, Capture, CircuitDefinition, Comparison, ComparisonOperator, Convert, Declaration, Delay, Exchange, Fence, FrameDefinition, GateDefinition, GateSpecification, GateType, Include, Instruction, Jump, JumpUnless, JumpWhen, Label, Load, MeasureCalibrationDefinition, Measurement, Move, PauliSum, Pragma, PragmaArgument, Pulse, Qubit, RawCapture, Reset, SetFrequency, SetPhase, SetScale, ShiftFrequency, ShiftPhase, Store, SwapPhases, Target, UnaryLogic, UnaryOperator, - ValidationError, Waveform, WaveformDefinition, + UnresolvedCallArgument, ValidationError, Waveform, WaveformDefinition, }; use crate::parser::instruction::parse_block; @@ -19,7 +19,7 @@ use crate::parser::InternalParserResult; use crate::quil::Quil; use crate::{real, token}; -use super::common::parse_variable_qubit; +use super::common::{parse_memory_reference_with_brackets, parse_variable_qubit}; use super::{ common::{ parse_arithmetic_operand, parse_binary_logic_operand, parse_comparison_operand, @@ -120,6 +120,40 @@ pub(crate) fn parse_declare<'a>(input: ParserInput<'a>) -> InternalParserResult< )) } +/// Parse the contents of a `CALL` instruction. +/// +/// Note, the `CALL` instruction here is unresolved; it can only be resolved within the +/// full context of a program from an [`crate::instruction::extern_call::ExternSignatureMap`]. +/// +/// Call instructions are of the form: +/// `CALL @ms{Identifier} @rep[:min 1]{@group{@ms{Identifier} @alt @ms{Memory Reference} @alt @ms{Complex}}}` +/// +/// For additional detail, see ["Call" in the Quil specification](https://github.com/quil-lang/quil/blob/7f532c7cdde9f51eae6abe7408cc868fba9f91f6/specgen/spec/sec-other.s). +pub(crate) fn parse_call<'a>(input: ParserInput<'a>) -> InternalParserResult<'a, Instruction> { + let (input, name) = token!(Identifier(v))(input)?; + + let (input, arguments) = many0(parse_call_argument)(input)?; + let call = Call { name, arguments }; + + Ok((input, Instruction::Call(call))) +} + +fn parse_call_argument<'a>( + input: ParserInput<'a>, +) -> InternalParserResult<'a, UnresolvedCallArgument> { + alt(( + map( + parse_memory_reference_with_brackets, + UnresolvedCallArgument::MemoryReference, + ), + map(token!(Identifier(v)), UnresolvedCallArgument::Identifier), + map( + super::expression::parse_immediate_value, + UnresolvedCallArgument::Immediate, + ), + ))(input) +} + /// Parse the contents of a `CAPTURE` instruction. /// /// Unlike most other instructions, this can be _prefixed_ with the NONBLOCKING keyword, @@ -604,10 +638,11 @@ mod tests { PrefixExpression, PrefixOperator, }; use crate::instruction::{ - GateDefinition, GateSpecification, Offset, PauliGate, PauliSum, PauliTerm, PragmaArgument, - Sharing, + Call, GateDefinition, GateSpecification, Offset, PauliGate, PauliSum, PauliTerm, + PragmaArgument, Sharing, UnresolvedCallArgument, }; use crate::parser::lexer::lex; + use crate::parser::Token; use crate::{imag, real}; use crate::{ instruction::{ @@ -616,6 +651,7 @@ mod tests { }, make_test, }; + use rstest::*; use super::{parse_declare, parse_defcircuit, parse_defgate, parse_measurement, parse_pragma}; @@ -1021,4 +1057,111 @@ mod tests { }) }) ); + + /// Test case for parsing a `CALL` instruction. + struct ParseCallTestCase { + /// The input to parse. + input: &'static str, + /// The remaining tokens after parsing. + remainder: Vec, + /// The expected result. + expected: Result, + } + + impl ParseCallTestCase { + /// Basic call with arguments. + fn case_01() -> Self { + Self { + input: "foo integer[0] 1.0 bar", + remainder: vec![], + expected: Ok(Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(real!(1.0)), + UnresolvedCallArgument::Identifier("bar".to_string()), + ], + }), + } + } + + /// No arguments does in fact parse. + fn case_02() -> Self { + Self { + input: "foo", + remainder: vec![], + expected: Ok(Call { + name: "foo".to_string(), + arguments: vec![], + }), + } + } + + /// Invalid identifier. + fn case_03() -> Self { + Self { + input: "INCLUDE", + remainder: vec![], + expected: Err( + "ExpectedToken { actual: COMMAND(INCLUDE), expected: \"Identifier\" }" + .to_string(), + ), + } + } + + /// Valid with leftover + fn case_04() -> Self { + Self { + input: "foo integer[0] 1.0 bar; baz", + remainder: vec![Token::Semicolon, Token::Identifier("baz".to_string())], + expected: Ok(Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "integer".to_string(), + index: 0, + }), + UnresolvedCallArgument::Immediate(real!(1.0)), + UnresolvedCallArgument::Identifier("bar".to_string()), + ], + }), + } + } + } + + /// Test that the `parse_call` function works as expected. + #[rstest] + #[case(ParseCallTestCase::case_01())] + #[case(ParseCallTestCase::case_02())] + #[case(ParseCallTestCase::case_03())] + #[case(ParseCallTestCase::case_04())] + fn test_parse_call(#[case] test_case: ParseCallTestCase) { + let input = ::nom_locate::LocatedSpan::new(test_case.input); + let tokens = lex(input).unwrap(); + match (test_case.expected, super::parse_call(&tokens)) { + (Ok(expected), Ok((remainder, parsed))) => { + assert_eq!(parsed, Instruction::Call(expected)); + let remainder: Vec<_> = remainder.iter().map(|t| t.as_token().clone()).collect(); + assert_eq!(remainder, test_case.remainder); + } + (Ok(expected), Err(e)) => { + panic!("Expected {:?}, got error: {:?}", expected, e); + } + (Err(expected), Ok((_, parsed))) => { + panic!("Expected error: {:?}, got {:?}", expected, parsed); + } + (Err(expected), Err(found)) => { + let found = format!("{found:?}"); + assert!( + found.contains(&expected), + "`{}` not in `{}`", + expected, + found + ); + } + } + } } diff --git a/quil-rs/src/parser/common.rs b/quil-rs/src/parser/common.rs index a85dd845..dae8ac73 100644 --- a/quil-rs/src/parser/common.rs +++ b/quil-rs/src/parser/common.rs @@ -30,7 +30,7 @@ use crate::{ Vector, WaveformInvocation, WaveformParameters, }, parser::lexer::Operator, - token, + token, unexpected_eof, }; use crate::parser::{InternalParseError, InternalParserResult}; @@ -348,7 +348,7 @@ pub(crate) fn parse_variable_qubit(input: ParserInput) -> InternalParserResult ScalarType { +pub(super) fn match_data_type_token(token: DataType) -> ScalarType { match token { DataType::Bit => ScalarType::Bit, DataType::Integer => ScalarType::Integer, @@ -399,6 +399,18 @@ pub(crate) fn parse_vector<'a>(input: ParserInput<'a>) -> InternalParserResult<' Ok((input, Vector { data_type, length })) } +/// Parse a "vector" which is an integer index, such as `[0]`. The brackets are requried here. +pub(crate) fn parse_vector_with_brackets<'a>( + input: ParserInput<'a>, +) -> InternalParserResult<'a, Vector> { + let (input, data_type_token) = token!(DataType(v))(input)?; + let data_type = match_data_type_token(data_type_token); + + let (input, length) = delimited(token!(LBracket), token!(Integer(v)), token!(RBracket))(input)?; + + Ok((input, Vector { data_type, length })) +} + /// Parse a waveform name which may look like `custom` or `q20_q27_xy/sqrtiSWAP` pub(crate) fn parse_waveform_name<'a>(input: ParserInput<'a>) -> InternalParserResult<'a, String> { use crate::parser::lexer::Operator::Slash; @@ -425,6 +437,15 @@ pub(crate) fn skip_newlines_and_comments<'a>( Ok((input, ())) } +/// Returns successfully if the head of input is the identifier `i`, returns error otherwise. +pub(super) fn parse_i(input: ParserInput) -> InternalParserResult<()> { + match super::split_first_token(input) { + None => unexpected_eof!(input), + Some((Token::Identifier(v), remainder)) if v == "i" => Ok((remainder, ())), + Some((other_token, _)) => expected_token!(input, other_token, "i".to_owned()), + } +} + #[cfg(test)] mod describe_skip_newlines_and_comments { use crate::parser::lex; diff --git a/quil-rs/src/parser/expression.rs b/quil-rs/src/parser/expression.rs index 63f73d62..24a7bd7a 100644 --- a/quil-rs/src/parser/expression.rs +++ b/quil-rs/src/parser/expression.rs @@ -13,6 +13,7 @@ // limitations under the License. use nom::combinator::opt; +use num_complex::Complex64; use crate::expression::{FunctionCallExpression, InfixExpression, PrefixExpression}; use crate::parser::InternalParserResult; @@ -25,6 +26,7 @@ use crate::{ token, unexpected_eof, }; +use super::common::parse_i; use super::lexer::{Operator, Token}; use super::ParserInput; @@ -74,31 +76,19 @@ pub(crate) fn parse_expression(input: ParserInput) -> InternalParserResult InternalParserResult { let (input, prefix) = opt(parse_prefix)(input)?; - let (mut input, mut left) = match super::split_first_token(input) { - None => unexpected_eof!(input), - Some((Token::Integer(value), remainder)) => { - let (remainder, imaginary) = opt(parse_i)(remainder)?; - match imaginary { - None => Ok((remainder, Expression::Number(crate::real!(*value as f64)))), - Some(_) => Ok((remainder, Expression::Number(crate::imag!(*value as f64)))), + let (input, maybe_immediate_value) = opt(parse_immediate_value)(input)?; + + let (mut input, mut left) = maybe_immediate_value + .map(|number| Ok((input, Expression::Number(number)))) + .unwrap_or_else(|| match super::split_first_token(input) { + None => unexpected_eof!(input), + Some((Token::Variable(name), remainder)) => { + Ok((remainder, Expression::Variable(name.clone()))) } - } - Some((Token::Float(value), remainder)) => { - let (remainder, imaginary) = opt(parse_i)(remainder)?; - match imaginary { - None => Ok((remainder, Expression::Number(crate::real!(*value)))), - Some(_) => Ok((remainder, Expression::Number(crate::imag!(*value)))), - } - } - Some((Token::Variable(name), remainder)) => { - Ok((remainder, Expression::Variable(name.clone()))) - } - Some((Token::Identifier(_), _)) => parse_expression_identifier(input), - Some((Token::LParenthesis, remainder)) => parse_grouped_expression(remainder), - Some((token, _)) => { - expected_token!(input, token, "expression".to_owned()) - } - }?; + Some((Token::Identifier(_), _)) => parse_expression_identifier(input), + Some((Token::LParenthesis, remainder)) => parse_grouped_expression(remainder), + Some((token, _)) => expected_token!(input, token, "expression".to_owned()), + })?; if let Some(prefix) = prefix { left = Expression::Prefix(PrefixExpression { @@ -122,12 +112,24 @@ fn parse(input: ParserInput, precedence: Precedence) -> InternalParserResult InternalParserResult<()> { +pub(super) fn parse_immediate_value(input: ParserInput) -> InternalParserResult { match super::split_first_token(input) { + Some((Token::Integer(value), remainder)) => { + let (remainder, imaginary) = opt(parse_i)(remainder)?; + match imaginary { + None => Ok((remainder, crate::real!(*value as f64))), + Some(_) => Ok((remainder, crate::imag!(*value as f64))), + } + } + Some((Token::Float(value), remainder)) => { + let (remainder, imaginary) = opt(parse_i)(remainder)?; + match imaginary { + None => Ok((remainder, crate::real!(*value))), + Some(_) => Ok((remainder, crate::imag!(*value))), + } + } + Some((token, _)) => expected_token!(input, token, "integer or float".to_owned()), None => unexpected_eof!(input), - Some((Token::Identifier(v), remainder)) if v == "i" => Ok((remainder, ())), - Some((other_token, _)) => expected_token!(input, other_token, "i".to_owned()), } } diff --git a/quil-rs/src/parser/instruction.rs b/quil-rs/src/parser/instruction.rs index 8c539311..d9900d92 100644 --- a/quil-rs/src/parser/instruction.rs +++ b/quil-rs/src/parser/instruction.rs @@ -45,6 +45,7 @@ pub(crate) fn parse_instruction(input: ParserInput) -> InternalParserResult match command { Command::Add => command::parse_arithmetic(ArithmeticOperator::Add, remainder), Command::And => command::parse_logical_binary(BinaryOperator::And, remainder), + Command::Call => command::parse_call(remainder), Command::Capture => command::parse_capture(remainder, true), Command::Convert => command::parse_convert(remainder), Command::Declare => command::parse_declare(remainder), diff --git a/quil-rs/src/parser/lexer/mod.rs b/quil-rs/src/parser/lexer/mod.rs index cd4d95a0..9c043cc6 100644 --- a/quil-rs/src/parser/lexer/mod.rs +++ b/quil-rs/src/parser/lexer/mod.rs @@ -17,7 +17,7 @@ mod quoted_strings; mod wrapped_parsers; use nom::{ - bytes::complete::{is_a, take_till, take_while, take_while1}, + bytes::complete::{is_a, tag_no_case, take_till, take_while, take_while1}, character::complete::{digit1, one_of}, combinator::{all_consuming, map, recognize, value}, multi::many0, @@ -39,6 +39,7 @@ pub use error::{LexError, LexErrorKind}; pub enum Command { Add, And, + Call, Capture, Convert, Declare, @@ -205,6 +206,7 @@ fn recognize_command_or_identifier(identifier: String) -> Token { "DEFGATE" => Token::Command(DefGate), "ADD" => Token::Command(Add), "AND" => Token::Command(And), + "CALL" => Token::Command(Call), "CONVERT" => Token::Command(Convert), "DIV" => Token::Command(Div), "EQ" => Token::Command(Eq), @@ -325,6 +327,7 @@ fn lex_modifier(input: LexInput) -> InternalLexResult { value(Token::Modifier(Modifier::Controlled), tag("CONTROLLED")), value(Token::Modifier(Modifier::Dagger), tag("DAGGER")), value(Token::Modifier(Modifier::Forked), tag("FORKED")), + value(Token::Mutable, tag_no_case("MUT")), value(Token::Offset, tag("OFFSET")), value(Token::PauliSum, tag("PAULI-SUM")), value(Token::Permutation, tag("PERMUTATION")), diff --git a/quil-rs/src/parser/mod.rs b/quil-rs/src/parser/mod.rs index 95dc8f75..d4d00cf1 100644 --- a/quil-rs/src/parser/mod.rs +++ b/quil-rs/src/parser/mod.rs @@ -26,6 +26,7 @@ mod error; mod expression; pub(crate) mod instruction; mod lexer; +pub(crate) mod pragma_extern; mod token; pub(crate) use error::{ErrorInput, InternalParseError}; @@ -33,7 +34,7 @@ pub use error::{ParseError, ParserErrorKind}; pub use lexer::LexError; pub use token::{Token, TokenWithLocation}; -type ParserInput<'a> = &'a [TokenWithLocation<'a>]; +pub(crate) type ParserInput<'a> = &'a [TokenWithLocation<'a>]; type InternalParserResult<'a, R, E = InternalParseError<'a>> = IResult, R, E>; /// Pops the first token off of the `input` and returns it and the remaining input. diff --git a/quil-rs/src/parser/pragma_extern.rs b/quil-rs/src/parser/pragma_extern.rs new file mode 100644 index 00000000..2c29ca2c --- /dev/null +++ b/quil-rs/src/parser/pragma_extern.rs @@ -0,0 +1,88 @@ +use nom::{ + branch::alt, + combinator::{map, opt}, + multi::separated_list0, +}; + +use crate::{ + instruction::{ExternParameter, ExternParameterType, ExternSignature, ScalarType}, + token, +}; + +use super::{ + common::{match_data_type_token, parse_vector_with_brackets}, + InternalParserResult, ParserInput, +}; + +/// Parse an [`ExternSignature`] from a string. Note, externs are currently defined within a +/// [`crate::instruction::Pragma`] instruction as `PRAGMA EXTERN foo "signature"`; the "signature" +/// currently represents its own mini-language within the Quil specification. +/// +/// Signatures are of the form: +/// `@rep[:min 0 :max 1]{@ms{Base Type}} ( @ms{Extern Parameter} @rep[:min 0]{@group{ , @ms{Extern Parameter} }} )` +/// +/// For details on the signature format, see the [Quil specification for "Extern Signature"](https://github.com/quil-lang/quil/blob/7f532c7cdde9f51eae6abe7408cc868fba9f91f6/specgen/spec/sec-other.s). +/// +/// Note, there are test cases for this parser in [`crate::instruction::extern_call::tests`] (via +/// [`std::str::FromStr`] for [`crate::instruction::ExternSignature`]). +pub(crate) fn parse_extern_signature<'a>( + input: ParserInput<'a>, +) -> InternalParserResult<'a, ExternSignature> { + let (input, return_type) = opt(token!(DataType(v)))(input)?; + let (input, lparen) = opt(token!(LParenthesis))(input)?; + let (input, parameters) = if lparen.is_some() { + let (input, parameters) = + opt(separated_list0(token!(Comma), parse_extern_parameter))(input)?; + let (input, _) = token!(RParenthesis)(input)?; + (input, parameters.unwrap_or_default()) + } else { + (input, vec![]) + }; + + let signature = ExternSignature { + return_type: return_type.map(match_data_type_token), + parameters, + }; + + Ok((input, signature)) +} + +fn parse_extern_parameter<'a>(input: ParserInput<'a>) -> InternalParserResult<'a, ExternParameter> { + let (input, name) = token!(Identifier(v))(input)?; + let (input, _) = token!(Colon)(input)?; + let (input, mutable) = opt(token!(Mutable))(input)?; + let (input, data_type) = alt(( + map( + parse_vector_with_brackets, + ExternParameterType::FixedLengthVector, + ), + map( + parse_variable_length_vector, + ExternParameterType::VariableLengthVector, + ), + map(token!(DataType(v)), |data_type| { + ExternParameterType::Scalar(match_data_type_token(data_type)) + }), + ))(input)?; + Ok(( + input, + ExternParameter { + name, + mutable: mutable.is_some(), + data_type, + }, + )) +} + +/// Parse a variable length [`crate::instruction::Vector`], which is represented as [`ScalarType`] +/// followed by empty brackets `[]`. +fn parse_variable_length_vector<'a>( + input: ParserInput<'a>, +) -> InternalParserResult<'a, ScalarType> { + let (input, data_type_token) = token!(DataType(v))(input)?; + let data_type = match_data_type_token(data_type_token); + let (input, _) = token!(LBracket)(input)?; + let (input, _) = token!(RBracket)(input)?; + + Ok((input, data_type)) +} diff --git a/quil-rs/src/parser/token.rs b/quil-rs/src/parser/token.rs index 18210b27..6f0b1a9e 100644 --- a/quil-rs/src/parser/token.rs +++ b/quil-rs/src/parser/token.rs @@ -85,6 +85,7 @@ pub enum Token { NonBlocking, Matrix, Modifier(Modifier), + Mutable, NewLine, Operator(Operator), Offset, @@ -117,6 +118,7 @@ impl fmt::Display for Token { Token::NonBlocking => write!(f, "NONBLOCKING"), Token::Matrix => write!(f, "MATRIX"), Token::Modifier(m) => write!(f, "{m}"), + Token::Mutable => write!(f, "MUT"), Token::NewLine => write!(f, "NEWLINE"), Token::Operator(op) => write!(f, "{op}"), Token::Offset => write!(f, "OFFSET"), @@ -151,6 +153,7 @@ impl fmt::Debug for Token { Token::NonBlocking => write!(f, "{self}"), Token::Matrix => write!(f, "{self}"), Token::Modifier(m) => write!(f, "MODIFIER({m})"), + Token::Mutable => write!(f, "{self}"), Token::NewLine => write!(f, "NEWLINE"), Token::Operator(op) => write!(f, "OPERATOR({op})"), Token::Offset => write!(f, "{self}"), diff --git a/quil-rs/src/program/analysis/control_flow_graph.rs b/quil-rs/src/program/analysis/control_flow_graph.rs index 98f60170..f050d0f6 100644 --- a/quil-rs/src/program/analysis/control_flow_graph.rs +++ b/quil-rs/src/program/analysis/control_flow_graph.rs @@ -418,6 +418,7 @@ impl<'p> From<&'p Program> for ControlFlowGraph<'p> { match instruction { Instruction::Arithmetic(_) | Instruction::BinaryLogic(_) + | Instruction::Call(_) | Instruction::Capture(_) | Instruction::Convert(_) | Instruction::Comparison(_) diff --git a/quil-rs/src/program/memory.rs b/quil-rs/src/program/memory.rs index 3d914aff..c733ec0f 100644 --- a/quil-rs/src/program/memory.rs +++ b/quil-rs/src/program/memory.rs @@ -16,11 +16,12 @@ use std::collections::HashSet; use crate::expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression}; use crate::instruction::{ - Arithmetic, ArithmeticOperand, BinaryLogic, BinaryOperand, Capture, CircuitDefinition, - Comparison, ComparisonOperand, Convert, Delay, Exchange, Gate, GateDefinition, - GateSpecification, Instruction, JumpUnless, JumpWhen, Load, MeasureCalibrationDefinition, - Measurement, MemoryReference, Move, Pulse, RawCapture, SetFrequency, SetPhase, SetScale, - Sharing, ShiftFrequency, ShiftPhase, Store, UnaryLogic, Vector, WaveformInvocation, + Arithmetic, ArithmeticOperand, BinaryLogic, BinaryOperand, CallResolutionError, Capture, + CircuitDefinition, Comparison, ComparisonOperand, Convert, Delay, Exchange, ExternSignatureMap, + Gate, GateDefinition, GateSpecification, Instruction, JumpUnless, JumpWhen, Load, + MeasureCalibrationDefinition, Measurement, MemoryReference, Move, Pulse, RawCapture, + SetFrequency, SetPhase, SetScale, Sharing, ShiftFrequency, ShiftPhase, Store, UnaryLogic, + Vector, WaveformInvocation, }; #[derive(Clone, Debug, Hash, PartialEq)] @@ -43,7 +44,7 @@ pub struct MemoryAccess { pub access_type: MemoryAccessType, } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct MemoryAccesses { pub captures: HashSet, pub reads: HashSet, @@ -92,10 +93,26 @@ macro_rules! set_from_memory_references { }; } +#[derive(thiserror::Error, Debug, PartialEq)] +pub enum MemoryAccessesError { + #[error(transparent)] + CallResolution(#[from] CallResolutionError), +} + +pub type MemoryAccessesResult = Result; + impl Instruction { - /// Return all memory accesses by the instruction - in expressions, captures, and memory manipulation - pub fn get_memory_accesses(&self) -> MemoryAccesses { - match self { + /// Return all memory accesses by the instruction - in expressions, captures, and memory manipulation. + /// + /// This will fail if the program contains [`Instruction::Call`] instructions that cannot + /// be resolved against a signature in the provided [`ExternSignatureMap`] (either because + /// they call functions that don't appear in the map or because the types of the parameters + /// are wrong). + pub fn get_memory_accesses( + &self, + extern_signature_map: &ExternSignatureMap, + ) -> MemoryAccessesResult { + Ok(match self { Instruction::Convert(Convert { source, destination, @@ -104,6 +121,7 @@ impl Instruction { writes: set_from_memory_references![[destination]], ..Default::default() }, + Instruction::Call(call) => call.get_memory_accesses(extern_signature_map)?, Instruction::Comparison(Comparison { destination, lhs, @@ -187,14 +205,17 @@ impl Instruction { | Instruction::MeasureCalibrationDefinition(MeasureCalibrationDefinition { instructions, .. - }) => instructions.iter().fold(Default::default(), |acc, el| { - let el_accesses = el.get_memory_accesses(); - MemoryAccesses { - reads: merge_sets!(acc.reads, el_accesses.reads), - writes: merge_sets!(acc.writes, el_accesses.writes), - captures: merge_sets!(acc.captures, el_accesses.captures), - } - }), + }) => instructions.iter().try_fold( + Default::default(), + |acc: MemoryAccesses, el| -> MemoryAccessesResult { + let el_accesses = el.get_memory_accesses(extern_signature_map)?; + Ok(MemoryAccesses { + reads: merge_sets!(acc.reads, el_accesses.reads), + writes: merge_sets!(acc.writes, el_accesses.writes), + captures: merge_sets!(acc.captures, el_accesses.captures), + }) + }, + )?, Instruction::Delay(Delay { duration, .. }) => MemoryAccesses { reads: set_from_memory_references!(duration.get_memory_references()), ..Default::default() @@ -301,7 +322,7 @@ impl Instruction { | Instruction::Reset(_) | Instruction::SwapPhases(_) | Instruction::WaveformDefinition(_) => Default::default(), - } + }) } } @@ -354,8 +375,8 @@ mod tests { use crate::expression::Expression; use crate::instruction::{ - ArithmeticOperand, Convert, Exchange, FrameIdentifier, Instruction, MemoryReference, Qubit, - SetFrequency, ShiftFrequency, Store, + ArithmeticOperand, Convert, Exchange, ExternSignatureMap, FrameIdentifier, Instruction, + MemoryReference, Qubit, SetFrequency, ShiftFrequency, Store, }; use crate::program::MemoryAccesses; use std::collections::HashSet; @@ -451,7 +472,9 @@ mod tests { #[case] instruction: Instruction, #[case] expected: MemoryAccesses, ) { - let memory_accesses = instruction.get_memory_accesses(); + let memory_accesses = instruction + .get_memory_accesses(&ExternSignatureMap::default()) + .expect("must be able to get memory accesses"); assert_eq!(memory_accesses.captures, expected.captures); assert_eq!(memory_accesses.reads, expected.reads); assert_eq!(memory_accesses.writes, expected.writes); diff --git a/quil-rs/src/program/mod.rs b/quil-rs/src/program/mod.rs index 60792ad9..d80256c2 100644 --- a/quil-rs/src/program/mod.rs +++ b/quil-rs/src/program/mod.rs @@ -9,7 +9,7 @@ // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and +//See the License for the specific language governing permissions and // limitations under the License. use std::collections::{HashMap, HashSet}; @@ -21,20 +21,24 @@ use ndarray::Array2; use nom_locate::LocatedSpan; use crate::instruction::{ - Arithmetic, ArithmeticOperand, ArithmeticOperator, Declaration, FrameDefinition, - FrameIdentifier, GateDefinition, GateError, Instruction, InstructionHandler, Jump, JumpUnless, - Label, Matrix, MemoryReference, Move, Qubit, QubitPlaceholder, ScalarType, Target, - TargetPlaceholder, Vector, Waveform, WaveformDefinition, + Arithmetic, ArithmeticOperand, ArithmeticOperator, Call, Declaration, ExternPragmaMap, + FrameDefinition, FrameIdentifier, GateDefinition, GateError, Instruction, InstructionHandler, + Jump, JumpUnless, Label, Matrix, MemoryReference, Move, Qubit, QubitPlaceholder, ScalarType, + Target, TargetPlaceholder, Vector, Waveform, WaveformDefinition, RESERVED_PRAGMA_EXTERN, }; use crate::parser::{lex, parse_instructions, ParseError}; use crate::quil::Quil; pub use self::calibration::Calibrations; pub use self::calibration_set::CalibrationSet; -pub use self::error::{disallow_leftover, map_parsed, recover, ParseProgramError, SyntaxError}; +pub use self::error::{ + disallow_leftover, map_parsed, recover, LeftoverError, ParseProgramError, SyntaxError, +}; pub use self::frame::FrameSet; pub use self::frame::MatchedFrames; -pub use self::memory::{MemoryAccess, MemoryAccesses, MemoryRegion}; +pub use self::memory::{ + MemoryAccess, MemoryAccesses, MemoryAccessesError, MemoryAccessesResult, MemoryRegion, +}; pub mod analysis; mod calibration; @@ -73,6 +77,7 @@ type Result = std::result::Result; #[derive(Clone, Debug, Default, PartialEq)] pub struct Program { pub calibrations: Calibrations, + extern_pragma_map: ExternPragmaMap, pub frames: FrameSet, pub memory_regions: IndexMap, pub waveforms: IndexMap, @@ -86,6 +91,7 @@ impl Program { pub fn new() -> Self { Program { calibrations: Calibrations::default(), + extern_pragma_map: ExternPragmaMap::default(), frames: FrameSet::new(), memory_regions: IndexMap::new(), waveforms: IndexMap::new(), @@ -123,6 +129,7 @@ impl Program { Self { instructions: Vec::new(), calibrations: self.calibrations.clone(), + extern_pragma_map: self.extern_pragma_map.clone(), frames: self.frames.clone(), memory_regions: self.memory_regions.clone(), gate_definitions: self.gate_definitions.clone(), @@ -132,6 +139,11 @@ impl Program { } /// Add an instruction to the end of the program. + /// + /// Note, parsing extern signatures is deferred here to maintain infallibility + /// of [`Program::add_instruction`]. This means that invalid `PRAGMA EXTERN` + /// instructions are still added to the [`Program::extern_pragma_map`]; + /// duplicate `PRAGMA EXTERN` names are overwritten. pub fn add_instruction(&mut self, instruction: Instruction) { self.used_qubits .extend(instruction.get_qubits().into_iter().cloned()); @@ -187,6 +199,9 @@ impl Program { Instruction::Pulse(pulse) => { self.instructions.push(Instruction::Pulse(pulse)); } + Instruction::Pragma(pragma) if pragma.name == RESERVED_PRAGMA_EXTERN => { + self.extern_pragma_map.insert(pragma); + } Instruction::RawCapture(raw_capture) => { self.instructions.push(Instruction::RawCapture(raw_capture)); } @@ -253,6 +268,7 @@ impl Program { let mut new_program = Self { calibrations: self.calibrations.clone(), + extern_pragma_map: self.extern_pragma_map.clone(), frames: self.frames.clone(), memory_regions: self.memory_regions.clone(), waveforms: self.waveforms.clone(), @@ -361,6 +377,7 @@ impl Program { let mut frames_used: HashSet<&FrameIdentifier> = HashSet::new(); let mut waveforms_used: HashSet<&String> = HashSet::new(); + let mut extern_signatures_used: HashSet<&String> = HashSet::new(); for instruction in &expanded_program.instructions { if let Some(matched_frames) = @@ -372,12 +389,23 @@ impl Program { if let Some(waveform) = instruction.get_waveform_invocation() { waveforms_used.insert(&waveform.name); } + + if let Instruction::Call(Call { name, .. }) = instruction { + extern_signatures_used.insert(name); + } } expanded_program.frames = self.frames.intersection(&frames_used); expanded_program .waveforms .retain(|name, _definition| waveforms_used.contains(name)); + expanded_program + .extern_pragma_map + .retain(|name, _signature| { + name.as_ref() + .map(|name| extern_signatures_used.contains(name)) + .unwrap_or(false) + }); Ok(expanded_program) } @@ -389,6 +417,8 @@ impl Program { /// - All calibrations, following calibration expansion /// - Frame definitions which are not used by any instruction such as `PULSE` or `CAPTURE` /// - Waveform definitions which are not used by any instruction + /// - `PRAGMA EXTERN` instructions which are not used by any `CALL` instruction (see + /// [`Program::extern_pragma_map`]). /// /// When a valid program is simplified, it remains valid. /// @@ -582,12 +612,14 @@ impl Program { + self.waveforms.len() + self.gate_definitions.len() + self.instructions.len() + + self.extern_pragma_map.len() } /// Return a copy of all of the instructions which constitute this [`Program`]. pub fn to_instructions(&self) -> Vec { let mut instructions: Vec = Vec::with_capacity(self.len()); + instructions.extend(self.extern_pragma_map.to_instructions()); instructions.extend(self.memory_regions.iter().map(|(name, descriptor)| { Instruction::Declaration(Declaration { name: name.clone(), @@ -713,9 +745,11 @@ mod tests { use crate::{ imag, instruction::{ - Gate, Instruction, Jump, JumpUnless, JumpWhen, Label, Matrix, MemoryReference, Qubit, - QubitPlaceholder, Target, TargetPlaceholder, + Call, Declaration, ExternSignatureMap, Gate, Instruction, Jump, JumpUnless, JumpWhen, + Label, Matrix, MemoryReference, Qubit, QubitPlaceholder, ScalarType, Target, + TargetPlaceholder, UnresolvedCallArgument, Vector, RESERVED_PRAGMA_EXTERN, }, + program::MemoryAccesses, quil::{Quil, INDENT}, real, }; @@ -1523,4 +1557,100 @@ DEFFRAME 0 \"xy\": assert_eq!(new_program.to_quil().unwrap(), quil); } } + + /// Test that a program with a `CALL` instruction can be parsed and properly resolved to + /// the corresponding `EXTERN` instruction. Additionally, test that the memory accesses are + /// correctly calculated with the resolved `CALL` instruction. + #[test] + fn test_extern_call() { + let input = r#"PRAGMA EXTERN foo "OCTET (params : mut REAL[3])" +DECLARE reals REAL[3] +DECLARE octets OCTET[3] +CALL foo octets[1] reals +"#; + let program = Program::from_str(input).expect("should be able to parse program"); + let reserialized = program + .to_quil() + .expect("should be able to serialize program"); + assert_eq!(input, reserialized); + + let pragma = crate::instruction::Pragma { + name: RESERVED_PRAGMA_EXTERN.to_string(), + arguments: vec![crate::instruction::PragmaArgument::Identifier( + "foo".to_string(), + )], + data: Some("OCTET (params : mut REAL[3])".to_string()), + }; + let call = Call { + name: "foo".to_string(), + arguments: vec![ + UnresolvedCallArgument::MemoryReference(MemoryReference { + name: "octets".to_string(), + index: 1, + }), + UnresolvedCallArgument::Identifier("reals".to_string()), + ], + }; + let expected_program = Program::from_instructions(vec![ + Instruction::Declaration(Declaration::new( + "reals".to_string(), + Vector::new(ScalarType::Real, 3), + None, + )), + Instruction::Declaration(Declaration::new( + "octets".to_string(), + Vector::new(ScalarType::Octet, 3), + None, + )), + Instruction::Pragma(pragma.clone()), + Instruction::Call(call.clone()), + ]); + assert_eq!(expected_program, program); + + let extern_signature_map = ExternSignatureMap::try_from(program.extern_pragma_map) + .expect("should be able parse extern pragmas"); + assert_eq!(extern_signature_map.len(), 1); + + assert_eq!( + Instruction::Pragma(pragma) + .get_memory_accesses(&extern_signature_map) + .expect("should be able to get memory accesses"), + MemoryAccesses::default() + ); + + assert_eq!( + call.get_memory_accesses(&extern_signature_map) + .expect("should be able to get memory accesses"), + MemoryAccesses { + reads: ["octets", "reals"].into_iter().map(String::from).collect(), + writes: ["octets", "reals"].into_iter().map(String::from).collect(), + ..MemoryAccesses::default() + } + ); + } + + /// Test that unused `PRAGMA EXTERN` instructions are removed when simplifying a program. + #[test] + fn test_extern_call_simplification() { + let input = r#"PRAGMA EXTERN foo "OCTET (params : mut REAL[3])" +PRAGMA EXTERN bar "OCTET (params : mut REAL[3])" +DECLARE reals REAL[3] +DECLARE octets OCTET[3] +CALL foo octets[1] reals +"#; + let program = Program::from_str(input).expect("should be able to parse program"); + + let expected = r#"PRAGMA EXTERN foo "OCTET (params : mut REAL[3])" +DECLARE reals REAL[3] +DECLARE octets OCTET[3] +CALL foo octets[1] reals +"#; + + let reserialized = program + .into_simplified() + .expect("should be able to simplify program") + .to_quil() + .expect("should be able to serialize program"); + assert_eq!(expected, reserialized); + } } diff --git a/quil-rs/src/program/scheduling/graph.rs b/quil-rs/src/program/scheduling/graph.rs index 2e9cd4e8..878524a2 100644 --- a/quil-rs/src/program/scheduling/graph.rs +++ b/quil-rs/src/program/scheduling/graph.rs @@ -19,7 +19,9 @@ use std::collections::{HashMap, HashSet}; use petgraph::graphmap::GraphMap; use petgraph::Directed; -use crate::instruction::{FrameIdentifier, Instruction, InstructionHandler, Target}; +use crate::instruction::{ + ExternSignatureMap, FrameIdentifier, Instruction, InstructionHandler, Target, +}; use crate::program::analysis::{ BasicBlock, BasicBlockOwned, BasicBlockTerminator, ControlFlowGraph, }; @@ -30,7 +32,9 @@ pub use crate::program::memory::MemoryAccessType; #[derive(Debug, Clone, Copy)] pub enum ScheduleErrorVariant { DuplicateLabel, + Extern, UncalibratedInstruction, + UnresolvedCallInstruction, UnschedulableInstruction, } @@ -300,10 +304,22 @@ impl<'a> ScheduledBasicBlock<'a> { // NOTE: this may be refined to serialize by memory region offset rather than by entire region. let mut pending_memory_access: HashMap = HashMap::new(); + let extern_signature_map = ExternSignatureMap::try_from(program.extern_pragma_map.clone()) + .map_err(|(pragma, _)| ScheduleError { + instruction_index: None, + instruction: Instruction::Pragma(pragma), + variant: ScheduleErrorVariant::Extern, + })?; for (index, &instruction) in basic_block.instructions().iter().enumerate() { let node = graph.add_node(ScheduledGraphNode::InstructionIndex(index)); - let accesses = custom_handler.memory_accesses(instruction); + let accesses = custom_handler + .memory_accesses(instruction, &extern_signature_map) + .map_err(|_| ScheduleError { + instruction_index: Some(index), + instruction: instruction.clone(), + variant: ScheduleErrorVariant::UnresolvedCallInstruction, + })?; let memory_dependencies = [ (accesses.reads, MemoryAccessType::Read),