Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rich wires #523

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ doc/code/api/*
coverage.xml
.coverage
/.serialize_cache/

.cursorrules
.venv
6 changes: 3 additions & 3 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]:
if "name" in params: # assume abstract type, serialize the representation
ansatz_cls = type(self.ansatz)
serializable["name"] = self.name
serializable["wires"] = self.wires.sorted_args
serializable["wires"] = self.wires.args
serializable["ansatz_cls"] = f"{ansatz_cls.__module__}.{ansatz_cls.__qualname__}"
return serializable, self.ansatz.to_dict()

Expand Down Expand Up @@ -698,9 +698,9 @@ def __rshift__(self, other: CircuitComponent | numbers.Number) -> CircuitCompone
if only_ket or only_bra or both_sides:
ret = self @ other
elif self_needs_bra or self_needs_ket:
ret = (self.adjoint @ self) @ other
ret = self.adjoint @ (self @ other)
elif other_needs_bra or other_needs_ket:
ret = self @ (other @ other.adjoint)
ret = (self @ other) @ other.adjoint
else:
msg = f"``>>`` not supported between {self} and {other} because it's not clear "
msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components."
Expand Down
9 changes: 4 additions & 5 deletions mrmustard/lab_dev/circuit_components_utils/b_to_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..transformations.base import Map
from ...physics.ansatz import PolyExpAnsatz
from ...physics.representations import RepEnum
from ...physics.wires import ReprEnum
from ..utils import make_parameter

__all__ = ["BtoPS"]
Expand Down Expand Up @@ -53,10 +53,9 @@ def __init__(
fn=triples.displacement_map_s_parametrized_Abc, s=self.s, n_modes=len(modes)
),
).representation
for i in self.wires.input.indices:
self.representation._idx_reps[i] = (RepEnum.BARGMANN, None)
for i in self.wires.output.indices:
self.representation._idx_reps[i] = (RepEnum.PHASESPACE, float(self.s.value))
for w in self.representation.wires.output.wires:
w.repr = ReprEnum.CHARACTERISTIC
w.repr_params = float(self.s.value)

def inverse(self):
ret = BtoPS(self.modes, self.s)
Expand Down
12 changes: 7 additions & 5 deletions mrmustard/lab_dev/circuit_components_utils/b_to_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..transformations.base import Operation
from ...physics.ansatz import PolyExpAnsatz
from ...physics.representations import RepEnum
from ...physics.wires import ReprEnum
from ..utils import make_parameter

__all__ = ["BtoQ"]
Expand Down Expand Up @@ -53,10 +53,12 @@ def __init__(
fn=triples.bargmann_to_quadrature_Abc, n_modes=len(modes), phi=self.phi
),
).representation
for i in self.wires.input.indices:
self.representation._idx_reps[i] = (RepEnum.BARGMANN, None)
for i in self.wires.output.indices:
self.representation._idx_reps[i] = (RepEnum.QUADRATURE, float(self.phi.value))
for w in self.representation.wires.input.wires:
w.repr = ReprEnum.BARGMANN
w.repr_params = None
for w in self.representation.wires.output.wires:
w.repr = ReprEnum.QUADRATURE
w.repr_params = float(self.phi.value)

def inverse(self):
ret = BtoQ(self.modes, self.phi)
Expand Down
5 changes: 5 additions & 0 deletions mrmustard/lab_dev/states/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Sequence

from mrmustard.physics.ansatz import ArrayAnsatz
from mrmustard.physics.wires import ReprEnum
from mrmustard.physics.fock_utils import fock_state
from .ket import Ket
from ..utils import make_parameter, reshape_params
Expand Down Expand Up @@ -81,3 +82,7 @@ def __init__(
self.short_name = [str(int(n)) for n in self.n.value]
for i, cutoff in enumerate(self.cutoffs.value):
self.manual_shape[i] = int(cutoff) + 1

for w in self.representation.wires.output.wires:
w.repr = ReprEnum.FOCK
w.repr_params = [int(self.n.value[w.index])]
5 changes: 5 additions & 0 deletions mrmustard/lab_dev/states/quadrature_eigenstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from mrmustard.physics.ansatz import PolyExpAnsatz
from mrmustard.physics import triples
from mrmustard.physics.wires import ReprEnum
from .ket import Ket
from ..utils import make_parameter, reshape_params

Expand Down Expand Up @@ -77,6 +78,10 @@ def __init__(
),
).representation

for w in self.representation.wires.input.wires:
w.repr = ReprEnum.QUADRATURE
w.repr_params = [float(self.x.value[w.index]), float(self.phi.value[w.index])]

@property
def L2_norm(self):
r"""
Expand Down
6 changes: 5 additions & 1 deletion mrmustard/lab_dev/states/sauron.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mrmustard.lab_dev.states.ket import Ket
from mrmustard.physics.ansatz import PolyExpAnsatz
from mrmustard.physics import triples

from mrmustard.physics.wires import ReprEnum
from ..utils import make_parameter


Expand Down Expand Up @@ -50,3 +50,7 @@ def __init__(self, modes: Sequence[int], n: int, epsilon: float = 0.1):
triples.sauron_state_Abc, n=self.n.value, epsilon=self.epsilon.value
),
).representation

for w in self.representation.wires.input.wires:
w.repr = ReprEnum.FOCK
w.repr_params = [float(self.n.value[w.index]), float(self.epsilon.value[w.index])]
139 changes: 10 additions & 129 deletions mrmustard/physics/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from __future__ import annotations
from typing import Sequence
from enum import Enum

import numpy as np

Expand All @@ -32,41 +31,11 @@

from .ansatz import Ansatz, PolyExpAnsatz, ArrayAnsatz
from .triples import identity_Abc
from .wires import Wires
from .wires import Wires, ReprEnum

__all__ = ["Representation"]


class RepEnum(Enum):
r"""
An enum to represent what representation a wire is in.
"""

NONETYPE = 0
BARGMANN = 1
FOCK = 2
QUADRATURE = 3
PHASESPACE = 4

@classmethod
def from_ansatz(cls, ansatz: Ansatz):
r"""
Returns a ``RepEnum`` from an ``Ansatz``.

Args:
ansatz: The ansatz.
"""
if isinstance(ansatz, PolyExpAnsatz):
return cls(1)
elif isinstance(ansatz, ArrayAnsatz):
return cls(2)
else:
return cls(0)

def __repr__(self) -> str:
return self.name


class Representation:
r"""
A class for representations.
Expand All @@ -79,58 +48,14 @@ class Representation:

Args:
ansatz: An ansatz for this representation.
wires: The wires of this representation. Alternatively, can be
a ``(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket)``
sequence where if any of the modes are out of order the ansatz
will be reordered.
idx_reps: An optional dictionary for keeping track of each wire's representation.
wires: The wires of this representation.
"""

def __init__(
self,
ansatz: Ansatz | None = None,
wires: Wires | Sequence[tuple[int]] | None = None,
idx_reps: dict | None = None,
) -> None:
def __init__(self, ansatz: Ansatz | None = None, wires: Wires | None = None) -> None:
self._ansatz = ansatz

if wires is None:
wires = Wires(set(), set(), set(), set())
elif not isinstance(wires, Wires):
modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket = [
tuple(elem) for elem in wires
]
wires = Wires(
set(modes_out_bra),
set(modes_in_bra),
set(modes_out_ket),
set(modes_in_ket),
)
# handle out-of-order modes
ob = tuple(sorted(modes_out_bra))
ib = tuple(sorted(modes_in_bra))
ok = tuple(sorted(modes_out_ket))
ik = tuple(sorted(modes_in_ket))
if (
ob != modes_out_bra
or ib != modes_in_bra
or ok != modes_out_ket
or ik != modes_in_ket
):
offsets = [len(ob), len(ob) + len(ib), len(ob) + len(ib) + len(ok)]
perm = (
tuple(np.argsort(modes_out_bra))
+ tuple(np.argsort(modes_in_bra) + offsets[0])
+ tuple(np.argsort(modes_out_ket) + offsets[1])
+ tuple(np.argsort(modes_in_ket) + offsets[2])
)
if ansatz is not None:
self._ansatz = ansatz.reorder(tuple(perm))

self._wires = wires
self._idx_reps = idx_reps or dict.fromkeys(
wires.indices, (RepEnum.from_ansatz(ansatz), None)
)
self._wires = wires or Wires(set(), set(), set(), set())
if (perm := self.wires.perm()) and self.ansatz is not None:
self._ansatz = self.ansatz.reorder(perm)

@property
def adjoint(self) -> Representation:
Expand All @@ -142,12 +67,7 @@ def adjoint(self) -> Representation:
kets = self.wires.ket.indices
ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None
wires = self.wires.adjoint
idx_reps = {}
for i, j in enumerate(kets):
idx_reps[i] = self._idx_reps[j]
for i, j in enumerate(bras):
idx_reps[i + len(kets)] = self._idx_reps[j]
return Representation(ansatz, wires, idx_reps)
return Representation(ansatz, wires)

@property
def ansatz(self) -> Ansatz | None:
Expand All @@ -168,16 +88,7 @@ def dual(self) -> Representation:
ob = self.wires.bra.output.indices
ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None
wires = self.wires.dual
idx_reps = {}
for i, j in enumerate(ib):
idx_reps[i] = self._idx_reps[j]
for i, j in enumerate(ob):
idx_reps[i + len(ib)] = self._idx_reps[j]
for i, j in enumerate(ik):
idx_reps[i + len(ib + ob)] = self._idx_reps[j]
for i, j in enumerate(ok):
idx_reps[i + len(ib + ob + ik)] = self._idx_reps[j]
return Representation(ansatz, wires, idx_reps)
return Representation(ansatz, wires)

@property
def wires(self) -> Wires | None:
Expand Down Expand Up @@ -301,37 +212,9 @@ def _matmul_indices(self, other: Representation) -> tuple[tuple[int, ...], tuple
idx_zconj += other.wires.ket.input[ket_modes].indices
return idx_z, idx_zconj

def _matmul_idx_reps(self, wires_result: Wires, other: Representation):
r"""
Returns the new representation mappings when contracting ``self`` and ``other``.

Args:
wires_result: The resulting wires after contraction.
other: The representation contracting with.
"""
idx_reps = {}
for id in wires_result.ids:
if id in other.wires.ids:
temp_rep = other
else:
temp_rep = self
for t in (0, 1, 2, 3, 4, 5):
try:
idx = temp_rep.wires.ids_index_dicts[t][id]
n_idx = wires_result.ids_index_dicts[t][id]
idx_reps[n_idx] = temp_rep._idx_reps[idx]
break
except KeyError:
continue
return idx_reps

def __eq__(self, other):
if isinstance(other, Representation):
return (
self.ansatz == other.ansatz
and self.wires == other.wires
and self._idx_reps == other._idx_reps
)
return self.ansatz == other.ansatz and self.wires == other.wires
return False

def __matmul__(self, other: Representation):
Expand All @@ -344,8 +227,6 @@ def __matmul__(self, other: Representation):
else:
self_ansatz = self.to_bargmann().ansatz
other_ansatz = other.to_bargmann().ansatz

rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj]
rep = rep.reorder(perm) if perm else rep
idx_reps = self._matmul_idx_reps(wires_result, other)
return Representation(rep, wires_result, idx_reps)
return Representation(rep, wires_result)
Loading
Loading