From 12b1118a190d102a68b137cf3fdef377d6724cb0 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Wed, 13 Nov 2024 17:26:12 -0800 Subject: [PATCH 01/13] new wires --- mrmustard/physics/new_wires.py | 598 +++++++++++++++++++++++++++++++ mrmustard/widgets/__init__.py | 4 +- tests/test_physics/test_wires.py | 89 +---- 3 files changed, 605 insertions(+), 86 deletions(-) create mode 100644 mrmustard/physics/new_wires.py diff --git a/mrmustard/physics/new_wires.py b/mrmustard/physics/new_wires.py new file mode 100644 index 000000000..4417167d2 --- /dev/null +++ b/mrmustard/physics/new_wires.py @@ -0,0 +1,598 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from random import randint +from copy import deepcopy +from functools import cached_property, lru_cache +from enum import Enum, auto +from IPython.display import display + +from mrmustard import widgets + +__all__ = ["Wires"] + + +class Repr(Enum): + UNSPECIFIED = auto() + BARGMANN = auto() + FOCK = auto() + QUADRATURE = auto() + PHASESPACE = auto() + CHARACTERISTIC = auto() + + +@dataclass +class QuantumWire: + mode: int + is_out: bool + is_ket: bool + index: int + repr: Repr = Repr.UNSPECIFIED + id: int = field(default_factory=lambda: randint(0, 2**32 - 1)) + + @property + def is_dv(self) -> bool: + return self.repr == Repr.FOCK + + def __hash__(self) -> int: + return hash((self.mode, self.is_out, self.is_ket)) + + def __repr__(self) -> str: + return f"QuantumWire(mode={self.mode}, out={self.is_out}, ket={self.is_ket}, dv={self.is_dv}, repr={self.repr}, index={self.index})" + + def __eq__(self, other: QuantumWire) -> bool: + return ( + self.mode == other.mode + and self.is_out == other.is_out + and self.is_ket == other.is_ket + and self.is_dv == other.is_dv + and self.repr == other.repr + ) + + +@dataclass +class ClassicalWire: + mode: int + is_out: bool + index: int + repr: Repr = Repr.UNSPECIFIED + id: int = field(default_factory=lambda: randint(0, 2**32 - 1)) + + @property + def is_dv(self) -> bool: + return self.repr == Repr.FOCK + + def __hash__(self) -> int: + return hash((self.mode, self.is_out, self.is_dv)) + + def __repr__(self) -> str: + return f"ClassicalWire(mode={self.mode}, out={self.is_out}, dv={self.is_dv}, index={self.index})" + + def __eq__(self, other: ClassicalWire) -> bool: + return self.mode == other.mode and self.is_out == other.is_out and self.is_dv == other.is_dv + + +class Wires: + r""" + A class with wire functionality for tensor network applications. + + In MrMustard, instances of ``CircuitComponent`` have a ``Wires`` attribute. + The wires describe how they connect with the surrounding components in a tensor network picture, + where states flow from left to right. ``CircuitComponent``\s can have wires on the + bra and/or on the ket side. Additionally, they may have classical wires. Here are some examples + for the types of components available on ``mrmustard.lab_dev``: + + .. code-block:: + + A channel acting on mode ``1`` has input and output wires on both ket and bra sides: + + ┌──────┐ 1 ╔═════════╗ 1 ┌───────┐ + │Bra in│─────▶║ ║─────▶│Bra out│ + └──────┘ ║ Channel ║ └───────┘ + ┌──────┐ 1 ║ ║ 1 ┌───────┐ + │Ket in│─────▶║ ║─────▶│Ket out│ + └──────┘ ╚═════════╝ └───────┘ + + + A unitary acting on mode ``2`` has input and output wires on the ket side: + + ┌──────┐ 2 ╔═════════╗ 2 ┌───────┐ + │Ket in│─────▶║ Unitary ║─────▶│Ket out│ + └──────┘ ╚═════════╝ └───────┘ + + + A density matrix representing the state of mode ``0`` has only output wires: + + ╔═════════╗ 0 ┌───────┐ + ║ ║─────▶│Bra out│ + ║ Density ║ └───────┘ + ║ Matrix ║ 0 ┌───────┐ + ║ ║─────▶│Ket out│ + ╚═════════╝ └───────┘ + + + A ket representing the state of mode ``1`` has only output wires: + + ╔═════════╗ 1 ┌───────┐ + ║ Ket ║─────▶│Ket out│ + ╚═════════╝ └───────┘ + + A measurement acting on mode ``0`` has input wires on the ket side and classical output wires: + + ┌──────┐ 0 ╔═════════════╗ 0 ┌─────────────┐ + │Ket in│─────▶║ Measurement ║─────▶│Classical out│ + └──────┘ ╚═════════════╝ └─────────────┘ + + The ``Wires`` class can then be used to create subsets of wires: + + .. code-block:: + + >>> from mrmustard.physics.wires import Wires + + >>> modes_out_bra={0, 1} + >>> modes_in_bra={1, 2} + >>> modes_out_ket={0, 13} + >>> modes_in_ket={1, 2, 13} + >>> w = Wires(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket) + + >>> # all the modes + >>> modes = w.modes + >>> assert w.modes == {0, 1, 2, 13} + + >>> # input/output modes + >>> assert w.input.modes == {1, 2, 13} + >>> assert w.output.modes == {0, 1, 13} + + >>> # get ket/bra modes + >>> assert w.ket.modes == {0, 1, 2, 13} + >>> assert w.bra.modes == {0, 1, 2} + + >>> # combined subsets + >>> assert w.output.ket.modes == {0, 13} + >>> assert w.input.bra.modes == {1, 2} + + Here's a diagram of the original ``Wires`` object in the example above, + with the indices of the wires (the number in parenthesis) given in the "standard" order + (``bra_out``, ``bra_in``, ``ket_out``, ``ket_in``, and the modes in sorted increasing order): + + .. code-block:: + + ╔═════════════╗ + 1 (2) ─────▶ ║ ║─────▶ 0 (0) + 2 (3) ─────▶ ║ ║─────▶ 1 (1) + ║ ║ + ║ ``Wires`` ║ + 1 (6) ─────▶ ║ ║ + 2 (7) ─────▶ ║ ║─────▶ 0 (4) + 13 (8) ─────▶ ║ ║─────▶ 13 (5) + ╚═════════════╝ + + To access the index of a subset of wires in standard order we can use the ``indices`` + property: + + .. code-block:: + + >>> assert w.indices == (0,1,2,3,4,5,6,7,8) + >>> assert w.input.indices == (2,3,6,7,8) + + Another important application of the ``Wires`` class is to contract the wires of two components. + This is done using the ``@`` operator. The result is a new ``Wires`` object that combines the wires + of the two components. Here's an example of a contraction of a single-mode density matrix going + into a single-mode channel: + + .. code-block:: + + >>> rho = Wires(modes_out_bra={0}, modes_in_bra={0}) + >>> Phi = Wires(modes_out_bra={0}, modes_in_bra={0}, modes_out_ket={0}, modes_in_ket={0}) + >>> rho_out, perm = rho @ Phi + >>> assert rho_out.modes == {0} + + Here's a diagram of the result of the contraction: + + .. code-block:: + + ╔═══════╗ ╔═══════╗ + ║ ║─────▶║ ║─────▶ 0 + ║ rho ║ ║ Phi ║ + ║ ║─────▶║ ║─────▶ 0 + ╚═══════╝ ╚═══════╝ + + The permutations that standardize the CV and DV variables of the contracted representations are also returned. + + Args: + modes_out_bra: The output modes on the bra side. + modes_in_bra: The input modes on the bra side. + modes_out_ket: The output modes on the ket side. + modes_in_ket: The input modes on the ket side. + classical_out: The output modes for classical information. + classical_in: The input modes for classical information. + FOCK: The modes that are in Fock representation. + + Returns: + A ``Wires`` object, and the permutations that standardize the CV and DV variables. + """ + + def __init__( + self, + modes_out_bra: set[int] | None = None, + modes_in_bra: set[int] | None = None, + modes_out_ket: set[int] | None = None, + modes_in_ket: set[int] | None = None, + classical_out: set[int] | None = None, + classical_in: set[int] | None = None, + FOCK: set[int] | None = None, + ): + self.id = randint(0, 2**32 - 1) + self.quantum_wires = set() + self.classical_wires = set() + self.FOCK = FOCK or set() + + for i, m in enumerate(sorted(modes_out_bra or [])): + self.quantum_wires.add( + QuantumWire( + mode=m, + is_out=True, + is_ket=False, + repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, + index=i, + ) + ) + n = len(modes_out_bra or []) + for i, m in enumerate(sorted(modes_in_bra or [])): + self.quantum_wires.add( + QuantumWire( + mode=m, + is_out=False, + is_ket=False, + repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, + index=n + i, + ) + ) + n += len(modes_in_bra or []) + for i, m in enumerate(sorted(modes_out_ket or [])): + self.quantum_wires.add( + QuantumWire( + mode=m, + is_out=True, + is_ket=True, + repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, + index=n + i, + ) + ) + n += len(modes_out_ket or []) + for i, m in enumerate(sorted(modes_in_ket or [])): + self.quantum_wires.add( + QuantumWire( + mode=m, + is_out=False, + is_ket=True, + repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, + index=n + i, + ) + ) + n += len(modes_in_ket or []) + for i, m in enumerate(sorted(classical_out or [])): + self.classical_wires.add( + ClassicalWire( + mode=m, + is_out=True, + repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, + index=n + i, + ) + ) + n += len(classical_out or []) + for i, m in enumerate(sorted(classical_in or [])): + self.classical_wires.add( + ClassicalWire( + mode=m, + is_out=False, + repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, + index=n + i, + ) + ) + + def copy(self) -> Wires: + return deepcopy(self) + + ###### TRANSFORMATIONS ###### + + @cached_property + def adjoint(self) -> Wires: + r""" + New ``Wires`` object with the adjoint quantum wires (ket becomes bra and vice versa). + """ + w = self.copy() + for q in w.quantum_wires: + q.is_ket = not q.is_ket + return w + + @cached_property + def dual(self) -> Wires: + r""" + New ``Wires`` object with dual quantum and classical wires (input becomes output and vice versa). + """ + w = self.copy() + for q in w.quantum_wires: + q.is_out = not q.is_out + for c in w.classical_wires: + c.is_out = not c.is_out + return w + + ###### SUBSETS ###### + + @lru_cache + def __getitem__(self, modes: tuple[int, ...] | int) -> Wires: + """ + Returns the quantum and classical wires with the given modes. + """ + modes = {modes} if isinstance(modes, int) else set(modes) + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires.copy() if q.mode in modes} + w.classical_wires = {c for c in self.classical_wires.copy() if c.mode in modes} + return w + + @cached_property + def classical(self) -> Wires: + r""" + New ``Wires`` object with only classical wires. + """ + w = Wires() + w.classical_wires = self.classical_wires.copy() + return w + + @cached_property + def quantum(self) -> Wires: + r""" + New ``Wires`` object with only quantum wires. + """ + w = Wires() + w.quantum_wires = self.quantum_wires.copy() + return w + + @cached_property + def bra(self) -> Wires: + r""" + New ``Wires`` object with only quantum bra wires. + """ + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires.copy() if not q.is_ket} + return w + + @cached_property + def ket(self) -> Wires: + r""" + New ``Wires`` object with only quantum ket wires. + """ + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires.copy() if q.is_ket} + return w + + @cached_property + def input(self) -> Wires: + r""" + New ``Wires`` object with only classical and quantum input wires. + """ + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires.copy() if not q.is_out} + w.classical_wires = {c for c in self.classical_wires.copy() if not c.is_out} + return w + + @cached_property + def output(self) -> Wires: + r""" + New ``Wires`` object with only classical and quantum output wires. + """ + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires.copy() if q.is_out} + w.classical_wires = {c for c in self.classical_wires.copy() if c.is_out} + return w + + ###### PROPERTIES ###### + + @cached_property + def ids(self) -> tuple[int, ...]: + r""" + The ids of the wires in standard order. + """ + return tuple(w.id for w in self.sorted_wires) + + @cached_property + def DV_indices(self) -> tuple[int, ...]: + r""" + The indices of the DV wires (both quantum and classical) in standard order. + """ + return tuple(q.index for q in self.DV_wires) + + @cached_property + def CV_indices(self) -> tuple[int, ...]: + r""" + The indices of the CV wires (both quantum and classical) in standard order. + """ + return tuple(q.index for q in self.CV_wires) + + @cached_property + def DV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: + r""" + The DV wires in standard order. + """ + return tuple(w for w in self.sorted_wires if w.is_dv) + + @cached_property + def indices(self) -> tuple[int, ...]: + r""" + The indices of the wires in standard order. + """ + return tuple(w.index for w in self.sorted_wires) + + @cached_property + def CV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: + r""" + The CV wires in standard order. + """ + return tuple(w for w in self.sorted_wires if not w.is_dv) + + @cached_property + def modes(self) -> set[int]: + r""" + The modes spanned by the wires. + """ + return {q.mode for q in self.quantum_wires} | {c.mode for c in self.classical_wires} + + @cached_property + def args(self) -> tuple[set[int], ...]: + r""" + The arguments to pass to ``Wires`` to create the same object. + """ + return ( + self.bra.output.modes, + self.bra.input.modes, + self.ket.output.modes, + self.ket.input.modes, + self.classical.output.modes, + self.classical.input.modes, + self.FOCK, + ) + + @cached_property + def wires(self) -> set[QuantumWire | ClassicalWire]: + r""" + A set of all wires. + """ + return {*self.quantum_wires, *self.classical_wires} + + @cached_property + def sorted_wires(self) -> list[QuantumWire | ClassicalWire]: + r""" + A list of all wires in standard order. + """ + return [ + *sorted(self.bra.output.wires, key=lambda s: s.mode), + *sorted(self.bra.input.wires, key=lambda s: s.mode), + *sorted(self.ket.output.wires, key=lambda s: s.mode), + *sorted(self.ket.input.wires, key=lambda s: s.mode), + *sorted(self.classical.output.wires, key=lambda s: s.mode), + *sorted(self.classical.input.wires, key=lambda s: s.mode), + ] + + ###### METHODS ###### + + def reindex(self) -> None: + r""" + Updates the indices of the wires according to the standard order. + """ + for i, w in enumerate(self.sorted_wires): + w.index = i + + def __add__(self, other: Wires) -> Wires: + r""" + New ``Wires`` object that combines the wires of self and other. + If there are overlapping wires (same mode, is_ket, is_out), raises a ValueError. + """ + if ovlp_classical := self.classical_wires & other.classical_wires: + raise ValueError(f"Overlapping classical wires {ovlp_classical}.") + if ovlp_quantum := self.quantum_wires & other.quantum_wires: + raise ValueError(f"Overlapping quantum wires {ovlp_quantum}.") + w = Wires() + w.quantum_wires = self.quantum_wires | other.quantum_wires + w.classical_wires = self.classical_wires | other.classical_wires + w.reindex() + return w + + def __sub__(self, other: Wires) -> Wires: + r""" + New ``Wires`` object that removes the wires of other from self, by mode. + Note it does not look at ket, bra, input or output: just the mode. Use with caution. + """ + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires.copy() if q.mode not in other.modes} + w.classical_wires = {c for c in self.classical_wires.copy() if c.mode not in other.modes} + w.reindex() + return w + + def __bool__(self) -> bool: + return bool(self.quantum_wires) or bool(self.classical_wires) + + def __hash__(self) -> int: + return hash(tuple(tuple(sorted(subset)) for subset in self.args)) + + def __eq__(self, other: Wires) -> bool: + return ( + self.quantum_wires == other.quantum_wires + and self.classical_wires == other.classical_wires + ) + + def __len__(self) -> int: + return len(self.quantum_wires) + len(self.classical_wires) + + def __repr__(self) -> str: + return ( + f"Wires(modes_out_bra={self.output.bra.modes}, " + f"modes_in_bra={self.input.bra.modes}, " + f"modes_out_ket={self.output.ket.modes}, " + f"modes_in_ket={self.input.ket.modes}, " + f"classical_out={self.output.classical.modes}, " + f"classical_in={self.input.classical.modes}, " + f"FOCK={self.FOCK})" + ) + + def __matmul__(self, other: Wires) -> tuple[Wires, list[int], list[int]]: + r""" + Returns the ``Wires`` for the circuit component resulting from the composition of self and other. + Returns also the permutations of the CV and DV wires to reorder the wires to standard order. + Consider the following example: + + .. code-block:: + + ╔═══════╗ ╔═══════╗ + B───║ self ║───A D───║ other ║───C + b───║ ║───a d───║ ║───c + ╚═══════╝ ╚═══════╝ + + B and D-A must not overlap, same for b and d-a, etc. The result is a new ``Wires`` object + + .. code-block:: + + ╔═══════╗ + B+(D-A)────║self @ ║────C+(A-D) + b+(d-a)────║ other ║────c+(a-d) + ╚═══════╝ + + Using the permutations, it is possible to write: + + .. code-block:: + + ansatz = ansatz1[idx1] @ ansatz2[idx2] # not in standard order + wires, perm_CV, perm_DV = wires1 @ wires2 # matmul the wires + ansatz = ansatz.reorder(perm_CV, perm_DV) # now in standard order + + Args: + other: The wires of the other circuit component. + + Returns: + The wires of the circuit composition and the permutations. + """ + bra_out = other.output.bra + (self.output.bra - other.input.bra) + ket_out = other.output.ket + (self.output.ket - other.input.ket) + bra_in = self.input.bra + (other.input.bra - self.output.bra) + ket_in = self.input.ket + (other.input.ket - self.output.ket) + cl_out = other.classical.output + (self.classical.output - other.classical.input) + cl_in = self.classical.input + (other.classical.input - self.classical.output) + + # get the wires + w = Wires() + w.quantum_wires = (bra_out + bra_in + ket_out + ket_in).wires + w.classical_wires = (cl_out + cl_in).wires + w.reindex() + + # get the permutations + CV_ids = [w.id for w in w.CV_wires if w.id in self.ids] + [ + w.id for w in w.CV_wires if w.id in other.ids + ] + DV_ids = [w.id for w in w.DV_wires if w.id in self.ids] + [ + w.id for w in w.DV_wires if w.id in other.ids + ] + CV_perm = [CV_ids.index(w.id) for w in w.CV_wires] + DV_perm = [DV_ids.index(w.id) for w in w.DV_wires] + return w, CV_perm, DV_perm + + def _ipython_display_(self): + display(widgets.wires(self)) diff --git a/mrmustard/widgets/__init__.py b/mrmustard/widgets/__init__.py index 4c9026dd0..95c708b7f 100644 --- a/mrmustard/widgets/__init__.py +++ b/mrmustard/widgets/__init__.py @@ -177,7 +177,7 @@ def wires(obj): def mode_to_str(m): max_modes = 3 - result = ", ".join(list(map(str, m.modes))[:max_modes]) + result = ", ".join(list(map(str, sorted(m.modes)))[:max_modes]) return (result + ", ...") if len(m) > max_modes else result mode_div = """ @@ -202,7 +202,7 @@ def mode_to_str(m): label = labels[i] title_row = f'{label}' - table_data = [f"{m}{mode[m].indices[0]}" for m in mode.modes] + table_data = [f"{m}{mode[m].indices[0]}" for m in sorted(mode.modes)] wire_tables.append(title_row + "".join(table_data)) index_table = f""" diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index e0a537de0..3f9e96870 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -21,7 +21,7 @@ import pytest from ipywidgets import HTML -from mrmustard.physics.wires import Wires +from mrmustard.physics.new_wires import Wires class TestWires: @@ -31,42 +31,7 @@ class TestWires: def test_init(self): w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}) - assert w.args == ({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}) - - def test_ids(self): - w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}) - assert w.ids == [w.id + i for i in range(9)] - - def test_ids_with_subsets(self): - w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9, 10}, {11}) - - assert w.input.ids == [w.ids[3], w.ids[4], w.ids[5], w.ids[8], w.ids[11]] - assert w.output.ids == [ - w.ids[0], - w.ids[1], - w.ids[2], - w.ids[6], - w.ids[7], - w.ids[9], - w.ids[10], - ] - assert w.bra.ids == [w.ids[0], w.ids[1], w.ids[2], w.ids[3], w.ids[4], w.ids[5]] - assert w.ket.ids == [w.ids[6], w.ids[7], w.ids[8]] - assert w.quantum.ids == [ - w.ids[0], - w.ids[1], - w.ids[2], - w.ids[3], - w.ids[4], - w.ids[5], - w.ids[6], - w.ids[7], - w.ids[8], - ] - assert w.classical.ids == [w.ids[9], w.ids[10], w.ids[11]] - - assert w.output.bra.ids == [w.ids[0], w.ids[1], w.ids[2]] - assert w.input.bra.ids == [w.ids[3], w.ids[4], w.ids[5]] + assert w.args == ({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}, set()) def test_indices(self): w = Wires({0, 10, 20}, {30, 40, 50}, {60, 70}, {80}) @@ -92,44 +57,6 @@ def test_wire_subsets(self): assert w.output.ket.modes == {2} assert w.input.ket.modes == {3} - def test_index_dicts(self): - w = Wires({0, 2, 1}, {6, 7, 8}, {3, 4}, {4}, {5}, {9}) - d = [{0: 0, 1: 1, 2: 2}, {6: 3, 7: 4, 8: 5}, {3: 6, 4: 7}, {4: 8}, {5: 9}, {9: 10}] - - assert w.index_dicts == d - assert w.input.index_dicts == d - assert w.input.bra.index_dicts == d - - def test_ids_dicts(self): - w = Wires({0, 2, 1}, {6, 7, 8}, {3, 4}, {4}, {5}, {9}) - d = [ - {0: w.id, 1: w.id + 1, 2: w.id + 2}, - {6: w.id + 3, 7: w.id + 4, 8: w.id + 5}, - {3: w.id + 6, 4: w.id + 7}, - {4: w.id + 8}, - {5: w.id + 9}, - {9: w.id + 10}, - ] - - assert w.ids_dicts == d - assert w.input.ids_dicts == d - assert w.input.bra.ids_dicts == d - - def test_ids_index_dicts(self): - w = Wires({0, 2, 1}, {6, 7, 8}, {3, 4}, {4}, {5}, {9}) - d = [ - {w.id: 0, w.id + 1: 1, w.id + 2: 2}, - {w.id + 3: 3, w.id + 4: 4, w.id + 5: 5}, - {w.id + 6: 6, w.id + 7: 7}, - {w.id + 8: 8}, - {w.id + 9: 9}, - {w.id + 10: 10}, - ] - - assert w.ids_index_dicts == d - assert w.input.ids_index_dicts == d - assert w.input.bra.ids_index_dicts == d - def test_adjoint(self): w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}) w_adj = w.adjoint @@ -168,19 +95,12 @@ def test_getitem(self): w0 = Wires({0}, {0}) assert w[0] == w0 - assert w._mode_cache == {(0,): w0} w1 = Wires({1}) assert w[1] == w1 - assert w._mode_cache == {(0,): w0, (1,): w1} w2 = Wires(set(), {2}) assert w[2] == w2 - assert w._mode_cache == { - (0,): w0, - (1,): w1, - (2,): w2, - } assert w[0].indices == (0, 2) assert w[1].indices == (1,) @@ -206,7 +126,7 @@ def test_matmul(self): # contracts 17,17 on classical u = Wires({1, 5}, {2, 6, 15}, {3, 7, 13}, {4, 8}, {16, 17}, {18}) v = Wires({0, 9, 14}, {1, 10}, {2, 11}, {13, 3, 12}, {19}, {17}) - new_wires, perm = u @ v + new_wires, CV_perm, DV_perm = u @ v assert new_wires.args == ( {0, 5, 9, 14}, {2, 6, 10, 15}, @@ -214,8 +134,9 @@ def test_matmul(self): {4, 8, 12}, {16, 19}, {18}, + set(), ) - assert perm == [9, 0, 10, 11, 1, 2, 12, 3, 13, 4, 14, 5, 6, 15, 7, 16, 8] + assert CV_perm == [9, 0, 10, 11, 1, 2, 12, 3, 13, 4, 14, 5, 6, 15, 7, 16, 8] def test_matmul_keeps_ids(self): U = Wires(set(), set(), {0}, {0}) From c57a72fd0bd930c41eae5a3a3a05d0c0c88bc55d Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Thu, 14 Nov 2024 10:26:10 -0800 Subject: [PATCH 02/13] updates --- mrmustard/physics/new_wires.py | 155 +++++++++++++++++-------------- tests/test_physics/test_wires.py | 5 +- 2 files changed, 88 insertions(+), 72 deletions(-) diff --git a/mrmustard/physics/new_wires.py b/mrmustard/physics/new_wires.py index 4417167d2..5819c23dd 100644 --- a/mrmustard/physics/new_wires.py +++ b/mrmustard/physics/new_wires.py @@ -1,10 +1,11 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Sequence from random import randint from copy import deepcopy -from functools import cached_property, lru_cache from enum import Enum, auto from IPython.display import display +from functools import lru_cache, cached_property from mrmustard import widgets @@ -20,7 +21,17 @@ class Repr(Enum): CHARACTERISTIC = auto() -@dataclass +class WiresType(Enum): + DM_LIKE = auto() # only output ket and bra on same modes + KET_LIKE = auto() # only output ket + UNITARY_LIKE = auto() # such that can map ket to ket + CHANNEL_LIKE = auto() # such that can map dm to dm + PROJ_MEAS_LIKE = auto() # only input ket + POVM_LIKE = auto() # only input ket and input bra on same modes + CLASSICAL_LIKE = auto() # only classical wires + + +@dataclass(slots=True) class QuantumWire: mode: int is_out: bool @@ -37,19 +48,15 @@ def __hash__(self) -> int: return hash((self.mode, self.is_out, self.is_ket)) def __repr__(self) -> str: - return f"QuantumWire(mode={self.mode}, out={self.is_out}, ket={self.is_ket}, dv={self.is_dv}, repr={self.repr}, index={self.index})" + return f"QuantumWire(mode={self.mode}, out={self.is_out}, ket={self.is_ket}, repr={self.repr}, index={self.index})" def __eq__(self, other: QuantumWire) -> bool: return ( - self.mode == other.mode - and self.is_out == other.is_out - and self.is_ket == other.is_ket - and self.is_dv == other.is_dv - and self.repr == other.repr + self.mode == other.mode and self.is_out == other.is_out and self.is_ket == other.is_ket ) -@dataclass +@dataclass(slots=True) class ClassicalWire: mode: int is_out: bool @@ -65,10 +72,10 @@ def __hash__(self) -> int: return hash((self.mode, self.is_out, self.is_dv)) def __repr__(self) -> str: - return f"ClassicalWire(mode={self.mode}, out={self.is_out}, dv={self.is_dv}, index={self.index})" + return f"ClassicalWire(mode={self.mode}, out={self.is_out}, repr={self.repr}, index={self.index})" def __eq__(self, other: ClassicalWire) -> bool: - return self.mode == other.mode and self.is_out == other.is_out and self.is_dv == other.is_dv + return self.mode == other.mode and self.is_out == other.is_out class Wires: @@ -196,7 +203,7 @@ class Wires: ║ ║─────▶║ ║─────▶ 0 ╚═══════╝ ╚═══════╝ - The permutations that standardize the CV and DV variables of the contracted representations are also returned. + The permutations that standardize the CV and DV variables of the contracted reprs are also returned. Args: modes_out_bra: The output modes on the bra side. @@ -205,7 +212,6 @@ class Wires: modes_in_ket: The input modes on the ket side. classical_out: The output modes for classical information. classical_in: The input modes for classical information. - FOCK: The modes that are in Fock representation. Returns: A ``Wires`` object, and the permutations that standardize the CV and DV variables. @@ -213,79 +219,70 @@ class Wires: def __init__( self, - modes_out_bra: set[int] | None = None, - modes_in_bra: set[int] | None = None, - modes_out_ket: set[int] | None = None, - modes_in_ket: set[int] | None = None, - classical_out: set[int] | None = None, - classical_in: set[int] | None = None, - FOCK: set[int] | None = None, + modes_out_bra: Sequence[int] = (), + modes_in_bra: Sequence[int] = (), + modes_out_ket: Sequence[int] = (), + modes_in_ket: Sequence[int] = (), + classical_out: Sequence[int] = (), + classical_in: Sequence[int] = (), ): - self.id = randint(0, 2**32 - 1) self.quantum_wires = set() self.classical_wires = set() - self.FOCK = FOCK or set() - for i, m in enumerate(sorted(modes_out_bra or [])): + for i, m in enumerate(sorted(modes_out_bra)): self.quantum_wires.add( QuantumWire( mode=m, is_out=True, is_ket=False, - repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, index=i, ) ) - n = len(modes_out_bra or []) - for i, m in enumerate(sorted(modes_in_bra or [])): + n = len(modes_out_bra) + for i, m in enumerate(sorted(modes_in_bra)): self.quantum_wires.add( QuantumWire( mode=m, is_out=False, is_ket=False, - repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, index=n + i, ) ) - n += len(modes_in_bra or []) - for i, m in enumerate(sorted(modes_out_ket or [])): + n += len(modes_in_bra) + for i, m in enumerate(sorted(modes_out_ket)): self.quantum_wires.add( QuantumWire( mode=m, is_out=True, is_ket=True, - repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, index=n + i, ) ) - n += len(modes_out_ket or []) - for i, m in enumerate(sorted(modes_in_ket or [])): + n += len(modes_out_ket) + for i, m in enumerate(sorted(modes_in_ket)): self.quantum_wires.add( QuantumWire( mode=m, is_out=False, is_ket=True, - repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, index=n + i, ) ) - n += len(modes_in_ket or []) - for i, m in enumerate(sorted(classical_out or [])): + n += len(modes_in_ket) + for i, m in enumerate(sorted(classical_out)): self.classical_wires.add( ClassicalWire( mode=m, is_out=True, - repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, index=n + i, ) ) - n += len(classical_out or []) - for i, m in enumerate(sorted(classical_in or [])): + n += len(classical_out) + for i, m in enumerate(sorted(classical_in)): self.classical_wires.add( ClassicalWire( mode=m, is_out=False, - repr=Repr.FOCK if m in self.FOCK else Repr.UNSPECIFIED, index=n + i, ) ) @@ -295,7 +292,7 @@ def copy(self) -> Wires: ###### TRANSFORMATIONS ###### - @cached_property + @property def adjoint(self) -> Wires: r""" New ``Wires`` object with the adjoint quantum wires (ket becomes bra and vice versa). @@ -305,7 +302,7 @@ def adjoint(self) -> Wires: q.is_ket = not q.is_ket return w - @cached_property + @property def dual(self) -> Wires: r""" New ``Wires`` object with dual quantum and classical wires (input becomes output and vice versa). @@ -326,8 +323,8 @@ def __getitem__(self, modes: tuple[int, ...] | int) -> Wires: """ modes = {modes} if isinstance(modes, int) else set(modes) w = Wires() - w.quantum_wires = {q for q in self.quantum_wires.copy() if q.mode in modes} - w.classical_wires = {c for c in self.classical_wires.copy() if c.mode in modes} + w.quantum_wires = {q for q in self.quantum_wires if q.mode in modes} + w.classical_wires = {c for c in self.classical_wires if c.mode in modes} return w @cached_property @@ -336,7 +333,7 @@ def classical(self) -> Wires: New ``Wires`` object with only classical wires. """ w = Wires() - w.classical_wires = self.classical_wires.copy() + w.classical_wires = self.classical_wires return w @cached_property @@ -345,7 +342,7 @@ def quantum(self) -> Wires: New ``Wires`` object with only quantum wires. """ w = Wires() - w.quantum_wires = self.quantum_wires.copy() + w.quantum_wires = self.quantum_wires return w @cached_property @@ -354,7 +351,7 @@ def bra(self) -> Wires: New ``Wires`` object with only quantum bra wires. """ w = Wires() - w.quantum_wires = {q for q in self.quantum_wires.copy() if not q.is_ket} + w.quantum_wires = {q for q in self.quantum_wires if not q.is_ket} return w @cached_property @@ -363,7 +360,7 @@ def ket(self) -> Wires: New ``Wires`` object with only quantum ket wires. """ w = Wires() - w.quantum_wires = {q for q in self.quantum_wires.copy() if q.is_ket} + w.quantum_wires = {q for q in self.quantum_wires if q.is_ket} return w @cached_property @@ -372,8 +369,8 @@ def input(self) -> Wires: New ``Wires`` object with only classical and quantum input wires. """ w = Wires() - w.quantum_wires = {q for q in self.quantum_wires.copy() if not q.is_out} - w.classical_wires = {c for c in self.classical_wires.copy() if not c.is_out} + w.quantum_wires = {q for q in self.quantum_wires if not q.is_out} + w.classical_wires = {c for c in self.classical_wires if not c.is_out} return w @cached_property @@ -382,12 +379,23 @@ def output(self) -> Wires: New ``Wires`` object with only classical and quantum output wires. """ w = Wires() - w.quantum_wires = {q for q in self.quantum_wires.copy() if q.is_out} - w.classical_wires = {c for c in self.classical_wires.copy() if c.is_out} + w.quantum_wires = {q for q in self.quantum_wires if q.is_out} + w.classical_wires = {c for c in self.classical_wires if c.is_out} return w ###### PROPERTIES ###### + @cached_property + def id(self) -> int: + return randint(0, 2**32 - 1) + + @cached_property + def modes(self) -> set[int]: + r""" + The modes spanned by the wires. + """ + return {q.mode for q in self.quantum_wires} | {c.mode for c in self.classical_wires} + @cached_property def ids(self) -> tuple[int, ...]: r""" @@ -395,6 +403,13 @@ def ids(self) -> tuple[int, ...]: """ return tuple(w.id for w in self.sorted_wires) + @cached_property + def indices(self) -> tuple[int, ...]: + r""" + The indices of the wires in standard order. + """ + return tuple(w.index for w in self.sorted_wires) + @cached_property def DV_indices(self) -> tuple[int, ...]: r""" @@ -414,33 +429,19 @@ def DV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: r""" The DV wires in standard order. """ - return tuple(w for w in self.sorted_wires if w.is_dv) - - @cached_property - def indices(self) -> tuple[int, ...]: - r""" - The indices of the wires in standard order. - """ - return tuple(w.index for w in self.sorted_wires) + return tuple(w for w in self.sorted_wires.copy() if w.is_dv) @cached_property def CV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: r""" The CV wires in standard order. """ - return tuple(w for w in self.sorted_wires if not w.is_dv) - - @cached_property - def modes(self) -> set[int]: - r""" - The modes spanned by the wires. - """ - return {q.mode for q in self.quantum_wires} | {c.mode for c in self.classical_wires} + return tuple(w for w in self.sorted_wires.copy() if not w.is_dv) @cached_property def args(self) -> tuple[set[int], ...]: r""" - The arguments to pass to ``Wires`` to create the same object. + The arguments to pass to ``Wires`` to create the same object with fresh wires. """ return ( self.bra.output.modes, @@ -449,7 +450,6 @@ def args(self) -> tuple[set[int], ...]: self.ket.input.modes, self.classical.output.modes, self.classical.input.modes, - self.FOCK, ) @cached_property @@ -475,6 +475,20 @@ def sorted_wires(self) -> list[QuantumWire | ClassicalWire]: ###### METHODS ###### + def wire(self, mode: int, is_out: bool, is_ket: bool) -> QuantumWire | ClassicalWire: + r""" + Returns the wire with the given mode, ket, and output status. + """ + if quantum := [ + w + for w in self.quantum_wires + if w.mode == mode and w.is_out == is_out and w.is_ket == is_ket + ]: + return quantum[0] + if classical := [w for w in self.classical_wires if w.mode == mode and w.is_out == is_out]: + return classical[0] + raise ValueError(f"No wire with mode {mode}, is_out {is_out}, and is_ket {is_ket}.") + def reindex(self) -> None: r""" Updates the indices of the wires according to the standard order. @@ -530,8 +544,7 @@ def __repr__(self) -> str: f"modes_out_ket={self.output.ket.modes}, " f"modes_in_ket={self.input.ket.modes}, " f"classical_out={self.output.classical.modes}, " - f"classical_in={self.input.classical.modes}, " - f"FOCK={self.FOCK})" + f"classical_in={self.input.classical.modes})" ) def __matmul__(self, other: Wires) -> tuple[Wires, list[int], list[int]]: diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index 3f9e96870..f9762d483 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -21,7 +21,7 @@ import pytest from ipywidgets import HTML -from mrmustard.physics.new_wires import Wires +from mrmustard.physics.new_wires import Wires, Repr class TestWires: @@ -33,6 +33,9 @@ def test_init(self): w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}) assert w.args == ({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}, set()) + w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}, FOCK={1}) + assert w.wire(mode=1, is_ket=False, is_out=True).repr == Repr.FOCK + def test_indices(self): w = Wires({0, 10, 20}, {30, 40, 50}, {60, 70}, {80}) assert w.indices == (0, 1, 2, 3, 4, 5, 6, 7, 8) From a63c66f3bfd823640acce7c701355829ac39e17c Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Thu, 14 Nov 2024 10:29:05 -0800 Subject: [PATCH 03/13] better format --- mrmustard/physics/new_wires.py | 52 ++++---------------------------- tests/test_physics/test_wires.py | 2 +- 2 files changed, 7 insertions(+), 47 deletions(-) diff --git a/mrmustard/physics/new_wires.py b/mrmustard/physics/new_wires.py index 5819c23dd..1ca6ba14c 100644 --- a/mrmustard/physics/new_wires.py +++ b/mrmustard/physics/new_wires.py @@ -230,62 +230,22 @@ def __init__( self.classical_wires = set() for i, m in enumerate(sorted(modes_out_bra)): - self.quantum_wires.add( - QuantumWire( - mode=m, - is_out=True, - is_ket=False, - index=i, - ) - ) + self.quantum_wires.add(QuantumWire(mode=m, is_out=True, is_ket=False, index=i)) n = len(modes_out_bra) for i, m in enumerate(sorted(modes_in_bra)): - self.quantum_wires.add( - QuantumWire( - mode=m, - is_out=False, - is_ket=False, - index=n + i, - ) - ) + self.quantum_wires.add(QuantumWire(mode=m, is_out=False, is_ket=False, index=n + i)) n += len(modes_in_bra) for i, m in enumerate(sorted(modes_out_ket)): - self.quantum_wires.add( - QuantumWire( - mode=m, - is_out=True, - is_ket=True, - index=n + i, - ) - ) + self.quantum_wires.add(QuantumWire(mode=m, is_out=True, is_ket=True, index=n + i)) n += len(modes_out_ket) for i, m in enumerate(sorted(modes_in_ket)): - self.quantum_wires.add( - QuantumWire( - mode=m, - is_out=False, - is_ket=True, - index=n + i, - ) - ) + self.quantum_wires.add(QuantumWire(mode=m, is_out=False, is_ket=True, index=n + i)) n += len(modes_in_ket) for i, m in enumerate(sorted(classical_out)): - self.classical_wires.add( - ClassicalWire( - mode=m, - is_out=True, - index=n + i, - ) - ) + self.classical_wires.add(ClassicalWire(mode=m, is_out=True, index=n + i)) n += len(classical_out) for i, m in enumerate(sorted(classical_in)): - self.classical_wires.add( - ClassicalWire( - mode=m, - is_out=False, - index=n + i, - ) - ) + self.classical_wires.add(ClassicalWire(mode=m, is_out=False, index=n + i)) def copy(self) -> Wires: return deepcopy(self) diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index f9762d483..f8f136207 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -21,7 +21,7 @@ import pytest from ipywidgets import HTML -from mrmustard.physics.new_wires import Wires, Repr +from mrmustard.physics.new_wires import Repr, Wires class TestWires: From a5dd9937e780856b64894b829eaf4dbd78447b02 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Thu, 14 Nov 2024 10:31:28 -0800 Subject: [PATCH 04/13] replace old wires --- mrmustard/physics/new_wires.py | 571 ----------------------------- mrmustard/physics/wires.py | 592 ++++++++++++++++--------------- tests/test_physics/test_wires.py | 2 +- 3 files changed, 308 insertions(+), 857 deletions(-) delete mode 100644 mrmustard/physics/new_wires.py diff --git a/mrmustard/physics/new_wires.py b/mrmustard/physics/new_wires.py deleted file mode 100644 index 1ca6ba14c..000000000 --- a/mrmustard/physics/new_wires.py +++ /dev/null @@ -1,571 +0,0 @@ -from __future__ import annotations -from dataclasses import dataclass, field -from typing import Sequence -from random import randint -from copy import deepcopy -from enum import Enum, auto -from IPython.display import display -from functools import lru_cache, cached_property - -from mrmustard import widgets - -__all__ = ["Wires"] - - -class Repr(Enum): - UNSPECIFIED = auto() - BARGMANN = auto() - FOCK = auto() - QUADRATURE = auto() - PHASESPACE = auto() - CHARACTERISTIC = auto() - - -class WiresType(Enum): - DM_LIKE = auto() # only output ket and bra on same modes - KET_LIKE = auto() # only output ket - UNITARY_LIKE = auto() # such that can map ket to ket - CHANNEL_LIKE = auto() # such that can map dm to dm - PROJ_MEAS_LIKE = auto() # only input ket - POVM_LIKE = auto() # only input ket and input bra on same modes - CLASSICAL_LIKE = auto() # only classical wires - - -@dataclass(slots=True) -class QuantumWire: - mode: int - is_out: bool - is_ket: bool - index: int - repr: Repr = Repr.UNSPECIFIED - id: int = field(default_factory=lambda: randint(0, 2**32 - 1)) - - @property - def is_dv(self) -> bool: - return self.repr == Repr.FOCK - - def __hash__(self) -> int: - return hash((self.mode, self.is_out, self.is_ket)) - - def __repr__(self) -> str: - return f"QuantumWire(mode={self.mode}, out={self.is_out}, ket={self.is_ket}, repr={self.repr}, index={self.index})" - - def __eq__(self, other: QuantumWire) -> bool: - return ( - self.mode == other.mode and self.is_out == other.is_out and self.is_ket == other.is_ket - ) - - -@dataclass(slots=True) -class ClassicalWire: - mode: int - is_out: bool - index: int - repr: Repr = Repr.UNSPECIFIED - id: int = field(default_factory=lambda: randint(0, 2**32 - 1)) - - @property - def is_dv(self) -> bool: - return self.repr == Repr.FOCK - - def __hash__(self) -> int: - return hash((self.mode, self.is_out, self.is_dv)) - - def __repr__(self) -> str: - return f"ClassicalWire(mode={self.mode}, out={self.is_out}, repr={self.repr}, index={self.index})" - - def __eq__(self, other: ClassicalWire) -> bool: - return self.mode == other.mode and self.is_out == other.is_out - - -class Wires: - r""" - A class with wire functionality for tensor network applications. - - In MrMustard, instances of ``CircuitComponent`` have a ``Wires`` attribute. - The wires describe how they connect with the surrounding components in a tensor network picture, - where states flow from left to right. ``CircuitComponent``\s can have wires on the - bra and/or on the ket side. Additionally, they may have classical wires. Here are some examples - for the types of components available on ``mrmustard.lab_dev``: - - .. code-block:: - - A channel acting on mode ``1`` has input and output wires on both ket and bra sides: - - ┌──────┐ 1 ╔═════════╗ 1 ┌───────┐ - │Bra in│─────▶║ ║─────▶│Bra out│ - └──────┘ ║ Channel ║ └───────┘ - ┌──────┐ 1 ║ ║ 1 ┌───────┐ - │Ket in│─────▶║ ║─────▶│Ket out│ - └──────┘ ╚═════════╝ └───────┘ - - - A unitary acting on mode ``2`` has input and output wires on the ket side: - - ┌──────┐ 2 ╔═════════╗ 2 ┌───────┐ - │Ket in│─────▶║ Unitary ║─────▶│Ket out│ - └──────┘ ╚═════════╝ └───────┘ - - - A density matrix representing the state of mode ``0`` has only output wires: - - ╔═════════╗ 0 ┌───────┐ - ║ ║─────▶│Bra out│ - ║ Density ║ └───────┘ - ║ Matrix ║ 0 ┌───────┐ - ║ ║─────▶│Ket out│ - ╚═════════╝ └───────┘ - - - A ket representing the state of mode ``1`` has only output wires: - - ╔═════════╗ 1 ┌───────┐ - ║ Ket ║─────▶│Ket out│ - ╚═════════╝ └───────┘ - - A measurement acting on mode ``0`` has input wires on the ket side and classical output wires: - - ┌──────┐ 0 ╔═════════════╗ 0 ┌─────────────┐ - │Ket in│─────▶║ Measurement ║─────▶│Classical out│ - └──────┘ ╚═════════════╝ └─────────────┘ - - The ``Wires`` class can then be used to create subsets of wires: - - .. code-block:: - - >>> from mrmustard.physics.wires import Wires - - >>> modes_out_bra={0, 1} - >>> modes_in_bra={1, 2} - >>> modes_out_ket={0, 13} - >>> modes_in_ket={1, 2, 13} - >>> w = Wires(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket) - - >>> # all the modes - >>> modes = w.modes - >>> assert w.modes == {0, 1, 2, 13} - - >>> # input/output modes - >>> assert w.input.modes == {1, 2, 13} - >>> assert w.output.modes == {0, 1, 13} - - >>> # get ket/bra modes - >>> assert w.ket.modes == {0, 1, 2, 13} - >>> assert w.bra.modes == {0, 1, 2} - - >>> # combined subsets - >>> assert w.output.ket.modes == {0, 13} - >>> assert w.input.bra.modes == {1, 2} - - Here's a diagram of the original ``Wires`` object in the example above, - with the indices of the wires (the number in parenthesis) given in the "standard" order - (``bra_out``, ``bra_in``, ``ket_out``, ``ket_in``, and the modes in sorted increasing order): - - .. code-block:: - - ╔═════════════╗ - 1 (2) ─────▶ ║ ║─────▶ 0 (0) - 2 (3) ─────▶ ║ ║─────▶ 1 (1) - ║ ║ - ║ ``Wires`` ║ - 1 (6) ─────▶ ║ ║ - 2 (7) ─────▶ ║ ║─────▶ 0 (4) - 13 (8) ─────▶ ║ ║─────▶ 13 (5) - ╚═════════════╝ - - To access the index of a subset of wires in standard order we can use the ``indices`` - property: - - .. code-block:: - - >>> assert w.indices == (0,1,2,3,4,5,6,7,8) - >>> assert w.input.indices == (2,3,6,7,8) - - Another important application of the ``Wires`` class is to contract the wires of two components. - This is done using the ``@`` operator. The result is a new ``Wires`` object that combines the wires - of the two components. Here's an example of a contraction of a single-mode density matrix going - into a single-mode channel: - - .. code-block:: - - >>> rho = Wires(modes_out_bra={0}, modes_in_bra={0}) - >>> Phi = Wires(modes_out_bra={0}, modes_in_bra={0}, modes_out_ket={0}, modes_in_ket={0}) - >>> rho_out, perm = rho @ Phi - >>> assert rho_out.modes == {0} - - Here's a diagram of the result of the contraction: - - .. code-block:: - - ╔═══════╗ ╔═══════╗ - ║ ║─────▶║ ║─────▶ 0 - ║ rho ║ ║ Phi ║ - ║ ║─────▶║ ║─────▶ 0 - ╚═══════╝ ╚═══════╝ - - The permutations that standardize the CV and DV variables of the contracted reprs are also returned. - - Args: - modes_out_bra: The output modes on the bra side. - modes_in_bra: The input modes on the bra side. - modes_out_ket: The output modes on the ket side. - modes_in_ket: The input modes on the ket side. - classical_out: The output modes for classical information. - classical_in: The input modes for classical information. - - Returns: - A ``Wires`` object, and the permutations that standardize the CV and DV variables. - """ - - def __init__( - self, - modes_out_bra: Sequence[int] = (), - modes_in_bra: Sequence[int] = (), - modes_out_ket: Sequence[int] = (), - modes_in_ket: Sequence[int] = (), - classical_out: Sequence[int] = (), - classical_in: Sequence[int] = (), - ): - self.quantum_wires = set() - self.classical_wires = set() - - for i, m in enumerate(sorted(modes_out_bra)): - self.quantum_wires.add(QuantumWire(mode=m, is_out=True, is_ket=False, index=i)) - n = len(modes_out_bra) - for i, m in enumerate(sorted(modes_in_bra)): - self.quantum_wires.add(QuantumWire(mode=m, is_out=False, is_ket=False, index=n + i)) - n += len(modes_in_bra) - for i, m in enumerate(sorted(modes_out_ket)): - self.quantum_wires.add(QuantumWire(mode=m, is_out=True, is_ket=True, index=n + i)) - n += len(modes_out_ket) - for i, m in enumerate(sorted(modes_in_ket)): - self.quantum_wires.add(QuantumWire(mode=m, is_out=False, is_ket=True, index=n + i)) - n += len(modes_in_ket) - for i, m in enumerate(sorted(classical_out)): - self.classical_wires.add(ClassicalWire(mode=m, is_out=True, index=n + i)) - n += len(classical_out) - for i, m in enumerate(sorted(classical_in)): - self.classical_wires.add(ClassicalWire(mode=m, is_out=False, index=n + i)) - - def copy(self) -> Wires: - return deepcopy(self) - - ###### TRANSFORMATIONS ###### - - @property - def adjoint(self) -> Wires: - r""" - New ``Wires`` object with the adjoint quantum wires (ket becomes bra and vice versa). - """ - w = self.copy() - for q in w.quantum_wires: - q.is_ket = not q.is_ket - return w - - @property - def dual(self) -> Wires: - r""" - New ``Wires`` object with dual quantum and classical wires (input becomes output and vice versa). - """ - w = self.copy() - for q in w.quantum_wires: - q.is_out = not q.is_out - for c in w.classical_wires: - c.is_out = not c.is_out - return w - - ###### SUBSETS ###### - - @lru_cache - def __getitem__(self, modes: tuple[int, ...] | int) -> Wires: - """ - Returns the quantum and classical wires with the given modes. - """ - modes = {modes} if isinstance(modes, int) else set(modes) - w = Wires() - w.quantum_wires = {q for q in self.quantum_wires if q.mode in modes} - w.classical_wires = {c for c in self.classical_wires if c.mode in modes} - return w - - @cached_property - def classical(self) -> Wires: - r""" - New ``Wires`` object with only classical wires. - """ - w = Wires() - w.classical_wires = self.classical_wires - return w - - @cached_property - def quantum(self) -> Wires: - r""" - New ``Wires`` object with only quantum wires. - """ - w = Wires() - w.quantum_wires = self.quantum_wires - return w - - @cached_property - def bra(self) -> Wires: - r""" - New ``Wires`` object with only quantum bra wires. - """ - w = Wires() - w.quantum_wires = {q for q in self.quantum_wires if not q.is_ket} - return w - - @cached_property - def ket(self) -> Wires: - r""" - New ``Wires`` object with only quantum ket wires. - """ - w = Wires() - w.quantum_wires = {q for q in self.quantum_wires if q.is_ket} - return w - - @cached_property - def input(self) -> Wires: - r""" - New ``Wires`` object with only classical and quantum input wires. - """ - w = Wires() - w.quantum_wires = {q for q in self.quantum_wires if not q.is_out} - w.classical_wires = {c for c in self.classical_wires if not c.is_out} - return w - - @cached_property - def output(self) -> Wires: - r""" - New ``Wires`` object with only classical and quantum output wires. - """ - w = Wires() - w.quantum_wires = {q for q in self.quantum_wires if q.is_out} - w.classical_wires = {c for c in self.classical_wires if c.is_out} - return w - - ###### PROPERTIES ###### - - @cached_property - def id(self) -> int: - return randint(0, 2**32 - 1) - - @cached_property - def modes(self) -> set[int]: - r""" - The modes spanned by the wires. - """ - return {q.mode for q in self.quantum_wires} | {c.mode for c in self.classical_wires} - - @cached_property - def ids(self) -> tuple[int, ...]: - r""" - The ids of the wires in standard order. - """ - return tuple(w.id for w in self.sorted_wires) - - @cached_property - def indices(self) -> tuple[int, ...]: - r""" - The indices of the wires in standard order. - """ - return tuple(w.index for w in self.sorted_wires) - - @cached_property - def DV_indices(self) -> tuple[int, ...]: - r""" - The indices of the DV wires (both quantum and classical) in standard order. - """ - return tuple(q.index for q in self.DV_wires) - - @cached_property - def CV_indices(self) -> tuple[int, ...]: - r""" - The indices of the CV wires (both quantum and classical) in standard order. - """ - return tuple(q.index for q in self.CV_wires) - - @cached_property - def DV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: - r""" - The DV wires in standard order. - """ - return tuple(w for w in self.sorted_wires.copy() if w.is_dv) - - @cached_property - def CV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: - r""" - The CV wires in standard order. - """ - return tuple(w for w in self.sorted_wires.copy() if not w.is_dv) - - @cached_property - def args(self) -> tuple[set[int], ...]: - r""" - The arguments to pass to ``Wires`` to create the same object with fresh wires. - """ - return ( - self.bra.output.modes, - self.bra.input.modes, - self.ket.output.modes, - self.ket.input.modes, - self.classical.output.modes, - self.classical.input.modes, - ) - - @cached_property - def wires(self) -> set[QuantumWire | ClassicalWire]: - r""" - A set of all wires. - """ - return {*self.quantum_wires, *self.classical_wires} - - @cached_property - def sorted_wires(self) -> list[QuantumWire | ClassicalWire]: - r""" - A list of all wires in standard order. - """ - return [ - *sorted(self.bra.output.wires, key=lambda s: s.mode), - *sorted(self.bra.input.wires, key=lambda s: s.mode), - *sorted(self.ket.output.wires, key=lambda s: s.mode), - *sorted(self.ket.input.wires, key=lambda s: s.mode), - *sorted(self.classical.output.wires, key=lambda s: s.mode), - *sorted(self.classical.input.wires, key=lambda s: s.mode), - ] - - ###### METHODS ###### - - def wire(self, mode: int, is_out: bool, is_ket: bool) -> QuantumWire | ClassicalWire: - r""" - Returns the wire with the given mode, ket, and output status. - """ - if quantum := [ - w - for w in self.quantum_wires - if w.mode == mode and w.is_out == is_out and w.is_ket == is_ket - ]: - return quantum[0] - if classical := [w for w in self.classical_wires if w.mode == mode and w.is_out == is_out]: - return classical[0] - raise ValueError(f"No wire with mode {mode}, is_out {is_out}, and is_ket {is_ket}.") - - def reindex(self) -> None: - r""" - Updates the indices of the wires according to the standard order. - """ - for i, w in enumerate(self.sorted_wires): - w.index = i - - def __add__(self, other: Wires) -> Wires: - r""" - New ``Wires`` object that combines the wires of self and other. - If there are overlapping wires (same mode, is_ket, is_out), raises a ValueError. - """ - if ovlp_classical := self.classical_wires & other.classical_wires: - raise ValueError(f"Overlapping classical wires {ovlp_classical}.") - if ovlp_quantum := self.quantum_wires & other.quantum_wires: - raise ValueError(f"Overlapping quantum wires {ovlp_quantum}.") - w = Wires() - w.quantum_wires = self.quantum_wires | other.quantum_wires - w.classical_wires = self.classical_wires | other.classical_wires - w.reindex() - return w - - def __sub__(self, other: Wires) -> Wires: - r""" - New ``Wires`` object that removes the wires of other from self, by mode. - Note it does not look at ket, bra, input or output: just the mode. Use with caution. - """ - w = Wires() - w.quantum_wires = {q for q in self.quantum_wires.copy() if q.mode not in other.modes} - w.classical_wires = {c for c in self.classical_wires.copy() if c.mode not in other.modes} - w.reindex() - return w - - def __bool__(self) -> bool: - return bool(self.quantum_wires) or bool(self.classical_wires) - - def __hash__(self) -> int: - return hash(tuple(tuple(sorted(subset)) for subset in self.args)) - - def __eq__(self, other: Wires) -> bool: - return ( - self.quantum_wires == other.quantum_wires - and self.classical_wires == other.classical_wires - ) - - def __len__(self) -> int: - return len(self.quantum_wires) + len(self.classical_wires) - - def __repr__(self) -> str: - return ( - f"Wires(modes_out_bra={self.output.bra.modes}, " - f"modes_in_bra={self.input.bra.modes}, " - f"modes_out_ket={self.output.ket.modes}, " - f"modes_in_ket={self.input.ket.modes}, " - f"classical_out={self.output.classical.modes}, " - f"classical_in={self.input.classical.modes})" - ) - - def __matmul__(self, other: Wires) -> tuple[Wires, list[int], list[int]]: - r""" - Returns the ``Wires`` for the circuit component resulting from the composition of self and other. - Returns also the permutations of the CV and DV wires to reorder the wires to standard order. - Consider the following example: - - .. code-block:: - - ╔═══════╗ ╔═══════╗ - B───║ self ║───A D───║ other ║───C - b───║ ║───a d───║ ║───c - ╚═══════╝ ╚═══════╝ - - B and D-A must not overlap, same for b and d-a, etc. The result is a new ``Wires`` object - - .. code-block:: - - ╔═══════╗ - B+(D-A)────║self @ ║────C+(A-D) - b+(d-a)────║ other ║────c+(a-d) - ╚═══════╝ - - Using the permutations, it is possible to write: - - .. code-block:: - - ansatz = ansatz1[idx1] @ ansatz2[idx2] # not in standard order - wires, perm_CV, perm_DV = wires1 @ wires2 # matmul the wires - ansatz = ansatz.reorder(perm_CV, perm_DV) # now in standard order - - Args: - other: The wires of the other circuit component. - - Returns: - The wires of the circuit composition and the permutations. - """ - bra_out = other.output.bra + (self.output.bra - other.input.bra) - ket_out = other.output.ket + (self.output.ket - other.input.ket) - bra_in = self.input.bra + (other.input.bra - self.output.bra) - ket_in = self.input.ket + (other.input.ket - self.output.ket) - cl_out = other.classical.output + (self.classical.output - other.classical.input) - cl_in = self.classical.input + (other.classical.input - self.classical.output) - - # get the wires - w = Wires() - w.quantum_wires = (bra_out + bra_in + ket_out + ket_in).wires - w.classical_wires = (cl_out + cl_in).wires - w.reindex() - - # get the permutations - CV_ids = [w.id for w in w.CV_wires if w.id in self.ids] + [ - w.id for w in w.CV_wires if w.id in other.ids - ] - DV_ids = [w.id for w in w.DV_wires if w.id in self.ids] + [ - w.id for w in w.DV_wires if w.id in other.ids - ] - CV_perm = [CV_ids.index(w.id) for w in w.CV_wires] - DV_perm = [DV_ids.index(w.id) for w in w.DV_wires] - return w, CV_perm, DV_perm - - def _ipython_display_(self): - display(widgets.wires(self)) diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index db14a9f48..1ca6ba14c 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -1,28 +1,81 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Sequence +from random import randint +from copy import deepcopy +from enum import Enum, auto +from IPython.display import display +from functools import lru_cache, cached_property -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +from mrmustard import widgets -# http://www.apache.org/licenses/LICENSE-2.0 +__all__ = ["Wires"] -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""``Wires`` class for supporting tensor network functionalities.""" +class Repr(Enum): + UNSPECIFIED = auto() + BARGMANN = auto() + FOCK = auto() + QUADRATURE = auto() + PHASESPACE = auto() + CHARACTERISTIC = auto() -from __future__ import annotations -from functools import cached_property -import numpy as np -from IPython.display import display +class WiresType(Enum): + DM_LIKE = auto() # only output ket and bra on same modes + KET_LIKE = auto() # only output ket + UNITARY_LIKE = auto() # such that can map ket to ket + CHANNEL_LIKE = auto() # such that can map dm to dm + PROJ_MEAS_LIKE = auto() # only input ket + POVM_LIKE = auto() # only input ket and input bra on same modes + CLASSICAL_LIKE = auto() # only classical wires -from mrmustard import widgets -__all__ = ["Wires"] +@dataclass(slots=True) +class QuantumWire: + mode: int + is_out: bool + is_ket: bool + index: int + repr: Repr = Repr.UNSPECIFIED + id: int = field(default_factory=lambda: randint(0, 2**32 - 1)) + + @property + def is_dv(self) -> bool: + return self.repr == Repr.FOCK + + def __hash__(self) -> int: + return hash((self.mode, self.is_out, self.is_ket)) + + def __repr__(self) -> str: + return f"QuantumWire(mode={self.mode}, out={self.is_out}, ket={self.is_ket}, repr={self.repr}, index={self.index})" + + def __eq__(self, other: QuantumWire) -> bool: + return ( + self.mode == other.mode and self.is_out == other.is_out and self.is_ket == other.is_ket + ) + + +@dataclass(slots=True) +class ClassicalWire: + mode: int + is_out: bool + index: int + repr: Repr = Repr.UNSPECIFIED + id: int = field(default_factory=lambda: randint(0, 2**32 - 1)) + + @property + def is_dv(self) -> bool: + return self.repr == Repr.FOCK + + def __hash__(self) -> int: + return hash((self.mode, self.is_out, self.is_dv)) + + def __repr__(self) -> str: + return f"ClassicalWire(mode={self.mode}, out={self.is_out}, repr={self.repr}, index={self.index})" + + def __eq__(self, other: ClassicalWire) -> bool: + return self.mode == other.mode and self.is_out == other.is_out class Wires: @@ -150,7 +203,7 @@ class Wires: ║ ║─────▶║ ║─────▶ 0 ╚═══════╝ ╚═══════╝ - The permutation that takes the contracted representations to the standard order is also returned. + The permutations that standardize the CV and DV variables of the contracted reprs are also returned. Args: modes_out_bra: The output modes on the bra side. @@ -159,307 +212,305 @@ class Wires: modes_in_ket: The input modes on the ket side. classical_out: The output modes for classical information. classical_in: The input modes for classical information. + + Returns: + A ``Wires`` object, and the permutations that standardize the CV and DV variables. """ def __init__( self, - modes_out_bra: set[int] | None = None, - modes_in_bra: set[int] | None = None, - modes_out_ket: set[int] | None = None, - modes_in_ket: set[int] | None = None, - classical_out: set[int] | None = None, - classical_in: set[int] | None = None, - ) -> None: - self.args: tuple[set, ...] = ( - modes_out_bra or set(), - modes_in_bra or set(), - modes_out_ket or set(), - modes_in_ket or set(), - classical_out or set(), - classical_in or set(), - ) - self._len = None + modes_out_bra: Sequence[int] = (), + modes_in_bra: Sequence[int] = (), + modes_out_ket: Sequence[int] = (), + modes_in_ket: Sequence[int] = (), + classical_out: Sequence[int] = (), + classical_in: Sequence[int] = (), + ): + self.quantum_wires = set() + self.classical_wires = set() + + for i, m in enumerate(sorted(modes_out_bra)): + self.quantum_wires.add(QuantumWire(mode=m, is_out=True, is_ket=False, index=i)) + n = len(modes_out_bra) + for i, m in enumerate(sorted(modes_in_bra)): + self.quantum_wires.add(QuantumWire(mode=m, is_out=False, is_ket=False, index=n + i)) + n += len(modes_in_bra) + for i, m in enumerate(sorted(modes_out_ket)): + self.quantum_wires.add(QuantumWire(mode=m, is_out=True, is_ket=True, index=n + i)) + n += len(modes_out_ket) + for i, m in enumerate(sorted(modes_in_ket)): + self.quantum_wires.add(QuantumWire(mode=m, is_out=False, is_ket=True, index=n + i)) + n += len(modes_in_ket) + for i, m in enumerate(sorted(classical_out)): + self.classical_wires.add(ClassicalWire(mode=m, is_out=True, index=n + i)) + n += len(classical_out) + for i, m in enumerate(sorted(classical_in)): + self.classical_wires.add(ClassicalWire(mode=m, is_out=False, index=n + i)) + + def copy(self) -> Wires: + return deepcopy(self) + + ###### TRANSFORMATIONS ###### - # The "parent" wires object, if any. This is ``None`` for freshly initialized - # wires objects, and ``not None`` for subsets. - self._original = None - - # Adds elements to the cache when calling ``__getitem__`` - self._mode_cache = {} - - @cached_property + @property def adjoint(self) -> Wires: r""" - New ``Wires`` object obtained by swapping ket and bra wires. + New ``Wires`` object with the adjoint quantum wires (ket becomes bra and vice versa). """ - return Wires( - self.args[2], - self.args[3], - self.args[0], - self.args[1], - self.args[4], - self.args[5], - ) + w = self.copy() + for q in w.quantum_wires: + q.is_ket = not q.is_ket + return w - @cached_property - def bra(self) -> Wires: + @property + def dual(self) -> Wires: r""" - New ``Wires`` object with only bra wires. + New ``Wires`` object with dual quantum and classical wires (input becomes output and vice versa). + """ + w = self.copy() + for q in w.quantum_wires: + q.is_out = not q.is_out + for c in w.classical_wires: + c.is_out = not c.is_out + return w + + ###### SUBSETS ###### + + @lru_cache + def __getitem__(self, modes: tuple[int, ...] | int) -> Wires: + """ + Returns the quantum and classical wires with the given modes. """ - ret = Wires(modes_out_bra=self.args[0], modes_in_bra=self.args[1]) - ret._original = self.original or self - return ret + modes = {modes} if isinstance(modes, int) else set(modes) + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires if q.mode in modes} + w.classical_wires = {c for c in self.classical_wires if c.mode in modes} + return w @cached_property def classical(self) -> Wires: r""" New ``Wires`` object with only classical wires. """ - ret = Wires(classical_out=self.args[4], classical_in=self.args[5]) - ret._original = self.original or self - return ret + w = Wires() + w.classical_wires = self.classical_wires + return w @cached_property def quantum(self) -> Wires: r""" New ``Wires`` object with only quantum wires. """ - ret = Wires( - modes_out_bra=self.args[0], - modes_in_bra=self.args[1], - modes_out_ket=self.args[2], - modes_in_ket=self.args[3], - ) - ret._original = self.original or self - return ret + w = Wires() + w.quantum_wires = self.quantum_wires + return w @cached_property - def dual(self) -> Wires: + def bra(self) -> Wires: r""" - New ``Wires`` object obtained by swapping input and output wires. + New ``Wires`` object with only quantum bra wires. """ - return Wires( - self.args[1], - self.args[0], - self.args[3], - self.args[2], - self.args[5], - self.args[4], - ) + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires if not q.is_ket} + return w @cached_property - def id(self) -> int: + def ket(self) -> Wires: r""" - A numerical identifier for this ``Wires`` object. - - The ``id`` are random and unique, and are preserved when taking subsets. + New ``Wires`` object with only quantum ket wires. """ - if self.original: - return self.original.id - return np.random.randint(0, 2**31) + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires if q.is_ket} + return w @cached_property - def ids(self) -> list[int]: + def input(self) -> Wires: r""" - A list of numerical identifier for the wires in this ``Wires`` object, in - the standard order. - - The ``ids`` are derived incrementally from the ``id`` and are unique. - - .. code-block:: - - >>> w = Wires(modes_in_ket = {0,1}, modes_out_ket = {0,1}) - >>> id = w.id - >>> ids = w.ids - >>> assert ids == [id, id+1, id+2, id+3] + New ``Wires`` object with only classical and quantum input wires. """ - if self.original: - return [self.original.ids[i] for i in self.indices] - return [id for d in self.ids_dicts for id in d.values()] + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires if not q.is_out} + w.classical_wires = {c for c in self.classical_wires if not c.is_out} + return w @cached_property - def ids_dicts(self) -> list[dict[int, int]]: + def output(self) -> Wires: r""" - A list of dictionary mapping modes to ``ids``, one for each of the subsets - (``output.bra``, ``input.bra``, ``output.ket``, ``input.ket``, - ``output.classical``, and ``input.classical``). - - If subsets are taken, ``ids_dicts`` refers to the parent object rather than to the - child. + New ``Wires`` object with only classical and quantum output wires. """ - if self.original: - return self.original.ids_dicts - return [{m: i + self.id for m, i in d.items()} for d in self.index_dicts] + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires if q.is_out} + w.classical_wires = {c for c in self.classical_wires if c.is_out} + return w + + ###### PROPERTIES ###### @cached_property - def index_dicts(self) -> list[dict[int, int]]: - r""" - A list of dictionary mapping modes to indices, one for each of the subsets - (``output.bra``, ``input.bra``, ``output.ket``, ``input.ket``, - ``output.classical``, and ``input.classical``). + def id(self) -> int: + return randint(0, 2**32 - 1) - If subsets are taken, ``index_dicts`` refers to the parent object rather than to the - child. + @cached_property + def modes(self) -> set[int]: + r""" + The modes spanned by the wires. """ - if self.original: - return self.original.index_dicts - return [ - {m: i + sum(len(s) for s in self.args[:t]) for i, m in enumerate(lst)} - for t, lst in enumerate(self.sorted_args) - ] + return {q.mode for q in self.quantum_wires} | {c.mode for c in self.classical_wires} - @property - def ids_index_dicts(self) -> list[dict[int, int]]: + @cached_property + def ids(self) -> tuple[int, ...]: r""" - A list of dictionary mapping ids to indices, one for each of the subsets - (``output.bra``, ``input.bra``, ``output.ket``, ``input.ket``, - ``output.classical``, and ``input.classical``). - - If subsets are taken, ``ids_index_dicts`` refers to the parent object rather than to the - child. + The ids of the wires in standard order. """ - if self.original: - return self.original.ids_index_dicts - return [ - {v: self.index_dicts[t][k] for k, v in self.ids_dicts[t].items()} - for t in (0, 1, 2, 3, 4, 5) - ] + return tuple(w.id for w in self.sorted_wires) @cached_property def indices(self) -> tuple[int, ...]: r""" - The array of indices of this ``Wires`` in the standard order. - When a subset is selected (e.g. ``.ket``), it doesn't include wires that do not belong - to the subset, but it still counts them because indices refer to the original modes. - - .. code-block:: - - >>> w = Wires(modes_in_ket = {0,1}, modes_out_ket = {0,1}) - >>> assert w.indices == (0,1,2,3) - >>> assert w.input.indices == (2,3) + The indices of the wires in standard order. """ - return tuple( - self.index_dicts[t][m] for t, modes in enumerate(self.sorted_args) for m in modes - ) + return tuple(w.index for w in self.sorted_wires) @cached_property - def input(self) -> Wires: + def DV_indices(self) -> tuple[int, ...]: r""" - New ``Wires`` object without output wires. + The indices of the DV wires (both quantum and classical) in standard order. """ - ret = Wires(set(), self.args[1], set(), self.args[3], set(), self.args[5]) - ret._original = self.original or self - return ret + return tuple(q.index for q in self.DV_wires) @cached_property - def ket(self) -> Wires: + def CV_indices(self) -> tuple[int, ...]: r""" - New ``Wires`` object with only ket wires. + The indices of the CV wires (both quantum and classical) in standard order. """ - ret = Wires(modes_out_ket=self.args[2], modes_in_ket=self.args[3]) - ret._original = self.original or self - return ret + return tuple(q.index for q in self.CV_wires) @cached_property - def modes(self) -> set[int]: + def DV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: r""" - The modes spanned by the wires. + The DV wires in standard order. """ - return set.union(*self.args) + return tuple(w for w in self.sorted_wires.copy() if w.is_dv) - @property - def original(self): + @cached_property + def CV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: r""" - The parent wire, if any. + The CV wires in standard order. """ - return self._original + return tuple(w for w in self.sorted_wires.copy() if not w.is_dv) @cached_property - def output(self) -> Wires: + def args(self) -> tuple[set[int], ...]: r""" - New ``Wires`` object with only output wires. + The arguments to pass to ``Wires`` to create the same object with fresh wires. """ - ret = Wires(self.args[0], set(), self.args[2], set(), self.args[4], set()) - ret._original = self.original or self - return ret + return ( + self.bra.output.modes, + self.bra.input.modes, + self.ket.output.modes, + self.ket.input.modes, + self.classical.output.modes, + self.classical.input.modes, + ) @cached_property - def sorted_args(self) -> tuple[list[int], ...]: + def wires(self) -> set[QuantumWire | ClassicalWire]: r""" - The sorted arguments. Allows to sort them only once. + A set of all wires. """ - return tuple(sorted(s) for s in self.args) + return {*self.quantum_wires, *self.classical_wires} - def contracted_indices(self, other: Wires): + @cached_property + def sorted_wires(self) -> list[QuantumWire | ClassicalWire]: r""" - Returns the indices being contracted between self and other when calling matmul. - - Args: - other: another Wires object + A list of all wires in standard order. """ - ovlp_bra, ovlp_ket = self.overlap(other) - idxA = self.output.bra[ovlp_bra].indices + self.output.ket[ovlp_ket].indices - idxB = other.input.bra[ovlp_bra].indices + other.input.ket[ovlp_ket].indices - return idxA, idxB + return [ + *sorted(self.bra.output.wires, key=lambda s: s.mode), + *sorted(self.bra.input.wires, key=lambda s: s.mode), + *sorted(self.ket.output.wires, key=lambda s: s.mode), + *sorted(self.ket.input.wires, key=lambda s: s.mode), + *sorted(self.classical.output.wires, key=lambda s: s.mode), + *sorted(self.classical.input.wires, key=lambda s: s.mode), + ] - def overlap(self, other: Wires) -> tuple[set[int], set[int]]: - r""" - Returns the modes that overlap between the two ``Wires`` objects. + ###### METHODS ###### - Args: - other: Another ``Wires`` object. + def wire(self, mode: int, is_out: bool, is_ket: bool) -> QuantumWire | ClassicalWire: + r""" + Returns the wire with the given mode, ket, and output status. """ - ovlp_ket = self.output.ket.modes & other.input.ket.modes - ovlp_bra = self.output.bra.modes & other.input.bra.modes - return ovlp_bra, ovlp_ket + if quantum := [ + w + for w in self.quantum_wires + if w.mode == mode and w.is_out == is_out and w.is_ket == is_ket + ]: + return quantum[0] + if classical := [w for w in self.classical_wires if w.mode == mode and w.is_out == is_out]: + return classical[0] + raise ValueError(f"No wire with mode {mode}, is_out {is_out}, and is_ket {is_ket}.") + + def reindex(self) -> None: + r""" + Updates the indices of the wires according to the standard order. + """ + for i, w in enumerate(self.sorted_wires): + w.index = i def __add__(self, other: Wires) -> Wires: r""" - New ``Wires`` object that combines the wires of ``self`` and those of ``other``. - - Raises: - ValueError: If any leftover wires would overlap. + New ``Wires`` object that combines the wires of self and other. + If there are overlapping wires (same mode, is_ket, is_out), raises a ValueError. """ - new_args = [] - for t, (m1, m2) in enumerate(zip(self.args, other.args)): - if m := m1 & m2: - raise ValueError(f"{t}-type wires overlap at mode {m}.") - new_args.append(m1 | m2) - return Wires(*new_args) - - def __bool__(self) -> bool: + if ovlp_classical := self.classical_wires & other.classical_wires: + raise ValueError(f"Overlapping classical wires {ovlp_classical}.") + if ovlp_quantum := self.quantum_wires & other.quantum_wires: + raise ValueError(f"Overlapping quantum wires {ovlp_quantum}.") + w = Wires() + w.quantum_wires = self.quantum_wires | other.quantum_wires + w.classical_wires = self.classical_wires | other.classical_wires + w.reindex() + return w + + def __sub__(self, other: Wires) -> Wires: r""" - Returns ``True`` if this ``Wires`` object has any wires, ``False`` otherwise. + New ``Wires`` object that removes the wires of other from self, by mode. + Note it does not look at ket, bra, input or output: just the mode. Use with caution. """ - return any(self.args) + w = Wires() + w.quantum_wires = {q for q in self.quantum_wires.copy() if q.mode not in other.modes} + w.classical_wires = {c for c in self.classical_wires.copy() if c.mode not in other.modes} + w.reindex() + return w - def __eq__(self, other) -> bool: - return self.args == other.args + def __bool__(self) -> bool: + return bool(self.quantum_wires) or bool(self.classical_wires) - def __getitem__(self, modes: tuple[int, ...] | int) -> Wires: - r""" - New ``Wires`` object with wires only on the given modes. - """ - modes = {modes} if isinstance(modes, int) else set(modes) - if tuple(modes) not in self._mode_cache: - w = Wires(*(self.args[t] & modes for t in (0, 1, 2, 3, 4, 5))) - w._original = self.original or self - self._mode_cache[tuple(modes)] = w - return self._mode_cache[tuple(modes)] + def __hash__(self) -> int: + return hash(tuple(tuple(sorted(subset)) for subset in self.args)) + + def __eq__(self, other: Wires) -> bool: + return ( + self.quantum_wires == other.quantum_wires + and self.classical_wires == other.classical_wires + ) def __len__(self) -> int: - r""" - The number of wires. - """ - if self._len is None: - self._len = sum(map(len, self.args)) - return self._len + return len(self.quantum_wires) + len(self.classical_wires) + + def __repr__(self) -> str: + return ( + f"Wires(modes_out_bra={self.output.bra.modes}, " + f"modes_in_bra={self.input.bra.modes}, " + f"modes_out_ket={self.output.ket.modes}, " + f"modes_in_ket={self.input.ket.modes}, " + f"classical_out={self.output.classical.modes}, " + f"classical_in={self.input.classical.modes})" + ) - def __matmul__(self, other: Wires) -> tuple[Wires, list[int]]: + def __matmul__(self, other: Wires) -> tuple[Wires, list[int], list[int]]: r""" - Returns the wires of the circuit composition of self and other without adding missing - adjoints. It also returns the permutation that takes the contracted representations - to the standard order. An exception is raised if any leftover wires would overlap. + Returns the ``Wires`` for the circuit component resulting from the composition of self and other. + Returns also the permutations of the CV and DV wires to reorder the wires to standard order. Consider the following example: .. code-block:: @@ -474,76 +525,47 @@ def __matmul__(self, other: Wires) -> tuple[Wires, list[int]]: .. code-block:: ╔═══════╗ - B|(D-A)────║self @ ║────C|(A-D) - b|(d-a)────║ other ║────c|(a-d) + B+(D-A)────║self @ ║────C+(A-D) + b+(d-a)────║ other ║────c+(a-d) ╚═══════╝ - In comparison, contracting the representations rather than the wires corresponds to - an order where we start from juxtaposing the objects and then removing pairs of contracted - indices, i.e. A-D, B, C, D-A and then the same for a-d, b, c, d-a. The returned permutation - is the one that takes the result of multiplying representations to the standard order. - - This way it is possible to write: + Using the permutations, it is possible to write: .. code-block:: ansatz = ansatz1[idx1] @ ansatz2[idx2] # not in standard order - wires, perm = wires1 @ wires2 # matmul the wires of each component - ansatz = ansatz.reorder(perm) # now in standard order + wires, perm_CV, perm_DV = wires1 @ wires2 # matmul the wires + ansatz = ansatz.reorder(perm_CV, perm_DV) # now in standard order Args: other: The wires of the other circuit component. Returns: - The wires of the circuit composition and the permutation. - - Raises: - ValueError: If any leftover wires would overlap. + The wires of the circuit composition and the permutations. """ - if self.original or other.original: - raise ValueError("Cannot contract a subset of wires.") - A, B, a, b, E, F = self.args - C, D, c, d, G, H = other.args - sets = (A - D, B, a - d, b, E - H, F, C, D - A, c, d - a, G, H - E) - if m := sets[0] & sets[6]: - raise ValueError(f"Output bra modes {m} overlap.") - if m := sets[1] & sets[7]: - raise ValueError(f"Input bra modes {m} overlap.") - if m := sets[2] & sets[8]: - raise ValueError(f"Output ket modes {m} overlap.") - if m := sets[3] & sets[9]: - raise ValueError(f"Input ket modes {m} overlap.") - if m := sets[4] & sets[10]: - raise ValueError(f"Output classical modes {m} overlap.") - if m := sets[5] & sets[11]: - raise ValueError(f"Input classical modes {m} overlap.") - bra_out = sets[0] | sets[6] # (self.output.bra - other.input.bra) | other.output.bra - bra_in = sets[1] | sets[7] # self.input.bra | (other.input.bra - self.output.bra) - ket_out = sets[2] | sets[8] # (self.output.ket - other.input.ket) | other.output.ket - ket_in = sets[3] | sets[9] # self.input.ket | (other.input.ket - self.output.ket) - classical_out = ( - sets[4] | sets[10] - ) # (self.output.classical - other.input.classical) | other.output.classical - classical_in = ( - sets[5] | sets[11] - ) # self.input.classical | (other.input.classical - self.output.classical) - w = Wires(bra_out, bra_in, ket_out, ket_in, classical_out, classical_in) - - # preserve ids - for t in (0, 1, 2, 3, 4, 5): - for m in w.args[t]: - w.ids_dicts[t][m] = self.ids_dicts[t][m] if m in sets[t] else other.ids_dicts[t][m] - - # calculate permutation - result_ids = [id for d in w.ids_dicts for id in d.values()] - self_other_ids = [ - self.ids_dicts[t][m] for t in (0, 1, 2, 3, 4, 5) for m in sorted(sets[t]) - ] + [other.ids_dicts[t][m] for t in (0, 1, 2, 3, 4, 5) for m in sorted(sets[t + 6])] - perm = [self_other_ids.index(id) for id in result_ids] - return w, perm - - def __repr__(self) -> str: - return f"Wires{self.args}" + bra_out = other.output.bra + (self.output.bra - other.input.bra) + ket_out = other.output.ket + (self.output.ket - other.input.ket) + bra_in = self.input.bra + (other.input.bra - self.output.bra) + ket_in = self.input.ket + (other.input.ket - self.output.ket) + cl_out = other.classical.output + (self.classical.output - other.classical.input) + cl_in = self.classical.input + (other.classical.input - self.classical.output) + + # get the wires + w = Wires() + w.quantum_wires = (bra_out + bra_in + ket_out + ket_in).wires + w.classical_wires = (cl_out + cl_in).wires + w.reindex() + + # get the permutations + CV_ids = [w.id for w in w.CV_wires if w.id in self.ids] + [ + w.id for w in w.CV_wires if w.id in other.ids + ] + DV_ids = [w.id for w in w.DV_wires if w.id in self.ids] + [ + w.id for w in w.DV_wires if w.id in other.ids + ] + CV_perm = [CV_ids.index(w.id) for w in w.CV_wires] + DV_perm = [DV_ids.index(w.id) for w in w.DV_wires] + return w, CV_perm, DV_perm def _ipython_display_(self): display(widgets.wires(self)) diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index f8f136207..8bc0b944f 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -21,7 +21,7 @@ import pytest from ipywidgets import HTML -from mrmustard.physics.new_wires import Repr, Wires +from mrmustard.physics.wires import Repr, Wires class TestWires: From 1ab2057d9ef49ca48a059e8f0d0fb11c35404169 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Thu, 14 Nov 2024 10:37:07 -0800 Subject: [PATCH 05/13] fix codefactor issues --- mrmustard/physics/wires.py | 39 +++++++++++++++++++++++++++++++- tests/test_physics/test_wires.py | 2 +- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index 1ca6ba14c..7c74f05a9 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -11,8 +11,16 @@ __all__ = ["Wires"] +""" +This module provides wire functionality for applications in MrMustard. +It defines the core classes for representing quantum and classical wires, and their +relationships in quantum optical circuits. +""" + class Repr(Enum): + """Enumeration of possible representations for quantum states and operations.""" + UNSPECIFIED = auto() BARGMANN = auto() FOCK = auto() @@ -22,6 +30,8 @@ class Repr(Enum): class WiresType(Enum): + """Enumeration of possible wire types in quantum circuits.""" + DM_LIKE = auto() # only output ket and bra on same modes KET_LIKE = auto() # only output ket UNITARY_LIKE = auto() # such that can map ket to ket @@ -33,6 +43,18 @@ class WiresType(Enum): @dataclass(slots=True) class QuantumWire: + """ + Represents a quantum wire in a circuit. + + Args: + mode: The mode number this wire represents + is_out: Whether this is an output wire + is_ket: Whether this wire is on the ket side + index: The index of this wire in the circuit + repr: The representation of this wire + id: Unique identifier for this wire + """ + mode: int is_out: bool is_ket: bool @@ -42,6 +64,7 @@ class QuantumWire: @property def is_dv(self) -> bool: + """Returns True if this wire uses discrete-variable representation.""" return self.repr == Repr.FOCK def __hash__(self) -> int: @@ -58,6 +81,17 @@ def __eq__(self, other: QuantumWire) -> bool: @dataclass(slots=True) class ClassicalWire: + """ + Represents a classical wire in a circuit. + + Args: + mode: The mode number this wire represents + is_out: Whether this is an output wire + index: The index of this wire in the circuit + repr: The representation of this wire + id: Unique identifier for this wire + """ + mode: int is_out: bool index: int @@ -66,6 +100,7 @@ class ClassicalWire: @property def is_dv(self) -> bool: + """Returns True if this wire uses discrete-variable representation.""" return self.repr == Repr.FOCK def __hash__(self) -> int: @@ -78,7 +113,7 @@ def __eq__(self, other: ClassicalWire) -> bool: return self.mode == other.mode and self.is_out == other.is_out -class Wires: +class Wires: # pylint: disable=too-many-public-methods r""" A class with wire functionality for tensor network applications. @@ -248,6 +283,7 @@ def __init__( self.classical_wires.add(ClassicalWire(mode=m, is_out=False, index=n + i)) def copy(self) -> Wires: + """Returns a deep copy of this Wires object.""" return deepcopy(self) ###### TRANSFORMATIONS ###### @@ -347,6 +383,7 @@ def output(self) -> Wires: @cached_property def id(self) -> int: + """Returns a unique identifier for this Wires object.""" return randint(0, 2**32 - 1) @cached_property diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index 8bc0b944f..a09b1b636 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -129,7 +129,7 @@ def test_matmul(self): # contracts 17,17 on classical u = Wires({1, 5}, {2, 6, 15}, {3, 7, 13}, {4, 8}, {16, 17}, {18}) v = Wires({0, 9, 14}, {1, 10}, {2, 11}, {13, 3, 12}, {19}, {17}) - new_wires, CV_perm, DV_perm = u @ v + new_wires, CV_perm, _ = u @ v assert new_wires.args == ( {0, 5, 9, 14}, {2, 6, 10, 15}, From 6b0f0744e5b1b06998d4067532258b3c4f443c39 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Thu, 14 Nov 2024 10:40:21 -0800 Subject: [PATCH 06/13] fixes wires tests --- tests/test_physics/test_wires.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index a09b1b636..b77aa662f 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -31,10 +31,7 @@ class TestWires: def test_init(self): w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}) - assert w.args == ({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}, set()) - - w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}, FOCK={1}) - assert w.wire(mode=1, is_ket=False, is_out=True).repr == Repr.FOCK + assert w.args == ({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}, {9}, {10}) def test_indices(self): w = Wires({0, 10, 20}, {30, 40, 50}, {60, 70}, {80}) @@ -137,7 +134,6 @@ def test_matmul(self): {4, 8, 12}, {16, 19}, {18}, - set(), ) assert CV_perm == [9, 0, 10, 11, 1, 2, 12, 3, 13, 4, 14, 5, 6, 15, 7, 16, 8] From c3ea39bf19be42424e093bce5165d0144fa31c4a Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Fri, 15 Nov 2024 10:05:07 -0800 Subject: [PATCH 07/13] default wires for components --- .../lab_dev/circuit_components_utils/b_to_ps.py | 9 ++++----- mrmustard/lab_dev/circuit_components_utils/b_to_q.py | 12 +++++++----- mrmustard/lab_dev/states/number.py | 5 +++++ mrmustard/lab_dev/states/quadrature_eigenstate.py | 5 +++++ mrmustard/lab_dev/states/sauron.py | 6 +++++- 5 files changed, 26 insertions(+), 11 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index 22e96dfbf..b84288758 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -23,7 +23,7 @@ from ..transformations.base import Map from ...physics.ansatz import PolyExpAnsatz -from ...physics.representations import RepEnum +from ...physics.wires import ReprEnum from ..utils import make_parameter __all__ = ["BtoPS"] @@ -53,10 +53,9 @@ def __init__( fn=triples.displacement_map_s_parametrized_Abc, s=self.s, n_modes=len(modes) ), ).representation - for i in self.wires.input.indices: - self.representation._idx_reps[i] = (RepEnum.BARGMANN, None) - for i in self.wires.output.indices: - self.representation._idx_reps[i] = (RepEnum.PHASESPACE, float(self.s.value)) + for w in self.representation.wires.output.wires: + w.repr = ReprEnum.CHARACTERISTIC + w.repr_params = float(self.s.value) def inverse(self): ret = BtoPS(self.modes, self.s) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index 3cf4d2d5a..c6e7855ff 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -22,7 +22,7 @@ from ..transformations.base import Operation from ...physics.ansatz import PolyExpAnsatz -from ...physics.representations import RepEnum +from ...physics.wires import ReprEnum from ..utils import make_parameter __all__ = ["BtoQ"] @@ -53,10 +53,12 @@ def __init__( fn=triples.bargmann_to_quadrature_Abc, n_modes=len(modes), phi=self.phi ), ).representation - for i in self.wires.input.indices: - self.representation._idx_reps[i] = (RepEnum.BARGMANN, None) - for i in self.wires.output.indices: - self.representation._idx_reps[i] = (RepEnum.QUADRATURE, float(self.phi.value)) + for w in self.representation.wires.input.wires: + w.repr = ReprEnum.BARGMANN + w.repr_params = None + for w in self.representation.wires.output.wires: + w.repr = ReprEnum.QUADRATURE + w.repr_params = float(self.phi.value) def inverse(self): ret = BtoQ(self.modes, self.phi) diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index 9f55c72b5..ffa9db1a8 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -21,6 +21,7 @@ from typing import Sequence from mrmustard.physics.ansatz import ArrayAnsatz +from mrmustard.physics.wires import ReprEnum from mrmustard.physics.fock_utils import fock_state from .ket import Ket from ..utils import make_parameter, reshape_params @@ -81,3 +82,7 @@ def __init__( self.short_name = [str(int(n)) for n in self.n.value] for i, cutoff in enumerate(self.cutoffs.value): self.manual_shape[i] = int(cutoff) + 1 + + for w in self.representation.wires.input.wires: + w.repr = ReprEnum.FOCK + w.repr_params = [int(self.n.value[w.mode])] diff --git a/mrmustard/lab_dev/states/quadrature_eigenstate.py b/mrmustard/lab_dev/states/quadrature_eigenstate.py index 29acac3ad..2c1671e0e 100644 --- a/mrmustard/lab_dev/states/quadrature_eigenstate.py +++ b/mrmustard/lab_dev/states/quadrature_eigenstate.py @@ -24,6 +24,7 @@ from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples +from mrmustard.physics.wires import ReprEnum from .ket import Ket from ..utils import make_parameter, reshape_params @@ -77,6 +78,10 @@ def __init__( ), ).representation + for w in self.representation.wires.input.wires: + w.repr = ReprEnum.QUADRATURE + w.repr_params = [float(self.x.value[w.mode]), float(self.phi.value[w.mode])] + @property def L2_norm(self): r""" diff --git a/mrmustard/lab_dev/states/sauron.py b/mrmustard/lab_dev/states/sauron.py index e4476f2a8..ffe040a17 100644 --- a/mrmustard/lab_dev/states/sauron.py +++ b/mrmustard/lab_dev/states/sauron.py @@ -18,7 +18,7 @@ from mrmustard.lab_dev.states.ket import Ket from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples - +from mrmustard.physics.wires import ReprEnum from ..utils import make_parameter @@ -50,3 +50,7 @@ def __init__(self, modes: Sequence[int], n: int, epsilon: float = 0.1): triples.sauron_state_Abc, n=self.n.value, epsilon=self.epsilon.value ), ).representation + + for w in self.representation.wires.input.wires: + w.repr = ReprEnum.FOCK + w.repr_params = [float(self.n.value[w.mode]), float(self.epsilon.value)] From 9846ecddb7d9a06c5a8ed96b578a0f560fb80821 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Fri, 15 Nov 2024 10:06:13 -0800 Subject: [PATCH 08/13] fix set issue --- mrmustard/physics/wires.py | 113 ++++++++++++++++++++++++------------- 1 file changed, 74 insertions(+), 39 deletions(-) diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index 7c74f05a9..9a0d44e3e 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -1,8 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Sequence +from dataclasses import dataclass, field, replace +from typing import Sequence, Any from random import randint -from copy import deepcopy from enum import Enum, auto from IPython.display import display from functools import lru_cache, cached_property @@ -18,7 +17,15 @@ """ -class Repr(Enum): +class LegibleEnum(Enum): + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return self.name + + +class ReprEnum(LegibleEnum): """Enumeration of possible representations for quantum states and operations.""" UNSPECIFIED = auto() @@ -29,7 +36,7 @@ class Repr(Enum): CHARACTERISTIC = auto() -class WiresType(Enum): +class WiresType(LegibleEnum): """Enumeration of possible wire types in quantum circuits.""" DM_LIKE = auto() # only output ket and bra on same modes @@ -59,23 +66,38 @@ class QuantumWire: is_out: bool is_ket: bool index: int - repr: Repr = Repr.UNSPECIFIED - id: int = field(default_factory=lambda: randint(0, 2**32 - 1)) + repr: ReprEnum = ReprEnum.BARGMANN + repr_params: Any = None + id: int = field(default_factory=lambda: randint(0, 2**32 - 1), compare=False) @property def is_dv(self) -> bool: """Returns True if this wire uses discrete-variable representation.""" - return self.repr == Repr.FOCK + return self.repr == ReprEnum.FOCK def __hash__(self) -> int: - return hash((self.mode, self.is_out, self.is_ket)) + return hash((self.mode, self.is_out, self.is_ket, self.repr)) def __repr__(self) -> str: return f"QuantumWire(mode={self.mode}, out={self.is_out}, ket={self.is_ket}, repr={self.repr}, index={self.index})" def __eq__(self, other: QuantumWire) -> bool: return ( - self.mode == other.mode and self.is_out == other.is_out and self.is_ket == other.is_ket + self.mode == other.mode + and self.is_out == other.is_out + and self.is_ket == other.is_ket + and self.repr == other.repr + ) + + def copy(self) -> QuantumWire: + return QuantumWire( + mode=self.mode, + is_out=self.is_out, + is_ket=self.is_ket, + index=self.index, + repr=self.repr, + repr_params=self.repr_params, + id=self.id, ) @@ -95,22 +117,33 @@ class ClassicalWire: mode: int is_out: bool index: int - repr: Repr = Repr.UNSPECIFIED + repr: ReprEnum = ReprEnum.UNSPECIFIED + repr_params: Any = None id: int = field(default_factory=lambda: randint(0, 2**32 - 1)) @property def is_dv(self) -> bool: """Returns True if this wire uses discrete-variable representation.""" - return self.repr == Repr.FOCK + return self.repr == ReprEnum.FOCK def __hash__(self) -> int: - return hash((self.mode, self.is_out, self.is_dv)) + return hash((self.mode, self.is_out, self.repr)) def __repr__(self) -> str: return f"ClassicalWire(mode={self.mode}, out={self.is_out}, repr={self.repr}, index={self.index})" def __eq__(self, other: ClassicalWire) -> bool: - return self.mode == other.mode and self.is_out == other.is_out + return self.mode == other.mode and self.is_out == other.is_out and self.repr == other.repr + + def copy(self) -> ClassicalWire: + return ClassicalWire( + mode=self.mode, + is_out=self.is_out, + index=self.index, + repr=self.repr, + repr_params=self.repr_params, + id=self.id, + ) class Wires: # pylint: disable=too-many-public-methods @@ -284,30 +317,30 @@ def __init__( def copy(self) -> Wires: """Returns a deep copy of this Wires object.""" - return deepcopy(self) + w = Wires() + w.quantum_wires = {q.copy() for q in self.quantum_wires} + w.classical_wires = {c.copy() for c in self.classical_wires} + return w ###### TRANSFORMATIONS ###### - @property + @cached_property def adjoint(self) -> Wires: r""" New ``Wires`` object with the adjoint quantum wires (ket becomes bra and vice versa). """ w = self.copy() - for q in w.quantum_wires: - q.is_ket = not q.is_ket + w.quantum_wires = {replace(q, is_ket=not q.is_ket) for q in w.quantum_wires} return w - @property + @cached_property def dual(self) -> Wires: r""" New ``Wires`` object with dual quantum and classical wires (input becomes output and vice versa). """ w = self.copy() - for q in w.quantum_wires: - q.is_out = not q.is_out - for c in w.classical_wires: - c.is_out = not c.is_out + w.quantum_wires = {replace(q, is_out=not q.is_out) for q in w.quantum_wires} + w.classical_wires = {replace(c, is_out=not c.is_out) for c in w.classical_wires} return w ###### SUBSETS ###### @@ -386,35 +419,35 @@ def id(self) -> int: """Returns a unique identifier for this Wires object.""" return randint(0, 2**32 - 1) - @cached_property + @property def modes(self) -> set[int]: r""" The modes spanned by the wires. """ return {q.mode for q in self.quantum_wires} | {c.mode for c in self.classical_wires} - @cached_property + @property def ids(self) -> tuple[int, ...]: r""" The ids of the wires in standard order. """ return tuple(w.id for w in self.sorted_wires) - @cached_property + @property def indices(self) -> tuple[int, ...]: r""" The indices of the wires in standard order. """ return tuple(w.index for w in self.sorted_wires) - @cached_property + @property def DV_indices(self) -> tuple[int, ...]: r""" The indices of the DV wires (both quantum and classical) in standard order. """ return tuple(q.index for q in self.DV_wires) - @cached_property + @property def CV_indices(self) -> tuple[int, ...]: r""" The indices of the CV wires (both quantum and classical) in standard order. @@ -435,7 +468,7 @@ def CV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: """ return tuple(w for w in self.sorted_wires.copy() if not w.is_dv) - @cached_property + @property def args(self) -> tuple[set[int], ...]: r""" The arguments to pass to ``Wires`` to create the same object with fresh wires. @@ -486,11 +519,11 @@ def wire(self, mode: int, is_out: bool, is_ket: bool) -> QuantumWire | Classical return classical[0] raise ValueError(f"No wire with mode {mode}, is_out {is_out}, and is_ket {is_ket}.") - def reindex(self) -> None: + def _reindex(self) -> None: r""" Updates the indices of the wires according to the standard order. """ - for i, w in enumerate(self.sorted_wires): + for i, w in enumerate(self.wires): w.index = i def __add__(self, other: Wires) -> Wires: @@ -505,7 +538,7 @@ def __add__(self, other: Wires) -> Wires: w = Wires() w.quantum_wires = self.quantum_wires | other.quantum_wires w.classical_wires = self.classical_wires | other.classical_wires - w.reindex() + w._reindex() return w def __sub__(self, other: Wires) -> Wires: @@ -516,7 +549,7 @@ def __sub__(self, other: Wires) -> Wires: w = Wires() w.quantum_wires = {q for q in self.quantum_wires.copy() if q.mode not in other.modes} w.classical_wires = {c for c in self.classical_wires.copy() if c.mode not in other.modes} - w.reindex() + w._reindex() return w def __bool__(self) -> bool: @@ -591,18 +624,20 @@ def __matmul__(self, other: Wires) -> tuple[Wires, list[int], list[int]]: w = Wires() w.quantum_wires = (bra_out + bra_in + ket_out + ket_in).wires w.classical_wires = (cl_out + cl_in).wires - w.reindex() + w._reindex() # get the permutations CV_ids = [w.id for w in w.CV_wires if w.id in self.ids] + [ w.id for w in w.CV_wires if w.id in other.ids ] - DV_ids = [w.id for w in w.DV_wires if w.id in self.ids] + [ - w.id for w in w.DV_wires if w.id in other.ids - ] CV_perm = [CV_ids.index(w.id) for w in w.CV_wires] - DV_perm = [DV_ids.index(w.id) for w in w.DV_wires] - return w, CV_perm, DV_perm + + # NOTE: this is for when BtoF is in + # DV_ids = [w.id for w in w.DV_wires if w.id in self.ids] + [ + # w.id for w in w.DV_wires if w.id in other.ids + # ] + # DV_perm = [DV_ids.index(w.id) for w in w.DV_wires] + return w, CV_perm def _ipython_display_(self): display(widgets.wires(self)) From 5632af6f0f5cea217cc1ed05debebc3086f2d7f0 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Fri, 15 Nov 2024 10:06:52 -0800 Subject: [PATCH 09/13] fix attribute name --- tests/test_lab_dev/test_circuit_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 200479614..b94652ab2 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -567,7 +567,7 @@ def test_serialize_default_behaviour(self): kwargs, arrays = cc._serialize() assert kwargs == { "class": f"{CircuitComponent.__module__}.CircuitComponent", - "wires": cc.wires.sorted_args, + "wires": cc.wires.args, "ansatz_cls": f"{PolyExpAnsatz.__module__}.PolyExpAnsatz", "name": name, } From 5d84f9f5cefc6b6609f253d44ab8e00c04bffd2d Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Fri, 15 Nov 2024 10:31:54 -0800 Subject: [PATCH 10/13] update representation and tests --- mrmustard/physics/representations.py | 92 ++-------------------- tests/test_physics/test_representations.py | 41 +++------- 2 files changed, 15 insertions(+), 118 deletions(-) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 053495b69..468607e7a 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -18,7 +18,6 @@ from __future__ import annotations from typing import Sequence -from enum import Enum import numpy as np @@ -32,41 +31,11 @@ from .ansatz import Ansatz, PolyExpAnsatz, ArrayAnsatz from .triples import identity_Abc -from .wires import Wires +from .wires import Wires, ReprEnum __all__ = ["Representation"] -class RepEnum(Enum): - r""" - An enum to represent what representation a wire is in. - """ - - NONETYPE = 0 - BARGMANN = 1 - FOCK = 2 - QUADRATURE = 3 - PHASESPACE = 4 - - @classmethod - def from_ansatz(cls, ansatz: Ansatz): - r""" - Returns a ``RepEnum`` from an ``Ansatz``. - - Args: - ansatz: The ansatz. - """ - if isinstance(ansatz, PolyExpAnsatz): - return cls(1) - elif isinstance(ansatz, ArrayAnsatz): - return cls(2) - else: - return cls(0) - - def __repr__(self) -> str: - return self.name - - class Representation: r""" A class for representations. @@ -87,10 +56,7 @@ class Representation: """ def __init__( - self, - ansatz: Ansatz | None = None, - wires: Wires | Sequence[tuple[int]] | None = None, - idx_reps: dict | None = None, + self, ansatz: Ansatz | None = None, wires: Wires | Sequence[tuple[int]] | None = None ) -> None: self._ansatz = ansatz @@ -128,9 +94,6 @@ def __init__( self._ansatz = ansatz.reorder(tuple(perm)) self._wires = wires - self._idx_reps = idx_reps or dict.fromkeys( - wires.indices, (RepEnum.from_ansatz(ansatz), None) - ) @property def adjoint(self) -> Representation: @@ -142,12 +105,7 @@ def adjoint(self) -> Representation: kets = self.wires.ket.indices ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None wires = self.wires.adjoint - idx_reps = {} - for i, j in enumerate(kets): - idx_reps[i] = self._idx_reps[j] - for i, j in enumerate(bras): - idx_reps[i + len(kets)] = self._idx_reps[j] - return Representation(ansatz, wires, idx_reps) + return Representation(ansatz, wires) @property def ansatz(self) -> Ansatz | None: @@ -168,16 +126,7 @@ def dual(self) -> Representation: ob = self.wires.bra.output.indices ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None wires = self.wires.dual - idx_reps = {} - for i, j in enumerate(ib): - idx_reps[i] = self._idx_reps[j] - for i, j in enumerate(ob): - idx_reps[i + len(ib)] = self._idx_reps[j] - for i, j in enumerate(ik): - idx_reps[i + len(ib + ob)] = self._idx_reps[j] - for i, j in enumerate(ok): - idx_reps[i + len(ib + ob + ik)] = self._idx_reps[j] - return Representation(ansatz, wires, idx_reps) + return Representation(ansatz, wires) @property def wires(self) -> Wires | None: @@ -301,37 +250,9 @@ def _matmul_indices(self, other: Representation) -> tuple[tuple[int, ...], tuple idx_zconj += other.wires.ket.input[ket_modes].indices return idx_z, idx_zconj - def _matmul_idx_reps(self, wires_result: Wires, other: Representation): - r""" - Returns the new representation mappings when contracting ``self`` and ``other``. - - Args: - wires_result: The resulting wires after contraction. - other: The representation contracting with. - """ - idx_reps = {} - for id in wires_result.ids: - if id in other.wires.ids: - temp_rep = other - else: - temp_rep = self - for t in (0, 1, 2, 3, 4, 5): - try: - idx = temp_rep.wires.ids_index_dicts[t][id] - n_idx = wires_result.ids_index_dicts[t][id] - idx_reps[n_idx] = temp_rep._idx_reps[idx] - break - except KeyError: - continue - return idx_reps - def __eq__(self, other): if isinstance(other, Representation): - return ( - self.ansatz == other.ansatz - and self.wires == other.wires - and self._idx_reps == other._idx_reps - ) + return self.ansatz == other.ansatz and self.wires == other.wires return False def __matmul__(self, other: Representation): @@ -347,5 +268,4 @@ def __matmul__(self, other: Representation): rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] rep = rep.reorder(perm) if perm else rep - idx_reps = self._matmul_idx_reps(wires_result, other) - return Representation(rep, wires_result, idx_reps) + return Representation(rep, wires_result) diff --git a/tests/test_physics/test_representations.py b/tests/test_physics/test_representations.py index 52d047bad..9054d550f 100644 --- a/tests/test_physics/test_representations.py +++ b/tests/test_physics/test_representations.py @@ -58,56 +58,33 @@ def test_init(self, triple): empty_rep = Representation() assert empty_rep.ansatz is None assert empty_rep.wires == Wires() - assert empty_rep._idx_reps == {} ansatz = PolyExpAnsatz(*triple) wires = Wires(set([0, 1])) rep = Representation(ansatz, wires) assert rep.ansatz == ansatz assert rep.wires == wires - assert rep._idx_reps == dict.fromkeys(wires.indices, (RepEnum.from_ansatz(ansatz), None)) - - @pytest.mark.parametrize("triple", [Abc_n2]) - def test_adjoint_idx_reps(self, triple): - ansatz = PolyExpAnsatz(*triple) - wires = Wires(modes_out_bra=set([0]), modes_out_ket=set([0])) - idx_reps = {0: (RepEnum.BARGMANN, None), 1: (RepEnum.QUADRATURE, 0.1)} - rep = Representation(ansatz, wires, idx_reps) - adj_rep = rep.adjoint - assert adj_rep._idx_reps == { - 1: (RepEnum.BARGMANN, None), - 0: (RepEnum.QUADRATURE, 0.1), - } - - @pytest.mark.parametrize("triple", [Abc_n2]) - def test_dual_idx_reps(self, triple): - ansatz = PolyExpAnsatz(*triple) - wires = Wires(modes_out_bra=set([0]), modes_in_bra=set([0])) - idx_reps = {0: (RepEnum.BARGMANN, None), 1: (RepEnum.QUADRATURE, 0.1)} - rep = Representation(ansatz, wires, idx_reps) - adj_rep = rep.dual - assert adj_rep._idx_reps == { - 1: (RepEnum.BARGMANN, None), - 0: (RepEnum.QUADRATURE, 0.1), - } def test_matmul_btoq(self, d_gate_rep, btoq_rep): q_dgate = d_gate_rep @ btoq_rep - assert q_dgate._idx_reps == { - 0: (RepEnum.QUADRATURE, 0.2), - 1: (RepEnum.BARGMANN, None), - } + for w in q_dgate.wires.input.wires: + assert w.repr == RepEnum.BARGMANN + for w in q_dgate.wires.output.wires: + assert w.repr == RepEnum.QUADRATURE + assert w.param == [0.2] def test_to_bargmann(self, d_gate_rep): d_fock = d_gate_rep.to_fock(shape=(4, 6)) d_barg = d_fock.to_bargmann() assert d_fock.ansatz._original_abc_data == d_gate_rep.ansatz.triple assert d_barg == d_gate_rep - assert all((k[0] == RepEnum.BARGMANN for k in d_barg._idx_reps.values())) + for w in d_barg.wires.wires: + assert w.repr == RepEnum.BARGMANN def test_to_fock(self, d_gate_rep): d_fock = d_gate_rep.to_fock(shape=(4, 6)) assert d_fock.ansatz == ArrayAnsatz( math.hermite_renormalized(*displacement_gate_Abc(x=0.1, y=0.1), shape=(4, 6)) ) - assert all((k[0] == RepEnum.FOCK for k in d_fock._idx_reps.values())) + for w in d_fock.wires.wires: + assert w.repr == RepEnum.FOCK From dbeec3195d402203475a3dcd29ccef1eddef35c6 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Fri, 15 Nov 2024 14:44:10 -0800 Subject: [PATCH 11/13] fixed permutation --- mrmustard/lab_dev/circuit_components.py | 6 +- mrmustard/lab_dev/states/number.py | 4 +- .../lab_dev/states/quadrature_eigenstate.py | 2 +- mrmustard/lab_dev/states/sauron.py | 2 +- mrmustard/physics/representations.py | 49 +---- mrmustard/physics/wires.py | 172 ++++++++++-------- tests/test_lab_dev/test_circuit_components.py | 20 +- tests/test_lab_dev/test_states/test_dm.py | 2 +- tests/test_lab_dev/test_states/test_ket.py | 2 +- 9 files changed, 125 insertions(+), 134 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index d0e1297df..0ad141142 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -84,7 +84,7 @@ def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: if "name" in params: # assume abstract type, serialize the representation ansatz_cls = type(self.ansatz) serializable["name"] = self.name - serializable["wires"] = self.wires.sorted_args + serializable["wires"] = self.wires.args serializable["ansatz_cls"] = f"{ansatz_cls.__module__}.{ansatz_cls.__qualname__}" return serializable, self.ansatz.to_dict() @@ -698,9 +698,9 @@ def __rshift__(self, other: CircuitComponent | numbers.Number) -> CircuitCompone if only_ket or only_bra or both_sides: ret = self @ other elif self_needs_bra or self_needs_ket: - ret = (self.adjoint @ self) @ other + ret = self.adjoint @ (self @ other) elif other_needs_bra or other_needs_ket: - ret = self @ (other @ other.adjoint) + ret = (self @ other) @ other.adjoint else: msg = f"``>>`` not supported between {self} and {other} because it's not clear " msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components." diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index ffa9db1a8..d6a64b4fe 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -83,6 +83,6 @@ def __init__( for i, cutoff in enumerate(self.cutoffs.value): self.manual_shape[i] = int(cutoff) + 1 - for w in self.representation.wires.input.wires: + for w in self.representation.wires.output.wires: w.repr = ReprEnum.FOCK - w.repr_params = [int(self.n.value[w.mode])] + w.repr_params = [int(self.n.value[w.index])] diff --git a/mrmustard/lab_dev/states/quadrature_eigenstate.py b/mrmustard/lab_dev/states/quadrature_eigenstate.py index 2c1671e0e..b00ffeeb3 100644 --- a/mrmustard/lab_dev/states/quadrature_eigenstate.py +++ b/mrmustard/lab_dev/states/quadrature_eigenstate.py @@ -80,7 +80,7 @@ def __init__( for w in self.representation.wires.input.wires: w.repr = ReprEnum.QUADRATURE - w.repr_params = [float(self.x.value[w.mode]), float(self.phi.value[w.mode])] + w.repr_params = [float(self.x.value[w.index]), float(self.phi.value[w.index])] @property def L2_norm(self): diff --git a/mrmustard/lab_dev/states/sauron.py b/mrmustard/lab_dev/states/sauron.py index ffe040a17..f7374f8e1 100644 --- a/mrmustard/lab_dev/states/sauron.py +++ b/mrmustard/lab_dev/states/sauron.py @@ -53,4 +53,4 @@ def __init__(self, modes: Sequence[int], n: int, epsilon: float = 0.1): for w in self.representation.wires.input.wires: w.repr = ReprEnum.FOCK - w.repr_params = [float(self.n.value[w.mode]), float(self.epsilon.value)] + w.repr_params = [float(self.n.value[w.index]), float(self.epsilon.value[w.index])] diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 468607e7a..0e25a3a2c 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -48,52 +48,14 @@ class Representation: Args: ansatz: An ansatz for this representation. - wires: The wires of this representation. Alternatively, can be - a ``(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket)`` - sequence where if any of the modes are out of order the ansatz - will be reordered. - idx_reps: An optional dictionary for keeping track of each wire's representation. + wires: The wires of this representation. """ - def __init__( - self, ansatz: Ansatz | None = None, wires: Wires | Sequence[tuple[int]] | None = None - ) -> None: + def __init__(self, ansatz: Ansatz | None = None, wires: Wires | None = None) -> None: self._ansatz = ansatz - - if wires is None: - wires = Wires(set(), set(), set(), set()) - elif not isinstance(wires, Wires): - modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket = [ - tuple(elem) for elem in wires - ] - wires = Wires( - set(modes_out_bra), - set(modes_in_bra), - set(modes_out_ket), - set(modes_in_ket), - ) - # handle out-of-order modes - ob = tuple(sorted(modes_out_bra)) - ib = tuple(sorted(modes_in_bra)) - ok = tuple(sorted(modes_out_ket)) - ik = tuple(sorted(modes_in_ket)) - if ( - ob != modes_out_bra - or ib != modes_in_bra - or ok != modes_out_ket - or ik != modes_in_ket - ): - offsets = [len(ob), len(ob) + len(ib), len(ob) + len(ib) + len(ok)] - perm = ( - tuple(np.argsort(modes_out_bra)) - + tuple(np.argsort(modes_in_bra) + offsets[0]) - + tuple(np.argsort(modes_out_ket) + offsets[1]) - + tuple(np.argsort(modes_in_ket) + offsets[2]) - ) - if ansatz is not None: - self._ansatz = ansatz.reorder(tuple(perm)) - - self._wires = wires + self._wires = wires or Wires(set(), set(), set(), set()) + if (perm := self.wires.perm()) and self.ansatz is not None: + self._ansatz = self.ansatz.reorder(perm) @property def adjoint(self) -> Representation: @@ -265,7 +227,6 @@ def __matmul__(self, other: Representation): else: self_ansatz = self.to_bargmann().ansatz other_ansatz = other.to_bargmann().ansatz - rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] rep = rep.reorder(perm) if perm else rep return Representation(rep, wires_result) diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index 9a0d44e3e..8cf52d487 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -5,7 +5,7 @@ from enum import Enum, auto from IPython.display import display from functools import lru_cache, cached_property - +import numpy as np from mrmustard import widgets __all__ = ["Wires"] @@ -79,7 +79,7 @@ def __hash__(self) -> int: return hash((self.mode, self.is_out, self.is_ket, self.repr)) def __repr__(self) -> str: - return f"QuantumWire(mode={self.mode}, out={self.is_out}, ket={self.is_ket}, repr={self.repr}, index={self.index})" + return f"QuantumWire(mode={self.mode}, {'out' if self.is_out else 'in'}, {'ket' if self.is_ket else 'bra'}, repr={self.repr}, index={self.index})" def __eq__(self, other: QuantumWire) -> bool: return ( @@ -100,6 +100,13 @@ def copy(self) -> QuantumWire: id=self.id, ) + def _order(self) -> int: + """ + Artificial ordering for sorting quantum wires. + Order achieved is by bra/ket, then out/in, then mode. + """ + return self.mode + 10_000 * (1 - 2 * self.is_out) - 100_000 * (1 - 2 * self.is_ket) + @dataclass(slots=True) class ClassicalWire: @@ -145,6 +152,13 @@ def copy(self) -> ClassicalWire: id=self.id, ) + def _order(self) -> int: + """ + Artificial ordering for sorting classical wires. + Order achieved is by out/in, then mode and they always come after quantum wires. + """ + return 1000_000 + self.mode + 10_000 * (1 - 2 * self.is_out) + class Wires: # pylint: disable=too-many-public-methods r""" @@ -287,16 +301,24 @@ class Wires: # pylint: disable=too-many-public-methods def __init__( self, - modes_out_bra: Sequence[int] = (), - modes_in_bra: Sequence[int] = (), - modes_out_ket: Sequence[int] = (), - modes_in_ket: Sequence[int] = (), - classical_out: Sequence[int] = (), - classical_in: Sequence[int] = (), + modes_out_bra: set[int] | Sequence[int] = set(), + modes_in_bra: set[int] | Sequence[int] = set(), + modes_out_ket: set[int] | Sequence[int] = set(), + modes_in_ket: set[int] | Sequence[int] = set(), + classical_out: set[int] | Sequence[int] = set(), + classical_in: set[int] | Sequence[int] = set(), ): self.quantum_wires = set() self.classical_wires = set() + # store the permutation to reorder wires if they are given out of order + self._initial_perm = [] + groups = [modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket] + if any(not isinstance(group, set) for group in groups): + for group in groups: + self._initial_perm.extend(tuple(np.argsort(group) + len(self._initial_perm))) + + # go ahead and add the wires in standard order for i, m in enumerate(sorted(modes_out_bra)): self.quantum_wires.add(QuantumWire(mode=m, is_out=True, is_ket=False, index=i)) n = len(modes_out_bra) @@ -329,8 +351,9 @@ def adjoint(self) -> Wires: r""" New ``Wires`` object with the adjoint quantum wires (ket becomes bra and vice versa). """ - w = self.copy() - w.quantum_wires = {replace(q, is_ket=not q.is_ket) for q in w.quantum_wires} + w = Wires() + w.quantum_wires = {replace(q, is_ket=not q.is_ket) for q in self.quantum_wires} + w.classical_wires = {c.copy() for c in self.classical_wires} return w @cached_property @@ -338,17 +361,17 @@ def dual(self) -> Wires: r""" New ``Wires`` object with dual quantum and classical wires (input becomes output and vice versa). """ - w = self.copy() - w.quantum_wires = {replace(q, is_out=not q.is_out) for q in w.quantum_wires} - w.classical_wires = {replace(c, is_out=not c.is_out) for c in w.classical_wires} + w = Wires() + w.quantum_wires = {replace(q, is_out=not q.is_out) for q in self.quantum_wires} + w.classical_wires = {replace(c, is_out=not c.is_out) for c in self.classical_wires} return w ###### SUBSETS ###### - @lru_cache + # @lru_cache def __getitem__(self, modes: tuple[int, ...] | int) -> Wires: """ - Returns the quantum and classical wires with the given modes. + Returns a new Wires object with references to the quantum and classical wires with the given modes. """ modes = {modes} if isinstance(modes, int) else set(modes) w = Wires() @@ -359,7 +382,8 @@ def __getitem__(self, modes: tuple[int, ...] | int) -> Wires: @cached_property def classical(self) -> Wires: r""" - New ``Wires`` object with only classical wires. + New ``Wires`` object with references to only classical wires. + Note that the wires are not copied. """ w = Wires() w.classical_wires = self.classical_wires @@ -368,7 +392,8 @@ def classical(self) -> Wires: @cached_property def quantum(self) -> Wires: r""" - New ``Wires`` object with only quantum wires. + New ``Wires`` object with references to only quantum wires. + Note that the wires are not copied. """ w = Wires() w.quantum_wires = self.quantum_wires @@ -377,7 +402,8 @@ def quantum(self) -> Wires: @cached_property def bra(self) -> Wires: r""" - New ``Wires`` object with only quantum bra wires. + New ``Wires`` object with references to only quantum bra wires. + Note that the wires are not copied. """ w = Wires() w.quantum_wires = {q for q in self.quantum_wires if not q.is_ket} @@ -386,7 +412,8 @@ def bra(self) -> Wires: @cached_property def ket(self) -> Wires: r""" - New ``Wires`` object with only quantum ket wires. + New ``Wires`` object with references to only quantum ket wires. + Note that the wires are not copied. """ w = Wires() w.quantum_wires = {q for q in self.quantum_wires if q.is_ket} @@ -395,7 +422,8 @@ def ket(self) -> Wires: @cached_property def input(self) -> Wires: r""" - New ``Wires`` object with only classical and quantum input wires. + New ``Wires`` object with references to only classical and quantum input wires. + Note that the wires are not copied. """ w = Wires() w.quantum_wires = {q for q in self.quantum_wires if not q.is_out} @@ -405,7 +433,8 @@ def input(self) -> Wires: @cached_property def output(self) -> Wires: r""" - New ``Wires`` object with only classical and quantum output wires. + New ``Wires`` object with references to only classical and quantum output wires. + Note that the wires are not copied. """ w = Wires() w.quantum_wires = {q for q in self.quantum_wires if q.is_out} @@ -419,35 +448,35 @@ def id(self) -> int: """Returns a unique identifier for this Wires object.""" return randint(0, 2**32 - 1) - @property + @cached_property def modes(self) -> set[int]: r""" The modes spanned by the wires. """ return {q.mode for q in self.quantum_wires} | {c.mode for c in self.classical_wires} - @property + @cached_property def ids(self) -> tuple[int, ...]: r""" The ids of the wires in standard order. """ return tuple(w.id for w in self.sorted_wires) - @property + @cached_property def indices(self) -> tuple[int, ...]: r""" The indices of the wires in standard order. """ return tuple(w.index for w in self.sorted_wires) - @property + @cached_property def DV_indices(self) -> tuple[int, ...]: r""" The indices of the DV wires (both quantum and classical) in standard order. """ return tuple(q.index for q in self.DV_wires) - @property + @cached_property def CV_indices(self) -> tuple[int, ...]: r""" The indices of the CV wires (both quantum and classical) in standard order. @@ -459,16 +488,16 @@ def DV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: r""" The DV wires in standard order. """ - return tuple(w for w in self.sorted_wires.copy() if w.is_dv) + return tuple(w for w in self.sorted_wires if w.is_dv) @cached_property def CV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: r""" The CV wires in standard order. """ - return tuple(w for w in self.sorted_wires.copy() if not w.is_dv) + return tuple(w for w in self.sorted_wires if not w.is_dv) - @property + @cached_property def args(self) -> tuple[set[int], ...]: r""" The arguments to pass to ``Wires`` to create the same object with fresh wires. @@ -494,42 +523,45 @@ def sorted_wires(self) -> list[QuantumWire | ClassicalWire]: r""" A list of all wires in standard order. """ - return [ - *sorted(self.bra.output.wires, key=lambda s: s.mode), - *sorted(self.bra.input.wires, key=lambda s: s.mode), - *sorted(self.ket.output.wires, key=lambda s: s.mode), - *sorted(self.ket.input.wires, key=lambda s: s.mode), - *sorted(self.classical.output.wires, key=lambda s: s.mode), - *sorted(self.classical.input.wires, key=lambda s: s.mode), - ] + return sorted(self.wires, key=lambda s: s._order()) ###### METHODS ###### - def wire(self, mode: int, is_out: bool, is_ket: bool) -> QuantumWire | ClassicalWire: + def perm(self) -> tuple[int, ...] | None: r""" - Returns the wire with the given mode, ket, and output status. + The permutation that standardizes the wires with respect to how they were initialized. + None if already in standard order. """ - if quantum := [ - w - for w in self.quantum_wires - if w.mode == mode and w.is_out == is_out and w.is_ket == is_ket - ]: - return quantum[0] - if classical := [w for w in self.classical_wires if w.mode == mode and w.is_out == is_out]: - return classical[0] - raise ValueError(f"No wire with mode {mode}, is_out {is_out}, and is_ket {is_ket}.") + return ( + tuple(self._initial_perm) if sorted(self._initial_perm) != self._initial_perm else None + ) + + # def wire(self, mode: int, is_out: bool, is_ket: bool) -> QuantumWire | ClassicalWire: + # r""" + # Returns the wire with the given mode, ket, and output status. + # """ + # if quantum := [ + # w + # for w in self.quantum_wires + # if w.mode == mode and w.is_out == is_out and w.is_ket == is_ket + # ]: + # return quantum[0] + # if classical := [w for w in self.classical_wires if w.mode == mode and w.is_out == is_out]: + # return classical[0] + # raise ValueError(f"No wire with mode {mode}, is_out {is_out}, and is_ket {is_ket}.") def _reindex(self) -> None: r""" Updates the indices of the wires according to the standard order. """ - for i, w in enumerate(self.wires): + for i, w in enumerate(self.sorted_wires): w.index = i def __add__(self, other: Wires) -> Wires: r""" - New ``Wires`` object that combines the wires of self and other. + New ``Wires`` object with references to the wires of self and other. If there are overlapping wires (same mode, is_ket, is_out), raises a ValueError. + Note that the wires are not reindexed nor copied. Use with caution. """ if ovlp_classical := self.classical_wires & other.classical_wires: raise ValueError(f"Overlapping classical wires {ovlp_classical}.") @@ -538,25 +570,26 @@ def __add__(self, other: Wires) -> Wires: w = Wires() w.quantum_wires = self.quantum_wires | other.quantum_wires w.classical_wires = self.classical_wires | other.classical_wires - w._reindex() return w + def __iter__(self) -> Iterator[QuantumWire | ClassicalWire]: + return iter(self.sorted_wires) + def __sub__(self, other: Wires) -> Wires: r""" - New ``Wires`` object that removes the wires of other from self, by mode. - Note it does not look at ket, bra, input or output: just the mode. Use with caution. + New ``Wires`` object with references to the wires of self whose modes are not in other. + Note that the wires are not reindexed nor copied. Use with caution. """ w = Wires() - w.quantum_wires = {q for q in self.quantum_wires.copy() if q.mode not in other.modes} - w.classical_wires = {c for c in self.classical_wires.copy() if c.mode not in other.modes} - w._reindex() + w.quantum_wires = {q for q in self.quantum_wires if q.mode not in other.modes} + w.classical_wires = {c for c in self.classical_wires if c.mode not in other.modes} return w def __bool__(self) -> bool: return bool(self.quantum_wires) or bool(self.classical_wires) def __hash__(self) -> int: - return hash(tuple(tuple(sorted(subset)) for subset in self.args)) + return hash((tuple(self.classical_wires), tuple(self.quantum_wires))) def __eq__(self, other: Wires) -> bool: return ( @@ -621,23 +654,20 @@ def __matmul__(self, other: Wires) -> tuple[Wires, list[int], list[int]]: cl_in = self.classical.input + (other.classical.input - self.classical.output) # get the wires - w = Wires() - w.quantum_wires = (bra_out + bra_in + ket_out + ket_in).wires - w.classical_wires = (cl_out + cl_in).wires - w._reindex() + new_wires = Wires() + new_wires.quantum_wires = { + q.copy() for q in bra_out.wires | bra_in.wires | ket_out.wires | ket_in.wires + } + new_wires.classical_wires = {c.copy() for c in cl_out.wires | cl_in.wires} + new_wires._reindex() # get the permutations - CV_ids = [w.id for w in w.CV_wires if w.id in self.ids] + [ - w.id for w in w.CV_wires if w.id in other.ids + CV_combined = [w for w in self.CV_wires if w.id in new_wires.ids] + [ + w for w in other.CV_wires if w.id in new_wires.ids ] - CV_perm = [CV_ids.index(w.id) for w in w.CV_wires] - - # NOTE: this is for when BtoF is in - # DV_ids = [w.id for w in w.DV_wires if w.id in self.ids] + [ - # w.id for w in w.DV_wires if w.id in other.ids - # ] - # DV_perm = [DV_ids.index(w.id) for w in w.DV_wires] - return w, CV_perm + CV_perm = [CV_combined.index(w) for w in new_wires.CV_wires] + + return new_wires, CV_perm def _ipython_display_(self): display(widgets.wires(self)) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index b94652ab2..c01358780 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -66,7 +66,7 @@ class TestCircuitComponent: def test_init(self, x, y): name = "my_component" ansatz = PolyExpAnsatz(*displacement_gate_Abc(x, y)) - cc = CircuitComponent(Representation(ansatz, [(), (), (1, 8), (1, 8)]), name=name) + cc = CircuitComponent(Representation(ansatz, Wires((), (), (1, 8), (1, 8))), name=name) assert cc.name == name assert list(cc.modes) == [1, 8] @@ -77,7 +77,7 @@ def test_init(self, x, y): def test_missing_name(self): cc = CircuitComponent( Representation( - PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.2)), [(), (), (1, 8), (1, 8)] + PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.2)), Wires((), (), (1, 8), (1, 8)) ) ) cc._name = None @@ -94,13 +94,13 @@ def test_modes_init_out_of_order(self): a1 = PolyExpAnsatz(*displacement_gate_Abc(x=[0.1, 0.2])) a2 = PolyExpAnsatz(*displacement_gate_Abc(x=[0.2, 0.1])) - cc1 = CircuitComponent(Representation(a1, wires=[(), (), m1, m1])) - cc2 = CircuitComponent(Representation(a2, wires=[(), (), m2, m2])) + cc1 = CircuitComponent(Representation(a1, Wires((), (), m1, m1))) + cc2 = CircuitComponent(Representation(a2, Wires((), (), m2, m2))) assert cc1 == cc2 a3 = (cc1.adjoint @ cc1).ansatz - cc3 = CircuitComponent(Representation(a3, wires=[m2, m2, m2, m1])) - cc4 = CircuitComponent(Representation(a3, wires=[m2, m2, m2, m2])) + cc3 = CircuitComponent(Representation(a3, Wires(m2, m2, m2, m1))) + cc4 = CircuitComponent(Representation(a3, Wires(m2, m2, m2, m2))) assert cc3.ansatz == cc4.ansatz.reorder([0, 1, 2, 3, 4, 5, 7, 6]) @pytest.mark.parametrize("x", [0.1, [0.2, 0.3]]) @@ -165,7 +165,7 @@ def test_dual(self): def test_light_copy(self): d1 = CircuitComponent( Representation( - PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.1)), wires=[(), (), (1,), (1,)] + PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.1)), Wires((), (), (1,), (1,)) ) ) d1_cp = d1._light_copy() @@ -226,7 +226,7 @@ def test_to_fock_poly_exp(self): A, b, _ = Abc_triple(3) c = np.random.random((1, 5)) polyexp = PolyExpAnsatz(A, b, c) - fock_cc = CircuitComponent(Representation(polyexp, wires=[(), (), (0, 1), ()])).to_fock( + fock_cc = CircuitComponent(Representation(polyexp, Wires((), (), (0, 1), ()))).to_fock( shape=(10, 10) ) poly = math.hermite_renormalized(A, b, 1, (10, 10, 5)) @@ -563,7 +563,7 @@ def test_serialize_default_behaviour(self): """Test the default serializer.""" name = "my_component" ansatz = PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.4)) - cc = CircuitComponent(Representation(ansatz, wires=[(), (), (1, 8), (1, 8)]), name=name) + cc = CircuitComponent(Representation(ansatz, Wires((), (), (1, 8), (1, 8))), name=name) kwargs, arrays = cc._serialize() assert kwargs == { "class": f"{CircuitComponent.__module__}.CircuitComponent", @@ -581,7 +581,7 @@ class MyComponent(CircuitComponent): def __init__(self, ansatz, custom_modes): super().__init__( - Representation(ansatz, wires=[custom_modes] * 4), name="my_component" + Representation(ansatz, Wires(*(custom_modes * 4))), name="my_component" ) cc = MyComponent(PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.4)), [0, 1]) diff --git a/tests/test_lab_dev/test_states/test_dm.py b/tests/test_lab_dev/test_states/test_dm.py index 61ef0047c..cb1f92b3f 100644 --- a/tests/test_lab_dev/test_states/test_dm.py +++ b/tests/test_lab_dev/test_states/test_dm.py @@ -336,7 +336,7 @@ def test_expectation_error(self): with pytest.raises(ValueError, match="Cannot calculate the expectation value"): dm.expectation(op1) - op2 = CircuitComponent(Representation(wires=[(), (), (1,), (0,)])) + op2 = CircuitComponent(Representation(wires=Wires((), (), (1,), (0,)))) with pytest.raises(ValueError, match="different modes"): dm.expectation(op2) diff --git a/tests/test_lab_dev/test_states/test_ket.py b/tests/test_lab_dev/test_states/test_ket.py index 071ef0c32..4e43d5c51 100644 --- a/tests/test_lab_dev/test_states/test_ket.py +++ b/tests/test_lab_dev/test_states/test_ket.py @@ -330,7 +330,7 @@ def test_expectation_error(self): with pytest.raises(ValueError, match="Cannot calculate the expectation value"): ket.expectation(op1) - op2 = CircuitComponent(Representation(wires=[(), (), (1,), (0,)])) + op2 = CircuitComponent(Representation(wires=Wires((), (), (1,), (0,)))) with pytest.raises(ValueError, match="different modes"): ket.expectation(op2) From 90333ff78c5d579e74faa86bd31c92565ece9802 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Fri, 15 Nov 2024 14:45:39 -0800 Subject: [PATCH 12/13] updated gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4ca2582f2..b14b2985f 100644 --- a/.gitignore +++ b/.gitignore @@ -21,5 +21,5 @@ doc/code/api/* coverage.xml .coverage /.serialize_cache/ - +.cursorrules .venv From ea6ed3f6dcb51d019d0e7c5ec49d0a61dd4b22ab Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Mon, 18 Nov 2024 09:16:22 -0800 Subject: [PATCH 13/13] sync --- mrmustard/physics/wires.py | 35 ++++++++----------- tests/test_lab_dev/test_circuit_components.py | 2 +- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index 8cf52d487..25bcebb3d 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -460,14 +460,14 @@ def ids(self) -> tuple[int, ...]: r""" The ids of the wires in standard order. """ - return tuple(w.id for w in self.sorted_wires) + return tuple(w.id for w in self.wires) @cached_property def indices(self) -> tuple[int, ...]: r""" The indices of the wires in standard order. """ - return tuple(w.index for w in self.sorted_wires) + return tuple(w.index for w in self.wires) @cached_property def DV_indices(self) -> tuple[int, ...]: @@ -488,14 +488,14 @@ def DV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: r""" The DV wires in standard order. """ - return tuple(w for w in self.sorted_wires if w.is_dv) + return tuple(w for w in self.wires if w.is_dv) @cached_property def CV_wires(self) -> tuple[QuantumWire | ClassicalWire, ...]: r""" The CV wires in standard order. """ - return tuple(w for w in self.sorted_wires if not w.is_dv) + return tuple(w for w in self.wires if not w.is_dv) @cached_property def args(self) -> tuple[set[int], ...]: @@ -512,18 +512,11 @@ def args(self) -> tuple[set[int], ...]: ) @cached_property - def wires(self) -> set[QuantumWire | ClassicalWire]: - r""" - A set of all wires. - """ - return {*self.quantum_wires, *self.classical_wires} - - @cached_property - def sorted_wires(self) -> list[QuantumWire | ClassicalWire]: + def wires(self) -> list[QuantumWire | ClassicalWire]: r""" A list of all wires in standard order. """ - return sorted(self.wires, key=lambda s: s._order()) + return sorted({*self.quantum_wires, *self.classical_wires}, key=lambda s: s._order()) ###### METHODS ###### @@ -554,7 +547,7 @@ def _reindex(self) -> None: r""" Updates the indices of the wires according to the standard order. """ - for i, w in enumerate(self.sorted_wires): + for i, w in enumerate(self.wires): w.index = i def __add__(self, other: Wires) -> Wires: @@ -573,7 +566,7 @@ def __add__(self, other: Wires) -> Wires: return w def __iter__(self) -> Iterator[QuantumWire | ClassicalWire]: - return iter(self.sorted_wires) + return iter(self.wires) def __sub__(self, other: Wires) -> Wires: r""" @@ -656,18 +649,18 @@ def __matmul__(self, other: Wires) -> tuple[Wires, list[int], list[int]]: # get the wires new_wires = Wires() new_wires.quantum_wires = { - q.copy() for q in bra_out.wires | bra_in.wires | ket_out.wires | ket_in.wires + q.copy() for q in bra_out.wires + bra_in.wires + ket_out.wires + ket_in.wires } - new_wires.classical_wires = {c.copy() for c in cl_out.wires | cl_in.wires} + new_wires.classical_wires = {c.copy() for c in cl_out.wires + cl_in.wires} new_wires._reindex() # get the permutations - CV_combined = [w for w in self.CV_wires if w.id in new_wires.ids] + [ - w for w in other.CV_wires if w.id in new_wires.ids + combined = [w for w in self.wires if w.id in new_wires.ids] + [ + w for w in other.wires if w.id in new_wires.ids ] - CV_perm = [CV_combined.index(w) for w in new_wires.CV_wires] + perm = [combined.index(w) for w in new_wires.wires] - return new_wires, CV_perm + return new_wires, perm def _ipython_display_(self): display(widgets.wires(self)) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index c01358780..64f824641 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -581,7 +581,7 @@ class MyComponent(CircuitComponent): def __init__(self, ansatz, custom_modes): super().__init__( - Representation(ansatz, Wires(*(custom_modes * 4))), name="my_component" + Representation(ansatz, Wires(*([custom_modes] * 4))), name="my_component" ) cc = MyComponent(PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.4)), [0, 1])