From a121ebe8f48426d67788e97ed0c9c65163aeb6e3 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 9 Sep 2024 14:49:22 -0400 Subject: [PATCH 01/87] rename --- mrmustard/lab/abstract/state.py | 50 +- mrmustard/lab/abstract/transformation.py | 26 +- mrmustard/lab/detectors.py | 6 +- mrmustard/lab/gates.py | 18 +- mrmustard/lab/states.py | 6 +- mrmustard/lab_dev/circuit_components.py | 2 +- mrmustard/lab_dev/states/base.py | 4 +- mrmustard/lab_dev/states/number.py | 2 +- mrmustard/lab_dev/transformations/base.py | 2 +- mrmustard/physics/__init__.py | 14 +- mrmustard/physics/bargmann.py | 251 ---- mrmustard/physics/fock.py | 1033 ----------------- tests/test_lab/test_gates_fock.py | 26 +- .../test_circuit_components_utils.py | 2 +- tests/test_lab_dev/test_states/test_number.py | 2 +- tests/test_math/test_compactFock.py | 2 +- tests/test_math/test_lattice.py | 2 +- tests/test_physics/test_ansatz.py | 2 +- tests/test_physics/test_bargmann.py | 2 +- tests/test_physics/test_fidelity.py | 2 +- tests/test_physics/test_fock.py | 68 +- 21 files changed, 126 insertions(+), 1396 deletions(-) delete mode 100644 mrmustard/physics/bargmann.py delete mode 100644 mrmustard/physics/fock.py diff --git a/mrmustard/lab/abstract/state.py b/mrmustard/lab/abstract/state.py index e52a52d66..94cec2307 100644 --- a/mrmustard/lab/abstract/state.py +++ b/mrmustard/lab/abstract/state.py @@ -28,7 +28,7 @@ from mrmustard import math, settings from mrmustard.math.parameters import Constant, Variable -from mrmustard.physics import bargmann, fock, gaussian +from mrmustard.physics import bargmann_utils, fock_utils, gaussian from mrmustard.physics.wigner import wigner_discretized from mrmustard.utils.typing import ( ComplexMatrix, @@ -157,7 +157,7 @@ def purity(self) -> float: if self.is_gaussian: self._purity = gaussian.purity(self.cov) else: - self._purity = fock.purity(self._dm) + self._purity = fock_utils.purity(self._dm) return self._purity @property @@ -187,7 +187,7 @@ def number_stdev(self) -> RealVector: return math.sqrt(math.diag_part(self.number_cov)) return math.sqrt( - fock.number_variances(self.fock, is_dm=len(self.fock.shape) == self.num_modes * 2) + fock_utils.number_variances(self.fock, is_dm=len(self.fock.shape) == self.num_modes * 2) ) @property @@ -195,7 +195,7 @@ def cutoffs(self) -> list[int]: r"""Returns the Hilbert space dimension of each mode.""" if self._cutoffs is None: if self._ket is None and self._dm is None: - self._cutoffs = fock.autocutoffs( + self._cutoffs = fock_utils.autocutoffs( self.cov, self.means, settings.AUTOSHAPE_PROBABILITY ) else: @@ -223,7 +223,7 @@ def shape(self) -> list[int]: def fock(self) -> ComplexTensor: r"""Returns the Fock representation of the state.""" if self._dm is None and self._ket is None: - _fock = fock.wigner_to_fock_state( + _fock = fock_utils.wigner_to_fock_state( self.cov, self.means, shape=self.shape, @@ -243,7 +243,7 @@ def number_means(self) -> RealVector: if self.is_gaussian: return gaussian.number_means(self.cov, self.means) - return fock.number_means(tensor=self.fock, is_dm=self.is_mixed) + return fock_utils.number_means(tensor=self.fock, is_dm=self.is_mixed) @property def number_cov(self) -> RealMatrix: @@ -258,7 +258,7 @@ def norm(self) -> float: r"""Returns the norm of the state.""" if self.is_gaussian: return self._norm - return fock.norm(self.fock, not self.is_hilbert_vector) + return fock_utils.norm(self.fock, not self.is_hilbert_vector) @property def probability(self) -> float: @@ -296,7 +296,7 @@ def ket( cutoffs = [c if c is not None else self.cutoffs[i] for i, c in enumerate(cutoffs)] if self.is_gaussian: - self._ket = fock.wigner_to_fock_state( + self._ket = fock_utils.wigner_to_fock_state( self.cov, self.means, shape=cutoffs, @@ -308,12 +308,12 @@ def ket( if self._ket is None: # if state is pure and has a density matrix, calculate the ket if self.is_pure: - self._ket = fock.dm_to_ket(self._dm) + self._ket = fock_utils.dm_to_ket(self._dm) current_cutoffs = [int(s) for s in self._ket.shape] if cutoffs != current_cutoffs: paddings = [(0, max(0, new - old)) for new, old in zip(cutoffs, current_cutoffs)] if any(p != (0, 0) for p in paddings): - padded = fock.math.pad(self._ket, paddings, mode="constant") + padded = fock_utils.math.pad(self._ket, paddings, mode="constant") else: padded = self._ket return padded[tuple(slice(s) for s in cutoffs)] @@ -336,16 +336,16 @@ def dm(self, cutoffs: list[int] | None = None) -> ComplexTensor: if self.is_pure: ket = self.ket(cutoffs=cutoffs) if ket is not None: - return fock.ket_to_dm(ket) + return fock_utils.ket_to_dm(ket) else: if self.is_gaussian: - self._dm = fock.wigner_to_fock_state( + self._dm = fock_utils.wigner_to_fock_state( self.cov, self.means, shape=cutoffs + cutoffs, return_dm=True ) elif cutoffs != (current_cutoffs := list(self._dm.shape[: self.num_modes])): paddings = [(0, max(0, new - old)) for new, old in zip(cutoffs, current_cutoffs)] if any(p != (0, 0) for p in paddings): - padded = fock.math.pad(self._dm, paddings + paddings, mode="constant") + padded = fock_utils.math.pad(self._dm, paddings + paddings, mode="constant") else: padded = self._dm return padded[tuple(slice(s) for s in cutoffs + cutoffs)] @@ -366,10 +366,10 @@ def fock_probabilities(self, cutoffs: Sequence[int]) -> RealTensor: if self._fock_probabilities is None: if self.is_mixed: dm = self.dm(cutoffs=cutoffs) - self._fock_probabilities = fock.dm_to_probs(dm) + self._fock_probabilities = fock_utils.dm_to_probs(dm) else: ket = self.ket(cutoffs=cutoffs) - self._fock_probabilities = fock.ket_to_probs(ket) + self._fock_probabilities = fock_utils.ket_to_probs(ket) return self._fock_probabilities def primal(self, other: State | Transformation) -> State: @@ -430,9 +430,9 @@ def _project_onto_fock(self, other: State) -> State | float: # return the probability (norm) of the state when there are no modes left return ( - fock.math.abs(out_fock) ** 2 + fock_utils.math.abs(out_fock) ** 2 if other.is_pure and self.is_pure - else fock.math.abs(out_fock) + else fock_utils.math.abs(out_fock) ) def _contract_with_other(self, other): @@ -444,7 +444,7 @@ def _contract_with_other(self, other): else: # matching other's cutoffs self_cutoffs = [other.cutoffs[other.indices(m)] for m in self.modes] - out_fock = fock.contract_states( + out_fock = fock_utils.contract_states( stateA=(other.ket(other_cutoffs) if other.is_pure else other.dm(other_cutoffs)), stateB=(self.ket(self_cutoffs) if self.is_pure else self.dm(self_cutoffs)), a_is_dm=other.is_mixed, @@ -499,7 +499,7 @@ def __and__(self, other: State) -> State: if self.is_mixed or other.is_mixed: self_fock = self.dm() other_fock = other.dm() - dm = fock.math.tensordot(self_fock, other_fock, [[], []]) + dm = fock_utils.math.tensordot(self_fock, other_fock, [[], []]) # e.g. self has shape [1,3,1,3] and other has shape [2,2] # we want self & other to have shape [1,3,2,1,3,2] # before transposing shape is [1,3,1,3]+[2,2] @@ -519,7 +519,7 @@ def __and__(self, other: State) -> State: self_fock = self.ket() other_fock = other.ket() return State( - ket=fock.math.tensordot(self_fock, other_fock, [[], []]), + ket=fock_utils.math.tensordot(self_fock, other_fock, [[], []]), modes=self.modes + [m + max(self.modes) + 1 for m in other.modes], ) cov = gaussian.join_covs([self.cov, other.cov]) @@ -551,9 +551,9 @@ def bargmann(self, numpy=False) -> tuple[ComplexMatrix, ComplexVector, complex] """ if self.is_gaussian: if self.is_pure: - A, B, C = bargmann.wigner_to_bargmann_psi(self.cov, self.means) + A, B, C = bargmann_utils.wigner_to_bargmann_psi(self.cov, self.means) else: - A, B, C = bargmann.wigner_to_bargmann_rho(self.cov, self.means) + A, B, C = bargmann_utils.wigner_to_bargmann_rho(self.cov, self.means) else: return None if numpy: @@ -582,7 +582,7 @@ def get_modes(self, item) -> State: means, _ = gaussian.partition_means(self.means, item_idx) return State(cov=cov, means=means, modes=item) - fock_partitioned = fock.trace(self.dm(self.cutoffs), keep=item_idx) + fock_partitioned = fock_utils.trace(self.dm(self.cutoffs), keep=item_idx) return State(dm=fock_partitioned, modes=item) def __eq__(self, other) -> bool: # pylint: disable=too-many-return-statements @@ -736,8 +736,8 @@ def mikkel_plot( if plot_args["ytick_labels"] is None: plot_args["ytick_labels"] = plot_args["yticks"] - q, ProbX = fock.quadrature_distribution(rho) - p, ProbP = fock.quadrature_distribution(rho, np.pi / 2) + q, ProbX = fock_utils.quadrature_distribution(rho) + p, ProbP = fock_utils.quadrature_distribution(rho, np.pi / 2) xvec = np.linspace(*xbounds, plot_args["resolution"]) pvec = np.linspace(*ybounds, plot_args["resolution"]) diff --git a/mrmustard/lab/abstract/transformation.py b/mrmustard/lab/abstract/transformation.py index 93044b4b5..62e236664 100644 --- a/mrmustard/lab/abstract/transformation.py +++ b/mrmustard/lab/abstract/transformation.py @@ -27,7 +27,7 @@ from mrmustard.math.parameter_set import ParameterSet from mrmustard.math.parameters import Constant, Variable from mrmustard.math.tensor_networks import Tensor -from mrmustard.physics import bargmann, fock, gaussian +from mrmustard.physics import bargmann_utils, fock_utils, gaussian from mrmustard.utils.typing import RealMatrix, RealVector from .state import State @@ -172,9 +172,9 @@ def d_vector_dual(self) -> RealVector | None: def bargmann(self, numpy=False): X, Y, d = self.XYd(allow_none=False) if self.is_unitary: - A, B, C = bargmann.wigner_to_bargmann_U(X, d) + A, B, C = bargmann_utils.wigner_to_bargmann_U(X, d) else: - A, B, C = bargmann.wigner_to_bargmann_Choi(X, Y, d) + A, B, C = bargmann_utils.wigner_to_bargmann_Choi(X, Y, d) if numpy: return math.asnumpy(A), math.asnumpy(B), math.asnumpy(C) return A, B, C @@ -208,11 +208,11 @@ def choi( U = self.U(shape[: self.num_modes]) Udual = self.U(shape[self.num_modes :]) if dual: - return fock.U_to_choi(U=Udual, Udual=U) - return fock.U_to_choi(U=U, Udual=Udual) + return fock_utils.U_to_choi(U=Udual, Udual=U) + return fock_utils.U_to_choi(U=U, Udual=Udual) X, Y, d = self.XYd(allow_none=False) - choi = fock.wigner_to_fock_Choi(X, Y, d, shape=shape) + choi = fock_utils.wigner_to_fock_Choi(X, Y, d, shape=shape) if dual: n = len(shape) // 4 N0 = list(range(0, n)) @@ -382,8 +382,10 @@ def _transform_fock(self, state: State, dual=False) -> State: op_idx = [state.modes.index(m) for m in self.modes] U = self.U(cutoffs=[state.cutoffs[i] for i in op_idx]) if state.is_hilbert_vector: - return State(ket=fock.apply_kraus_to_ket(U, state.ket(), op_idx), modes=state.modes) - return State(dm=fock.apply_kraus_to_dm(U, state.dm(), op_idx), modes=state.modes) + return State( + ket=fock_utils.apply_kraus_to_ket(U, state.ket(), op_idx), modes=state.modes + ) + return State(dm=fock_utils.apply_kraus_to_dm(U, state.dm(), op_idx), modes=state.modes) def U( self, @@ -411,7 +413,7 @@ def U( raise ValueError(f"len(cutoffs) must be {self.num_modes} (got {len(cutoffs)})") shape = shape or tuple(cutoffs) * 2 X, _, d = self.XYd(allow_none=False) - return fock.wigner_to_fock_U(X, d, shape=shape) + return fock_utils.wigner_to_fock_U(X, d, shape=shape) def __eq__(self, other): r"""Returns ``True`` if the two transformations are equal.""" @@ -453,8 +455,10 @@ def _transform_fock(self, state: State, dual: bool = False) -> State: op_idx = [state.modes.index(m) for m in self.modes] choi = self.choi(cutoffs=[state.cutoffs[i] for i in op_idx], dual=dual) if state.is_hilbert_vector: - return State(dm=fock.apply_choi_to_ket(choi, state.ket(), op_idx), modes=state.modes) - return State(dm=fock.apply_choi_to_dm(choi, state.dm(), op_idx), modes=state.modes) + return State( + dm=fock_utils.apply_choi_to_ket(choi, state.ket(), op_idx), modes=state.modes + ) + return State(dm=fock_utils.apply_choi_to_dm(choi, state.dm(), op_idx), modes=state.modes) def value(self, shape: tuple[int]): return self.choi(shape=shape) diff --git a/mrmustard/lab/detectors.py b/mrmustard/lab/detectors.py index 3d8869ab9..22304b01a 100644 --- a/mrmustard/lab/detectors.py +++ b/mrmustard/lab/detectors.py @@ -21,7 +21,7 @@ from typing import Iterable from mrmustard import settings -from mrmustard.physics import fock, gaussian +from mrmustard.physics import fock_utils, gaussian from mrmustard.utils.typing import RealMatrix, RealVector from mrmustard import math @@ -407,7 +407,7 @@ def _measure_fock(self, other) -> State | float: reduced_state = other.get_modes(self.modes) # build pdf and sample homodyne outcome - x_outcome, probability = fock.sample_homodyne( + x_outcome, probability = fock_utils.sample_homodyne( state=reduced_state.ket() if reduced_state.is_pure else reduced_state.dm(), quadrature_angle=self.quadrature_angle, ) @@ -427,7 +427,7 @@ def _measure_fock(self, other) -> State | float: other_cutoffs = [ None if m not in self.modes else other.cutoffs[other.indices(m)] for m in other.modes ] - out_fock = fock.contract_states( + out_fock = fock_utils.contract_states( stateA=(other.ket(other_cutoffs) if other.is_pure else other.dm(other_cutoffs)), stateB=self.state.ket(self_cutoffs), a_is_dm=other.is_mixed, diff --git a/mrmustard/lab/gates.py b/mrmustard/lab/gates.py index 9360bdb77..e8863fcac 100644 --- a/mrmustard/lab/gates.py +++ b/mrmustard/lab/gates.py @@ -24,7 +24,7 @@ import numpy as np from mrmustard import settings -from mrmustard.physics import gaussian, fock +from mrmustard.physics import fock_utils, gaussian from mrmustard.utils.typing import ComplexMatrix, RealMatrix from mrmustard import math from mrmustard.math.parameters import ( @@ -146,9 +146,9 @@ def U( Ud = None for idx, out_in in enumerate(zip(shape[:N], shape[N:])): if Ud is None: - Ud = fock.displacement(x[idx], y[idx], shape=out_in) + Ud = fock_utils.displacement(x[idx], y[idx], shape=out_in) else: - U_next = fock.displacement(x[idx], y[idx], shape=out_in) + U_next = fock_utils.displacement(x[idx], y[idx], shape=out_in) Ud = math.outer(Ud, U_next) return math.transpose( @@ -156,7 +156,7 @@ def U( list(range(0, 2 * N, 2)) + list(range(1, 2 * N, 2)), ) else: - return fock.displacement(x[0], y[0], shape=shape) + return fock_utils.displacement(x[0], y[0], shape=shape) class Sgate(Unitary): @@ -244,16 +244,16 @@ def U( Us = None for idx, single_shape in enumerate(zip(shape[:N], shape[N:])): if Us is None: - Us = fock.squeezer(r[idx], phi[idx], shape=single_shape) + Us = fock_utils.squeezer(r[idx], phi[idx], shape=single_shape) else: - U_next = fock.squeezer(r[idx], phi[idx], shape=single_shape) + U_next = fock_utils.squeezer(r[idx], phi[idx], shape=single_shape) Us = math.outer(Us, U_next) return math.transpose( Us, list(range(0, 2 * N, 2)) + list(range(1, 2 * N, 2)), ) else: - return fock.squeezer(r[0], phi[0], shape=shape) + return fock_utils.squeezer(r[0], phi[0], shape=shape) @property def X_matrix(self): @@ -541,7 +541,7 @@ def U( shape = shape or cutoffs - return fock.beamsplitter( + return fock_utils.beamsplitter( self.theta.value, self.phi.value, shape=shape, @@ -1049,7 +1049,7 @@ def primal(self, state): idx = state.modes.index(self.modes[0]) if state.is_pure: ket = state.ket() - dm = fock.ket_to_dm(ket) + dm = fock_utils.ket_to_dm(ket) else: dm = state.dm() diff --git a/mrmustard/lab/states.py b/mrmustard/lab/states.py index 60dfa493e..2d847556a 100644 --- a/mrmustard/lab/states.py +++ b/mrmustard/lab/states.py @@ -23,7 +23,7 @@ from mrmustard import math, settings from mrmustard.math.parameter_set import ParameterSet from mrmustard.math.parameters import update_symplectic -from mrmustard.physics import fock, gaussian +from mrmustard.physics import fock_utils, gaussian from mrmustard.utils.typing import RealMatrix, Scalar, Vector from .abstract import State @@ -453,7 +453,7 @@ def __init__( cutoffs: Sequence[int] | None = None, normalize: bool = False, ): - super().__init__(ket=fock.fock_state(n), cutoffs=cutoffs) + super().__init__(ket=fock_utils.fock_state(n), cutoffs=cutoffs) self._n = [n] if isinstance(n, int) else n self._modes = modes @@ -486,5 +486,5 @@ def _preferred_projection(self, other: State, mode_indices: Sequence[int]): else other.dm(cutoffs)[tuple(getitem) * 2] ) if self._normalize: - return fock.normalize(output, is_dm=other.is_mixed) + return fock_utils.normalize(output, is_dm=other.is_mixed) return output diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 61b8616e1..f88f71448 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -40,7 +40,7 @@ Batch, ) from mrmustard.physics.representations import Representation, Bargmann, Fock -from mrmustard.physics.fock import quadrature_basis +from mrmustard.physics.fock_utils import quadrature_basis from mrmustard.math.parameter_set import ParameterSet from mrmustard.math.parameters import Constant, Variable from mrmustard.lab_dev.wires import Wires diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 49c248a43..0237958ca 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -36,7 +36,7 @@ import plotly.graph_objects as go from mrmustard import math, settings, widgets -from mrmustard.physics.fock import quadrature_distribution +from mrmustard.physics.fock_utils import quadrature_distribution from mrmustard.physics.wigner import wigner_discretized from mrmustard.utils.typing import ( Batch, @@ -47,7 +47,7 @@ Scalar, Vector, ) -from mrmustard.physics.bargmann import ( +from mrmustard.physics.bargmann_utils import ( wigner_to_bargmann_psi, wigner_to_bargmann_rho, ) diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index 07261c8ae..da6623443 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -21,7 +21,7 @@ from typing import Sequence from mrmustard.physics.representations import Fock -from mrmustard.physics.fock import fock_state +from mrmustard.physics.fock_utils import fock_state from .base import Ket from ..utils import make_parameter, reshape_params diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index 0d4ff51ae..2367c9623 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -29,7 +29,7 @@ from mrmustard import math, settings from mrmustard.physics.representations import Bargmann, Fock from mrmustard.utils.typing import ComplexMatrix -from mrmustard.physics.bargmann import au2Symplectic, symplectic2Au, XY_of_channel +from mrmustard.physics.bargmann_utils import au2Symplectic, symplectic2Au, XY_of_channel from ..circuit_components import CircuitComponent diff --git a/mrmustard/physics/__init__.py b/mrmustard/physics/__init__.py index e8e91337c..03bcd4422 100644 --- a/mrmustard/physics/__init__.py +++ b/mrmustard/physics/__init__.py @@ -20,7 +20,7 @@ optimization routine. """ -from mrmustard.physics import fock, gaussian +from mrmustard.physics import fock_utils, gaussian # pylint: disable=protected-access @@ -36,7 +36,7 @@ def fidelity(A, B) -> float: """ if A.is_gaussian and B.is_gaussian: return gaussian.fidelity(A.means, A.cov, B.means, B.cov) - return fock.fidelity(A.fock, B.fock, a_ket=A._ket is not None, b_ket=B._ket is not None) + return fock_utils.fidelity(A.fock, B.fock, a_ket=A._ket is not None, b_ket=B._ket is not None) def normalize(A): @@ -53,9 +53,9 @@ def normalize(A): return A if A.is_mixed: - return A.__class__(dm=fock.normalize(A.dm(), is_dm=True)) + return A.__class__(dm=fock_utils.normalize(A.dm(), is_dm=True)) - return A.__class__(ket=fock.normalize(A.ket(), is_dm=False)) + return A.__class__(ket=fock_utils.normalize(A.ket(), is_dm=False)) def norm(A) -> float: @@ -72,7 +72,7 @@ def norm(A) -> float: """ if A.is_gaussian: return A._norm - return fock.norm(A.fock, is_dm=A.is_mixed) + return fock_utils.norm(A.fock, is_dm=A.is_mixed) def overlap(A, B) -> float: @@ -102,7 +102,7 @@ def von_neumann_entropy(A) -> float: """ if A.is_gaussian: return gaussian.von_neumann_entropy(A.cov) - return fock.von_neumann_entropy(A.fock, a_dm=A.is_mixed) + return fock_utils.von_neumann_entropy(A.fock, a_dm=A.is_mixed) def relative_entropy(A, B) -> float: @@ -130,4 +130,4 @@ def trace_distance(A, B) -> float: """ if A.is_gaussian and B.is_gaussian: return gaussian.trace_distance(A.means, A.cov, B.means, B.cov) - return fock.trace_distance(A.fock, B.fock, a_dm=A.is_mixed, b_dm=B.is_mixed) + return fock_utils.trace_distance(A.fock, B.fock, a_dm=A.is_mixed, b_dm=B.is_mixed) diff --git a/mrmustard/physics/bargmann.py b/mrmustard/physics/bargmann.py deleted file mode 100644 index 6600cf371..000000000 --- a/mrmustard/physics/bargmann.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module contains functions for performing calculations on objects in the Bargmann representations. -""" - -import numpy as np - -from mrmustard import math, settings -from mrmustard.physics.husimi import pq_to_aadag, wigner_to_husimi -from mrmustard.utils.typing import ComplexMatrix - - -def cayley(X, c): - r"""Returns the Cayley transform of a matrix: - :math:`cay(X) = (X - cI)(X + cI)^{-1}` - - Args: - c (float): the parameter of the Cayley transform - X (Tensor): a matrix - - Returns: - Tensor: the Cayley transform of X - """ - I = math.eye(X.shape[0], dtype=X.dtype) - return math.solve(X + c * I, X - c * I) - - -def wigner_to_bargmann_rho(cov, means): - r"""Converts the wigner representation in terms of covariance matrix and mean vector into the Bargmann `A,B,C` triple - for a density matrix (i.e. for `M` modes, `A` has shape `2M x 2M` and `B` has shape `2M`). - The order of the rows/columns of A and B corresponds to a density matrix with the usual ordering of the indices. - - Note that here A and B are defined with respect to the literature. - """ - N = cov.shape[-1] // 2 - A = math.matmul(math.Xmat(N), cayley(pq_to_aadag(cov), c=0.5)) - Q, beta = wigner_to_husimi(cov, means) - b = math.solve(Q, beta) - B = math.conj(b) - num_C = math.exp(-0.5 * math.sum(math.conj(beta) * b)) - detQ = math.det(Q) - den_C = math.sqrt(detQ, dtype=num_C.dtype) - C = num_C / den_C - return A, B, C - - -def wigner_to_bargmann_psi(cov, means): - r"""Converts the wigner representation in terms of covariance matrix and mean vector into the Bargmann A,B,C triple - for a Hilbert vector (i.e. for M modes, A has shape M x M and B has shape M). - """ - N = cov.shape[-1] // 2 - A, B, C = wigner_to_bargmann_rho(cov, means) - return A[N:, N:], B[N:], math.sqrt(C) - # NOTE: c for th psi is to calculated from the global phase formula. - - -def wigner_to_bargmann_Choi(X, Y, d): - r"""Converts the wigner representation in terms of covariance matrix and mean vector into the Bargmann `A,B,C` triple - for a channel (i.e. for M modes, A has shape 4M x 4M and B has shape 4M).""" - N = X.shape[-1] // 2 - I2 = math.eye(2 * N, dtype=X.dtype) - XT = math.transpose(X) - xi = 0.5 * (I2 + math.matmul(X, XT) + 2 * Y / settings.HBAR) - detxi = math.det(xi) - xi_inv = math.inv(xi) - A = math.block( - [ - [I2 - xi_inv, math.matmul(xi_inv, X)], - [math.matmul(XT, xi_inv), I2 - math.matmul(math.matmul(XT, xi_inv), X)], - ] - ) - I = math.eye(N, dtype=math.complex128) - o = math.zeros_like(I) - R = math.block( - [[I, 1j * I, o, o], [o, o, I, -1j * I], [I, -1j * I, o, o], [o, o, I, 1j * I]] - ) / np.sqrt(2) - A = math.matmul(math.matmul(R, A), math.dagger(R)) - A = math.matmul(math.Xmat(2 * N), A) - b = math.matvec(xi_inv, d) - B = math.matvec(math.conj(R), math.concat([b, -math.matvec(XT, b)], axis=-1)) / math.sqrt( - settings.HBAR, dtype=R.dtype - ) - C = math.exp(-0.5 * math.sum(d * b) / settings.HBAR) / math.sqrt(detxi, dtype=b.dtype) - # now A and B have order [out_r, in_r out_l, in_l]. - return A, B, math.cast(C, math.complex128) - - -def wigner_to_bargmann_U(X, d): - r"""Converts the wigner representation in terms of covariance matrix and mean vector into the Bargmann `A,B,C` triple - for a unitary (i.e. for `M` modes, `A` has shape `2M x 2M` and `B` has shape `2M`). - """ - N = X.shape[-1] // 2 - A, B, C = wigner_to_bargmann_Choi(X, math.zeros_like(X), d) - return A[2 * N :, 2 * N :], B[2 * N :], math.sqrt(C) - - -def norm_ket(A, b, c): - r"""Calculates the l2 norm of a Ket with a representation given by the Bargmann triple A,b,c.""" - M = math.block([[math.conj(A), -math.eye_like(A)], [-math.eye_like(A), A]]) - B = math.concat([math.conj(b), b], 0) - norm_squared = ( - math.abs(c) ** 2 - * math.exp(-0.5 * math.sum(B * math.matvec(math.inv(M), B))) - / math.sqrt((-1) ** A.shape[-1] * math.det(M)) - ) - return math.real(math.sqrt(norm_squared)) - - -def trace_dm(A, b, c): - r"""Calculates the total trace of the density matrix with representation given by the Bargmann triple A,b,c.""" - M = A - math.Xmat(A.shape[-1] // 2) - trace = ( - c - * math.exp(-0.5 * math.sum(b * math.matvec(math.inv(M), b))) - / math.sqrt((-1) ** (A.shape[-1] // 2) * math.det(M)) - ) - return math.real(trace) - - -def au2Symplectic(A): - r""" - helper for finding the Au of a unitary from its symplectic rep. - Au : in bra-ket order - """ - # A represents the A matrix corresponding to unitary U - A = A * (1.0 + 0.0 * 1j) - m = A.shape[-1] - m = m // 2 - - # identifying blocks of A_u - u_2 = A[..., :m, m:] - u_3 = A[..., m:, m:] - - # The formula to apply comes here - S_1 = math.conj(math.inv(math.transpose(u_2))) - S_2 = -S_1 @ math.conj(u_3) - S_3 = math.conj(S_2) - S_4 = math.conj(S_1) - - S = math.block([[S_1, S_2], [S_3, S_4]]) - - transformation = ( - 1 - / np.sqrt(2) - * math.block( - [ - [math.eye(m, dtype=math.complex128), math.eye(m, dtype=math.complex128)], - [-1j * math.eye(m, dtype=math.complex128), 1j * math.eye(m, dtype=math.complex128)], - ] - ) - ) - - return math.real(transformation @ S @ math.conj(math.transpose(transformation))) - - -def symplectic2Au(S): - r""" - The inverse of au2Symplectic i.e., returns symplectic, given Au - - S: symplectic in XXPP order - """ - m = S.shape[-1] - m = m // 2 - # the following lines of code transform the quadrature symplectic matrix to - # the annihilation one - transformation = ( - 1 - / np.sqrt(2) - * math.block( - [ - [math.eye(m, dtype=math.complex128), math.eye(m, dtype=math.complex128)], - [-1j * math.eye(m, dtype=math.complex128), 1j * math.eye(m, dtype=math.complex128)], - ] - ) - ) - S = np.conjugate(math.transpose(transformation)) @ S @ transformation - # identifying blocks of S - S_1 = S[:m, :m] - S_2 = S[:m, m:] - - # TODO: broadcasting/batch stuff consider a batch dimension - - # the formula to apply comes here - A_1 = S_2 @ math.conj(math.inv(S_1)) # use solve for inverse - A_2 = math.conj(math.inv(math.transpose(S_1))) - A_3 = math.transpose(A_2) - A_4 = -math.conj(math.solve(S_1, S_2)) - # -np.conjugate(np.linalg.pinv(S_1)) @ np.conjugate(S_2) - - A = math.block([[A_1, A_2], [A_3, A_4]]) - - return A - - -def XY_of_channel(A: ComplexMatrix): - r""" - Outputting the X and Y matrices corresponding to a channel determined by the "A" - matrix. - - Args: - A: the A matrix of the channel - """ - n = A.shape[-1] // 2 - m = n // 2 - - # here we transform to the other convention for wires i.e. {out-bra, out-ket, in-bra, in-ket} - A_out = math.block( - [[A[:m, :m], A[:m, 2 * m : 3 * m]], [A[2 * m : 3 * m, :m], A[2 * m : 3 * m, 2 * m : 3 * m]]] - ) - R = math.block( - [ - [A[:m, m : 2 * m], A[:m, 3 * m :]], - [A[2 * m : 3 * m, m : 2 * m], A[2 * m : 3 * m, 3 * m :]], - ] - ) - X_tilde = -math.inv(np.eye(n) - math.Xmat(m) @ A_out) @ math.Xmat(m) @ R @ math.Xmat(m) - transformation = math.block( - [ - [math.eye(m, dtype=math.complex128), math.eye(m, dtype=math.complex128)], - [-1j * math.eye(m, dtype=math.complex128), 1j * math.eye(m, dtype=math.complex128)], - ] - ) - X = -transformation @ X_tilde @ math.conj(transformation).T / 2 - - sigma_H = math.inv(math.eye(n) - math.Xmat(m) @ A_out) # the complex-Husimi covariance matrix - - N = sigma_H[m:, m:] - M = sigma_H[:m, m:] - sigma = ( - math.block([[math.real(N + M), math.imag(N + M)], [math.imag(M - N), math.real(N - M)]]) - - math.eye(n) / 2 - ) - Y = sigma - X @ X.T / 2 - if math.norm(math.imag(X)) > settings.ATOL or math.norm(math.imag(Y)) > settings.ATOL: - raise ValueError( - "Invalid input for the A matrix of channel, caused imaginary X and/or Y matrices." - ) - return math.real(X), math.real(Y) diff --git a/mrmustard/physics/fock.py b/mrmustard/physics/fock.py deleted file mode 100644 index 13c45b1fa..000000000 --- a/mrmustard/physics/fock.py +++ /dev/null @@ -1,1033 +0,0 @@ -# Copyright 2021 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint: disable=redefined-outer-name - -""" -This module contains functions for performing calculations on objects in the Fock representations. -""" - -from __future__ import annotations - -from functools import lru_cache -from typing import Sequence, Iterable - -import numpy as np - -from mrmustard import math, settings -from mrmustard.math.lattice import strategies -from mrmustard.math.caching import tensor_int_cache -from mrmustard.math.tensor_wrappers.mmtensor import MMTensor -from mrmustard.physics.bargmann import ( - wigner_to_bargmann_Choi, - wigner_to_bargmann_psi, - wigner_to_bargmann_rho, - wigner_to_bargmann_U, -) -from mrmustard.utils.typing import ComplexTensor, Matrix, Scalar, Tensor, Vector, Batch - -SQRT = np.sqrt(np.arange(1e6)) - -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# ~~~~~~~~~~~~~~ static functions ~~~~~~~~~~~~~~ -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - -def fock_state(n: Sequence[int], cutoffs: int | Sequence[int] | None = None) -> Tensor: - r""" - The Fock array of a tensor product of one-mode ``Number`` states. - - Args: - n: The photon numbers of the number states. - cutoffs: The cutoffs of the arrays for the number states. If it is given as - an ``int``, it is broadcasted to all the states. If ``None``, it - defaults to ``[n1+1, n2+1, ...]``, where ``ni`` is the photon number - of the ``i``th mode. - - Returns: - The Fock array of a tensor product of one-mode ``Number`` states. - """ - n = math.atleast_1d(n) - if cutoffs is None: - cutoffs = list(n) - elif isinstance(cutoffs, int): - cutoffs = [cutoffs] * len(n) - - if len(cutoffs) != len(n): - msg = f"Expected ``len(cutoffs)={len(n)}`` but found ``{len(cutoffs)}``." - raise ValueError(msg) - - shape = tuple([c + 1 for c in cutoffs]) - array = np.zeros(shape, dtype=np.complex128) - - try: - array[tuple(n)] = 1 - except IndexError: - msg = "Photon numbers cannot be larger than the corresponding cutoffs." - raise ValueError(msg) - - return math.astensor(array) - - -def autocutoffs(cov: Matrix, means: Vector, probability: float): - r"""Returns the cutoffs of a Gaussian state by computing the 1-mode marginals until - the probability of the marginal is less than ``probability``. - - Args: - cov: the covariance matrix - means: the means vector - probability: the cutoff probability - - Returns: - Tuple[int, ...]: the suggested cutoffs - """ - M = len(means) // 2 - cutoffs = [] - for i in range(M): - cov_i = np.array([[cov[i, i], cov[i, i + M]], [cov[i + M, i], cov[i + M, i + M]]]) - means_i = np.array([means[i], means[i + M]]) - # apply 1-d recursion until probability is less than 0.99 - A, B, C = [math.asnumpy(x) for x in wigner_to_bargmann_rho(cov_i, means_i)] - diag = math.hermite_renormalized_diagonal(A, B, C, cutoffs=[settings.AUTOCUTOFF_MAX_CUTOFF]) - # find at what index in the cumsum the probability is more than 0.99 - for i, val in enumerate(np.cumsum(diag)): - if val > probability: - cutoffs.append(max(i + 1, settings.AUTOCUTOFF_MIN_CUTOFF)) - break - else: - cutoffs.append(settings.AUTOCUTOFF_MAX_CUTOFF) - return cutoffs - - -def wigner_to_fock_state( - cov: Matrix, - means: Vector, - shape: Sequence[int], - max_prob: float = 1.0, - max_photons: int | None = None, - return_dm: bool = True, -) -> Tensor: - r"""Returns the Fock representation of a Gaussian state. - Use with caution: if the cov matrix is that of a mixed state, - setting return_dm to False will produce nonsense. - If return_dm=False, we can apply max_prob and max_photons to stop the - computation of the Fock representation early, when those conditions are met. - - * If the state is pure it can return the state vector (ket) or the density matrix. - The index ordering is going to be [i's] in ket_i - * If the state is mixed it can return the density matrix. - The index order is going to be [i's,j's] in dm_ij - - Args: - cov: the Wigner covariance matrix - means: the Wigner means vector - shape: the shape of the tensor - max_prob: the maximum probability of a the state (applies only if the ket is returned) - max_photons: the maximum number of photons in the state (applies only if the ket is returned) - return_dm: whether to return the density matrix (otherwise it returns the ket) - - Returns: - Tensor: the fock representation - """ - if return_dm: - A, B, C = wigner_to_bargmann_rho(cov, means) - # NOTE: change the order of the index in AB - Xmat = math.Xmat(A.shape[-1] // 2) - A = math.matmul(math.matmul(Xmat, A), Xmat) - B = math.matvec(Xmat, B) - return math.hermite_renormalized(A, B, C, shape=shape) - else: # here we can apply max prob and max photons - A, B, C = wigner_to_bargmann_psi(cov, means) - if max_photons is None: - max_photons = sum(shape) - len(shape) - if max_prob < 1.0 or max_photons < sum(shape) - len(shape): - return math.hermite_renormalized_binomial( - A, B, C, shape=shape, max_l2=max_prob, global_cutoff=max_photons + 1 - ) - return math.hermite_renormalized(A, B, C, shape=tuple(shape)) - - -def wigner_to_fock_U(X, d, shape): - r"""Returns the Fock representation of a Gaussian unitary transformation. - The index order is out_l, in_l, where in_l is to be contracted with the indices of a ket, - or with the left indices of a density matrix. - - Arguments: - X: the X matrix - d: the d vector - shape: the shape of the tensor - - Returns: - Tensor: the fock representation of the unitary transformation - """ - A, B, C = wigner_to_bargmann_U(X, d) - return math.hermite_renormalized(A, B, C, shape=tuple(shape)) - - -def wigner_to_fock_Choi(X, Y, d, shape): - r"""Returns the Fock representation of a Gaussian Choi matrix. - The order of choi indices is :math:`[\mathrm{out}_l, \mathrm{in}_l, \mathrm{out}_r, \mathrm{in}_r]` - where :math:`\mathrm{in}_l` and :math:`\mathrm{in}_r` are to be contracted with the left and right indices of a density matrix. - - Arguments: - X: the X matrix - Y: the Y matrix - d: the d vector - shape: the shape of the tensor - - Returns: - Tensor: the fock representation of the Choi matrix - """ - A, B, C = wigner_to_bargmann_Choi(X, Y, d) - # NOTE: change the order of the index in AB - Xmat = math.Xmat(A.shape[-1] // 2) - A = math.matmul(math.matmul(Xmat, A), Xmat) - N = B.shape[-1] // 2 - B = math.concat([B[N:], B[:N]], axis=-1) - return math.hermite_renormalized(A, B, C, shape=tuple(shape)) - - -def ket_to_dm(ket: Tensor) -> Tensor: - r"""Maps a ket to a density matrix. - - Args: - ket: the ket - - Returns: - Tensor: the density matrix - """ - return math.outer(ket, math.conj(ket)) - - -def dm_to_ket(dm: Tensor) -> Tensor: - r"""Maps a density matrix to a ket if the state is pure. - - If the state is pure :math:`\hat \rho= |\psi\rangle\langle \psi|` then the - ket is the eigenvector of :math:`\rho` corresponding to the eigenvalue 1. - - Args: - dm (Tensor): the density matrix - - Returns: - Tensor: the ket - - Raises: - ValueError: if ket for mixed states cannot be calculated - """ - - is_pure_dm = np.isclose(purity(dm), 1.0, atol=1e-6) - if not is_pure_dm: - raise ValueError("Cannot calculate ket for mixed states.") - - cutoffs = dm.shape[: len(dm.shape) // 2] - d = int(np.prod(cutoffs)) - dm = math.reshape(dm, (d, d)) - - eigvals, eigvecs = math.eigh(dm) - # eigenvalues and related eigenvectors are sorted in non-decreasing order, - # meaning the associated eigvec to largest eigval is stored last. - ket = eigvecs[:, -1] * math.sqrt(eigvals[-1]) - ket = math.reshape(ket, cutoffs) - - return ket - - -def ket_to_probs(ket: Tensor) -> Tensor: - r"""Maps a ket to probabilities. - - Args: - ket: the ket - - Returns: - Tensor: the probabilities vector - """ - return math.abs(ket) ** 2 - - -def dm_to_probs(dm: Tensor) -> Tensor: - r"""Extracts the diagonals of a density matrix. - - Args: - dm: the density matrix - - Returns: - Tensor: the probabilities vector - """ - return math.all_diagonals(dm, real=True) - - -def U_to_choi(U: Tensor, Udual: Tensor | None = None) -> Tensor: - r"""Converts a unitary transformation to a Choi tensor. - - Args: - U: the unitary transformation - Udual: the dual unitary transformation (optional, will use conj U if not provided) - - Returns: - Tensor: the Choi tensor. The index order is going to be :math:`[\mathrm{out}_l, \mathrm{in}_l, \mathrm{out}_r, \mathrm{in}_r]` - where :math:`\mathrm{in}_l` and :math:`\mathrm{in}_r` are to be contracted with the left and right indices of the density matrix. - """ - return math.outer(U, math.conj(U) if Udual is None else Udual) - - -def fidelity(state_a, state_b, a_ket: bool, b_ket: bool) -> Scalar: - r"""Computes the fidelity between two states in Fock representation.""" - if a_ket and b_ket: - min_cutoffs = [slice(min(a, b)) for a, b in zip(state_a.shape, state_b.shape)] - state_a = state_a[tuple(min_cutoffs)] - state_b = state_b[tuple(min_cutoffs)] - return math.abs(math.sum(math.conj(state_a) * state_b)) ** 2 - - if a_ket: - min_cutoffs = [ - slice(min(a, b)) - for a, b in zip(state_a.shape, state_b.shape[: len(state_b.shape) // 2]) - ] - state_a = state_a[tuple(min_cutoffs)] - state_b = state_b[tuple(min_cutoffs * 2)] - a = math.reshape(state_a, -1) - return math.real( - math.sum(math.conj(a) * math.matvec(math.reshape(state_b, (len(a), len(a))), a)) - ) - - if b_ket: - min_cutoffs = [ - slice(min(a, b)) - for a, b in zip(state_a.shape[: len(state_a.shape) // 2], state_b.shape) - ] - state_a = state_a[tuple(min_cutoffs * 2)] - state_b = state_b[tuple(min_cutoffs)] - b = math.reshape(state_b, -1) - return math.real( - math.sum(math.conj(b) * math.matvec(math.reshape(state_a, (len(b), len(b))), b)) - ) - - # mixed state - # Richard Jozsa (1994) Fidelity for Mixed Quantum States, Journal of Modern Optics, 41:12, 2315-2323, DOI: 10.1080/09500349414552171 - - # trim states to have same cutoff - min_cutoffs = [ - slice(min(a, b)) - for a, b in zip( - state_a.shape[: len(state_a.shape) // 2], - state_b.shape[: len(state_b.shape) // 2], - ) - ] - state_a = state_a[tuple(min_cutoffs * 2)] - state_b = state_b[tuple(min_cutoffs * 2)] - return math.abs( - ( - math.trace( - math.sqrtm( - math.matmul(math.matmul(math.sqrtm(state_a), state_b), math.sqrtm(state_a)) - ) - ) - ** 2 - ) - ) - - -def number_means(tensor, is_dm: bool): - r"""Returns the mean of the number operator in each mode.""" - probs = math.all_diagonals(tensor, real=True) if is_dm else math.abs(tensor) ** 2 - modes = list(range(len(probs.shape))) - marginals = [math.sum(probs, axes=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] - return math.astensor( - [ - math.sum(marginal * math.arange(len(marginal), dtype=math.float64)) - for marginal in marginals - ] - ) - - -def number_variances(tensor, is_dm: bool): - r"""Returns the variance of the number operator in each mode.""" - probs = math.all_diagonals(tensor, real=True) if is_dm else math.abs(tensor) ** 2 - modes = list(range(len(probs.shape))) - marginals = [math.sum(probs, axes=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] - return math.astensor( - [ - ( - math.sum(marginal * math.arange(marginal.shape[0], dtype=marginal.dtype) ** 2) - - math.sum(marginal * math.arange(marginal.shape[0], dtype=marginal.dtype)) ** 2 - ) - for marginal in marginals - ] - ) - - -def purity(dm: Tensor) -> Scalar: - r"""Returns the purity of a density matrix.""" - cutoffs = dm.shape[: len(dm.shape) // 2] - d = int(np.prod(cutoffs)) # combined cutoffs in all modes - dm = math.reshape(dm, (d, d)) - dm = dm / math.trace(dm) # assumes all nonzero values are included in the density matrix - return math.abs(math.sum(math.transpose(dm) * dm)) # tr(rho^2) - - -def validate_contraction_indices(in_idx, out_idx, M, name): - r"""Validates the indices used for the contraction of a tensor.""" - if len(set(in_idx)) != len(in_idx): - raise ValueError(f"{name}_in_idx should not contain repeated indices.") - if len(set(out_idx)) != len(out_idx): - raise ValueError(f"{name}_out_idx should not contain repeated indices.") - if not set(range(M)).intersection(out_idx).issubset(set(in_idx)): - wrong_indices = set(range(M)).intersection(out_idx) - set(in_idx) - raise ValueError( - f"Indices {wrong_indices} in {name}_out_idx are trying to replace uncontracted indices." - ) - - -def apply_kraus_to_ket(kraus, ket, kraus_in_modes, kraus_out_modes=None): - r"""Applies a kraus operator to a ket. - It assumes that the ket is indexed as left_1, ..., left_n. - - The kraus op has indices that contract with the ket (kraus_in_modes) and indices that are left over (kraus_out_modes). - The final index order will be sorted (note that an index appearing in both kraus_in_modes and kraus_out_modes will replace the original index). - - Args: - kraus (array): the kraus operator to be applied - ket (array): the ket to which the operator is applied - kraus_in_modes (list of ints): the indices (counting from 0) of the kraus operator that contract with the ket - kraus_out_modes (list of ints): the indices (counting from 0) of the kraus operator that are leftover - - Returns: - array: the resulting ket with indices as kraus_out_modes + uncontracted ket indices - """ - if kraus_out_modes is None: - kraus_out_modes = kraus_in_modes - - if not set(kraus_in_modes).issubset(range(ket.ndim)): - raise ValueError("kraus_in_modes should be a subset of the ket indices.") - - # check that there are no repeated indices in kraus_in_modes and kraus_out_modes (separately) - validate_contraction_indices(kraus_in_modes, kraus_out_modes, ket.ndim, "kraus") - - ket = MMTensor(ket, axis_labels=[f"in_left_{i}" for i in range(ket.ndim)]) - kraus = MMTensor( - kraus, - axis_labels=[f"out_left_{i}" for i in kraus_out_modes] - + [f"in_left_{i}" for i in kraus_in_modes], - ) - - # contract the operator with the ket. - # now the leftover indices are in the order kraus_out_modes + uncontracted ket indices - kraus_ket = kraus @ ket - - # sort kraus_ket.axis_labels by the int at the end of each label. - # Each label is guaranteed to have a unique int at the end. - new_axis_labels = sorted(kraus_ket.axis_labels, key=lambda x: int(x.split("_")[-1])) - - return kraus_ket.transpose(new_axis_labels).tensor - - -def apply_kraus_to_dm(kraus, dm, kraus_in_modes, kraus_out_modes=None): - r"""Applies a kraus operator to a density matrix. - It assumes that the density matrix is indexed as left_1, ..., left_n, right_1, ..., right_n. - - The kraus operator has indices that contract with the density matrix (kraus_in_modes) and indices that are leftover (kraus_out_modes). - `kraus` will contract from the left and from the right with the density matrix. For right contraction the kraus op is conjugated. - - Args: - kraus (array): the operator to be applied - dm (array): the density matrix to which the operator is applied - kraus_in_modes (list of ints): the indices (counting from 0) of the kraus operator that contract with the density matrix - kraus_out_modes (list of ints): the indices (counting from 0) of the kraus operator that are leftover (default None, in which case kraus_out_modes = kraus_in_modes) - - Returns: - array: the resulting density matrix - """ - if kraus_out_modes is None: - kraus_out_modes = kraus_in_modes - - if not set(kraus_in_modes).issubset(range(dm.ndim // 2)): - raise ValueError("kraus_in_modes should be a subset of the density matrix indices.") - - # check that there are no repeated indices in kraus_in_modes and kraus_out_modes (separately) - validate_contraction_indices(kraus_in_modes, kraus_out_modes, dm.ndim // 2, "kraus") - - dm = MMTensor( - dm, - axis_labels=[f"left_{i}" for i in range(dm.ndim // 2)] - + [f"right_{i}" for i in range(dm.ndim // 2)], - ) - kraus = MMTensor( - kraus, - axis_labels=[f"out_left_{i}" for i in kraus_out_modes] - + [f"left_{i}" for i in kraus_in_modes], - ) - kraus_conj = MMTensor( - math.conj(kraus.tensor), - axis_labels=[f"out_right_{i}" for i in kraus_out_modes] - + [f"right_{i}" for i in kraus_in_modes], - ) - - # contract the kraus operator with the density matrix from the left and from the right. - k_dm_k = kraus @ dm @ kraus_conj - # now the leftover indices are in the order: - # out_left_modes + uncontracted left indices + uncontracted right indices + out_right_modes - - # sort k_dm_k.axis_labels by the int at the end of each label, first left, then right - N = k_dm_k.tensor.ndim // 2 - left = sorted(k_dm_k.axis_labels[:N], key=lambda x: int(x.split("_")[-1])) - right = sorted(k_dm_k.axis_labels[N:], key=lambda x: int(x.split("_")[-1])) - - return k_dm_k.transpose(left + right).tensor - - -def apply_choi_to_dm( - choi: ComplexTensor, - dm: ComplexTensor, - choi_in_modes: Sequence[int], - choi_out_modes: Sequence[int] | None = None, -): - r"""Applies a choi operator to a density matrix. - It assumes that the density matrix is indexed as left_1, ..., left_n, right_1, ..., right_n. - - The choi operator has indices that contract with the density matrix (choi_in_modes) and indices that are left over (choi_out_modes). - `choi` will contract choi_in_modes from the left and from the right with the density matrix. - - Args: - choi (array): the choi operator to be applied - dm (array): the density matrix to which the choi operator is applied - choi_in_modes (list of ints): the input modes of the choi operator that contract with the density matrix - choi_out_modes (list of ints): the output modes of the choi operator - - Returns: - array: the resulting density matrix - """ - if choi_out_modes is None: - choi_out_modes = choi_in_modes - if not set(choi_in_modes).issubset(range(dm.ndim // 2)): - raise ValueError("choi_in_modes should be a subset of the density matrix indices.") - - # check that there are no repeated indices in kraus_in_modes and kraus_out_modes (separately) - validate_contraction_indices(choi_in_modes, choi_out_modes, dm.ndim // 2, "choi") - - dm = MMTensor( - dm, - axis_labels=[f"in_left_{i}" for i in range(dm.ndim // 2)] - + [f"in_right_{i}" for i in range(dm.ndim // 2)], - ) - choi = MMTensor( - choi, - axis_labels=[f"out_left_{i}" for i in choi_out_modes] - + [f"in_left_{i}" for i in choi_in_modes] - + [f"out_right_{i}" for i in choi_out_modes] - + [f"in_right_{i}" for i in choi_in_modes], - ) - - # contract the choi matrix with the density matrix. - # now the leftover indices are in the order out_left_modes + out_right_modes + uncontracted left indices + uncontracted right indices - choi_dm = choi @ dm - - # sort choi_dm.axis_labels by the int at the end of each label, first left, then right - left_labels = [label for label in choi_dm.axis_labels if "left" in label] - left = sorted(left_labels, key=lambda x: int(x.split("_")[-1])) - right_labels = [label for label in choi_dm.axis_labels if "right" in label] - right = sorted(right_labels, key=lambda x: int(x.split("_")[-1])) - - return choi_dm.transpose(left + right).tensor - - -def apply_choi_to_ket(choi, ket, choi_in_modes, choi_out_modes=None): - r"""Applies a choi operator to a ket. - It assumes that the ket is indexed as left_1, ..., left_n. - - The choi operator has indices that contract with the ket (choi_in_modes) and indices that are left over (choi_out_modes). - `choi` will contract choi_in_modes from the left and from the right with the ket. - - Args: - choi (array): the choi operator to be applied - ket (array): the ket to which the choi operator is applied - choi_in_modes (list of ints): the indices of the choi operator that contract with the ket - choi_out_modes (list of ints): the indices of the choi operator that re leftover - - Returns: - array: the resulting ket - """ - if choi_out_modes is None: - choi_out_modes = choi_in_modes - - if not set(choi_in_modes).issubset(range(ket.ndim)): - raise ValueError("choi_in_modes should be a subset of the ket indices.") - - # check that there are no repeated indices in kraus_in_modes and kraus_out_modes (separately) - validate_contraction_indices(choi_in_modes, choi_out_modes, ket.ndim, "choi") - - ket = MMTensor(ket, axis_labels=[f"left_{i}" for i in range(ket.ndim)]) - ket_dual = MMTensor(math.conj(ket.tensor), axis_labels=[f"right_{i}" for i in range(ket.ndim)]) - choi = MMTensor( - choi, - axis_labels=[f"out_left_{i}" for i in choi_out_modes] - + [f"left_{i}" for i in choi_in_modes] - + [f"out_right_{i}" for i in choi_out_modes] - + [f"right_{i}" for i in choi_in_modes], - ) - - # contract the choi matrix with the ket and its dual, like choi @ |ket> Tensor: - r"""Harmonic oscillator eigenstate wavefunction `\psi_n(q) = `. - - Args: - q (Vector): a vector containing the q points at which the function is evaluated (units of \sqrt{\hbar}) - cutoff (int): maximum number of photons - - Returns: - Tensor: a tensor of size ``len(q)*cutoff``. Each entry with index ``[i, j]`` represents the eigenstate evaluated - with number of photons ``i`` evaluated at position ``q[j]``, i.e., `\psi_i(q_j)`. - - .. details:: - - .. admonition:: Definition - :class: defn - - The q-quadrature eigenstates are defined as - - .. math:: - - \psi_n(x) = 1/sqrt[2^n n!](\frac{\omega}{\pi \hbar})^{1/4} - \exp{-\frac{\omega}{2\hbar} x^2} H_n(\sqrt{\frac{\omega}{\pi}} x) - - where :math:`H_n(x)` is the (physicists) `n`-th Hermite polynomial. - """ - hbar = settings.HBAR - x = math.cast(q / np.sqrt(hbar), math.complex128) # unit-less vector - - # prefactor term (\Omega/\hbar \pi)**(1/4) * 1 / sqrt(2**n) - prefactor = math.cast( - (np.pi * hbar) ** (-0.25) * math.pow(0.5, math.arange(0, cutoff) / 2), - math.complex128, - ) - - # Renormalized physicist hermite polys: Hn / sqrt(n!) - R = -np.array([[2 + 0j]]) # to get the physicist polys - - def f_hermite_polys(xi): - return math.hermite_renormalized(R, math.astensor([2 * xi]), 1 + 0j, [cutoff]) - - hermite_polys = math.map_fn(f_hermite_polys, x) - - # (real) wavefunction - psi = math.exp(-(x**2 / 2)) * math.transpose(prefactor * hermite_polys) - return psi - - -@lru_cache -def estimate_dx(cutoff, period_resolution=20): - r"""Estimates a suitable quadrature discretization interval `dx`. Uses the fact - that Fock state `n` oscillates with angular frequency :math:`\sqrt{2(n + 1)}`, - which follows from the relation - - .. math:: - - \psi^{[n]}'(q) = q - sqrt(2*(n + 1))*\psi^{[n+1]}(q) - - by setting q = 0, and approximating the oscillation amplitude by `\psi^{[n+1]}(0) - - Ref: https://en.wikipedia.org/wiki/Hermite_polynomials#Hermite_functions - - Args - cutoff (int): Fock cutoff - period_resolution (int): Number of points used to sample one Fock - wavefunction oscillation. Larger values yields better approximations - and thus smaller `dx`. - - Returns - (float): discretization value of quadrature - """ - fock_cutoff_frequency = np.sqrt(2 * (cutoff + 1)) - fock_cutoff_period = 2 * np.pi / fock_cutoff_frequency - dx_estimate = fock_cutoff_period / period_resolution - return dx_estimate - - -@lru_cache -def estimate_xmax(cutoff, minimum=5): - r"""Estimates a suitable quadrature axis length - - Args - cutoff (int): Fock cutoff - minimum (float): Minimum value of the returned xmax - - Returns - (float): maximum quadrature value - """ - if cutoff == 0: - xmax_estimate = 3 - else: - # maximum q for a classical particle with energy n=cutoff - classical_endpoint = np.sqrt(2 * cutoff) - # approximate probability of finding particle outside classical region - excess_probability = 1 / (7.464 * cutoff ** (1 / 3)) - # Emperical factor that yields reasonable results - A = 5 - xmax_estimate = classical_endpoint * (1 + A * excess_probability) - return max(minimum, xmax_estimate) - - -@lru_cache -def estimate_quadrature_axis(cutoff, minimum=5, period_resolution=20): - """Generates a suitable quadrature axis. - - Args - cutoff (int): Fock cutoff - minimum (float): Minimum value of the returned xmax - period_resolution (int): Number of points used to sample one Fock - wavefunction oscillation. Larger values yields better approximations - and thus smaller dx. - - Returns - (array): quadrature axis - """ - xmax = estimate_xmax(cutoff, minimum=minimum) - dx = estimate_dx(cutoff, period_resolution=period_resolution) - xaxis = np.arange(-xmax, xmax, dx) - xaxis = np.append(xaxis, xaxis[-1] + dx) - xaxis = xaxis - np.mean(xaxis) # center around 0 - return xaxis - - -def quadrature_basis( - fock_array: Tensor, - quad: Batch[Vector], - conjugates: bool | list[bool] = False, - phi: Scalar = 0.0, -): - r"""Given the Fock basis representation return the quadrature basis representation. - - Args: - fock_array (Tensor): fock tensor amplitudes - quad (Batch[Vector]): points at which the quadrature basis is evaluated - conjugates (list[bool]): which dimensions of the array to conjugate based on - whether it is a bra or a ket - phi (float): angle of the quadrature basis vector - - Returns: - tuple(Tensor): quadrature basis representation at the points in quad - """ - dims = len(fock_array.shape) - - if quad.shape[-1] != dims: - raise ValueError( - f"Input fock array has dimension {dims} whereas ``quad`` has {quad.shape[-1]}." - ) - - conjugates = conjugates if isinstance(conjugates, Iterable) else [conjugates] * dims - - # construct quadrature basis vectors - shapes = fock_array.shape - quad_basis_vecs = [] - for dim in range(dims): - q_to_n = oscillator_eigenstate(quad[..., dim], shapes[dim]) - if not np.isclose(phi, 0.0): - theta = -math.arange(shapes[dim]) * phi - Ur = math.make_complex(math.cos(theta), math.sin(theta)) - q_to_n = math.einsum("a,ab->ab", Ur, q_to_n) - if conjugates[dim]: - q_to_n = math.conj(q_to_n) - quad_basis_vecs += [math.cast(q_to_n, "complex128")] - - # Convert each dimension to quadrature - subscripts = [chr(i) for i in range(98, 98 + dims)] - fock_string = "".join(subscripts[:dims]) #'bcd....' - q_string = "".join([fock_string[i] + "a," for i in range(dims - 1)] + [fock_string[-1] + "a"]) - quad_array = math.einsum( - fock_string + "," + q_string + "->" + "a", fock_array, *quad_basis_vecs - ) - - return quad_array - - -def quadrature_distribution( - state: Tensor, - quadrature_angle: float = 0.0, - x: Vector | None = None, -): - r"""Given the ket or density matrix of a single-mode state, it generates the probability - density distribution :math:`\tr [ \rho |x_\phi> the quadrature eigenvector with angle `\phi` - equal to ``quadrature_angle``. - - Args: - state (Tensor): single mode state ket or density matrix - quadrature_angle (float): angle of the quadrature basis vector - x (Vector): points at which the quadrature distribution is evaluated - - Returns: - tuple(Vector, Vector): coordinates at which the pdf is evaluated and the probability distribution - """ - cutoff = state.shape[0] - if x is None: - x = np.sqrt(settings.HBAR) * math.new_constant(estimate_quadrature_axis(cutoff), "q_tensor") - - dims = len(state.shape) - is_dm = dims == 2 - - quad = math.transpose(math.astensor([x] * dims)) - conjugates = [True, False] if is_dm else [False] - quad_basis = quadrature_basis(state, quad, conjugates, quadrature_angle) - pdf = quad_basis if is_dm else math.abs(quad_basis) ** 2 - - return x, math.real(pdf) - - -def sample_homodyne(state: Tensor, quadrature_angle: float = 0.0) -> tuple[float, float]: - r"""Given a single-mode state, it generates the pdf of :math:`\tr [ \rho |x_\phi> 2: - raise ValueError( - "Input state has dimension {state.shape}. Make sure is either a single-mode ket or dm." - ) - - x, pdf = quadrature_distribution(state, quadrature_angle) - probs = pdf * (x[1] - x[0]) - - # draw a sample from the distribution - pdf = math.Categorical(probs=probs, name="homodyne_dist") - sample_idx = pdf.sample() - homodyne_sample = math.gather(x, sample_idx) - probability_sample = math.gather(probs, sample_idx) - - return homodyne_sample, probability_sample - - -@math.custom_gradient -def displacement(x, y, shape, tol=1e-15): - r"""creates a single mode displacement matrix""" - alpha = math.asnumpy(x) + 1j * math.asnumpy(y) - - if np.sqrt(x * x + y * y) > tol: - gate = strategies.displacement(tuple(shape), alpha) - else: - gate = math.eye(max(shape), dtype="complex128")[: shape[0], : shape[1]] - - ret = math.astensor(gate, dtype=gate.dtype.name) - if math.backend_name == "numpy": - return ret - - def grad(dL_dDc): - dD_da, dD_dac = strategies.jacobian_displacement(math.asnumpy(gate), alpha) - dL_dac = np.sum(np.conj(dL_dDc) * dD_dac + dL_dDc * np.conj(dD_da)) - dLdx = 2 * np.real(dL_dac) - dLdy = 2 * np.imag(dL_dac) - return math.astensor(dLdx, dtype=x.dtype), math.astensor(dLdy, dtype=y.dtype) - - return ret, grad - - -@math.custom_gradient -def beamsplitter(theta: float, phi: float, shape: Sequence[int], method: str): - r"""Creates a beamsplitter tensor with given cutoffs using a numba-based fock lattice strategy. - - Args: - theta (float): transmittivity angle of the beamsplitter - phi (float): phase angle of the beamsplitter - cutoffs (int,int): cutoff dimensions of the two modes - """ - if method == "vanilla": - bs_unitary = strategies.beamsplitter(shape, math.asnumpy(theta), math.asnumpy(phi)) - elif method == "schwinger": - bs_unitary = strategies.beamsplitter_schwinger( - shape, math.asnumpy(theta), math.asnumpy(phi) - ) - else: - raise ValueError( - f"Unknown beamsplitter method {method}. Options are 'vanilla' and 'schwinger'." - ) - - ret = math.astensor(bs_unitary, dtype=bs_unitary.dtype.name) - if math.backend_name == "numpy": - return ret - - def vjp(dLdGc): - dtheta, dphi = strategies.beamsplitter_vjp( - math.asnumpy(bs_unitary), - math.asnumpy(math.conj(dLdGc)), - math.asnumpy(theta), - math.asnumpy(phi), - ) - return math.astensor(dtheta, dtype=theta.dtype), math.astensor(dphi, dtype=phi.dtype) - - return ret, vjp - - -@math.custom_gradient -def squeezer(r, phi, shape): - r"""creates a single mode squeezer matrix using a numba-based fock lattice strategy""" - sq_unitary = strategies.squeezer(shape, math.asnumpy(r), math.asnumpy(phi)) - - ret = math.astensor(sq_unitary, dtype=sq_unitary.dtype.name) - if math.backend_name == "numpy": - return ret - - def vjp(dLdGc): - dr, dphi = strategies.squeezer_vjp( - math.asnumpy(sq_unitary), - math.asnumpy(math.conj(dLdGc)), - math.asnumpy(r), - math.asnumpy(phi), - ) - return math.astensor(dr, dtype=r.dtype), math.astensor(dphi, phi.dtype) - - return ret, vjp - - -@math.custom_gradient -def squeezed(r, phi, shape): - r"""creates a single mode squeezed state using a numba-based fock lattice strategy""" - sq_ket = strategies.squeezed(shape, math.asnumpy(r), math.asnumpy(phi)) - - ret = math.astensor(sq_ket, dtype=sq_ket.dtype.name) - if math.backend_name == "numpy": - return ret - - def vjp(dLdGc): - dr, dphi = strategies.squeezed_vjp( - math.asnumpy(sq_ket), - math.asnumpy(math.conj(dLdGc)), - math.asnumpy(r), - math.asnumpy(phi), - ) - return math.astensor(dr, dtype=r.dtype), math.astensor(dphi, phi.dtype) - - return ret, vjp diff --git a/tests/test_lab/test_gates_fock.py b/tests/test_lab/test_gates_fock.py index 89ab6506a..c499d8b9e 100644 --- a/tests/test_lab/test_gates_fock.py +++ b/tests/test_lab/test_gates_fock.py @@ -45,7 +45,7 @@ ) from mrmustard.lab.states import TMSV, Fock, SqueezedVacuum, State from mrmustard.math.lattice import strategies -from mrmustard.physics import fock +from mrmustard.physics import fock_utils from tests.random import ( angle, array_of_, @@ -163,7 +163,7 @@ def test_fock_representation_displacement(cutoffs, x, y): # compare with the standard way of calculating # transformation unitaries using the Choi isomorphism X, _, d = dgate.XYd(allow_none=False) - expected_Ud = fock.wigner_to_fock_U(X, d, cutoffs) + expected_Ud = fock_utils.wigner_to_fock_U(X, d, cutoffs) assert np.allclose(Ud, expected_Ud, atol=1e-5) @@ -188,7 +188,7 @@ def test_squeezer_grad_against_finite_differences(): dUdr = (Sgate(r + delta, phi).U(cutoffs) - Sgate(r - delta, phi).U(cutoffs)) / (2 * delta) dUdphi = (Sgate(r, phi + delta).U(cutoffs) - Sgate(r, phi - delta).U(cutoffs)) / (2 * delta) _, (gradr, gradphi) = math.value_and_gradients( - lambda: fock.squeezer(r, phi, shape=cutoffs), [r, phi] + lambda: fock_utils.squeezer(r, phi, shape=cutoffs), [r, phi] ) assert np.allclose(gradr, 2 * np.real(np.sum(dUdr))) assert np.allclose(gradphi, 2 * np.real(np.sum(dUdphi))) @@ -201,14 +201,16 @@ def test_displacement_grad(): y = math.new_variable(0.1, None, "y") alpha = math.asnumpy(math.make_complex(x, y)) delta = 1e-6 - dUdx = (fock.displacement(x + delta, y, cutoffs) - fock.displacement(x - delta, y, cutoffs)) / ( - 2 * delta - ) - dUdy = (fock.displacement(x, y + delta, cutoffs) - fock.displacement(x, y - delta, cutoffs)) / ( - 2 * delta - ) - - D = fock.displacement(x, y, shape=cutoffs) + dUdx = ( + fock_utils.displacement(x + delta, y, cutoffs) + - fock_utils.displacement(x - delta, y, cutoffs) + ) / (2 * delta) + dUdy = ( + fock_utils.displacement(x, y + delta, cutoffs) + - fock_utils.displacement(x, y - delta, cutoffs) + ) / (2 * delta) + + D = fock_utils.displacement(x, y, shape=cutoffs) dD_da, dD_dac = strategies.jacobian_displacement(math.asnumpy(D), alpha) assert np.allclose(dD_da + dD_dac, dUdx) assert np.allclose(1j * (dD_da - dD_dac), dUdy) @@ -297,7 +299,7 @@ def test_fock_representation_rgate(cutoffs, angles, modes): # compare with the standard way of calculating # transformation unitaries using the Choi isomorphism d = np.zeros(len(cutoffs) * 2) - expected_R = fock.wigner_to_fock_U(rgate.X_matrix, d, tuple(cutoffs + cutoffs)) + expected_R = fock_utils.wigner_to_fock_U(rgate.X_matrix, d, tuple(cutoffs + cutoffs)) assert np.allclose(R, expected_R, atol=1e-5) diff --git a/tests/test_lab_dev/test_circuit_components_utils.py b/tests/test_lab_dev/test_circuit_components_utils.py index 354059430..3433e6196 100644 --- a/tests/test_lab_dev/test_circuit_components_utils.py +++ b/tests/test_lab_dev/test_circuit_components_utils.py @@ -21,7 +21,7 @@ from mrmustard import math, settings from mrmustard.physics.triples import identity_Abc, displacement_map_s_parametrized_Abc -from mrmustard.physics.bargmann import wigner_to_bargmann_rho +from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho from mrmustard.physics.gaussian_integrals import ( contract_two_Abc, real_gaussian_integral, diff --git a/tests/test_lab_dev/test_states/test_number.py b/tests/test_lab_dev/test_states/test_number.py index 5321446ae..c67c3601a 100644 --- a/tests/test_lab_dev/test_states/test_number.py +++ b/tests/test_lab_dev/test_states/test_number.py @@ -19,7 +19,7 @@ import pytest from mrmustard import math -from mrmustard.physics.fock import fock_state +from mrmustard.physics.fock_utils import fock_state from mrmustard.lab_dev.states import Coherent, Number diff --git a/tests/test_math/test_compactFock.py b/tests/test_math/test_compactFock.py index 0e531f161..39e05b546 100644 --- a/tests/test_math/test_compactFock.py +++ b/tests/test_math/test_compactFock.py @@ -12,7 +12,7 @@ from mrmustard import math, settings from mrmustard.lab import Ggate, SqueezedVacuum, State, Vacuum from mrmustard.physics import fidelity, normalize -from mrmustard.physics.bargmann import wigner_to_bargmann_rho +from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho from mrmustard.training import Optimizer from tests.random import n_mode_mixed_state diff --git a/tests/test_math/test_lattice.py b/tests/test_math/test_lattice.py index 1a21f42f8..0c33bdb37 100644 --- a/tests/test_math/test_lattice.py +++ b/tests/test_math/test_lattice.py @@ -21,7 +21,7 @@ from mrmustard.lab import Gaussian, Dgate from mrmustard import lab_dev as mmld from mrmustard import settings, math -from mrmustard.physics.bargmann import wigner_to_bargmann_rho +from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho from mrmustard.math.lattice.strategies.binomial import binomial, binomial_dict from mrmustard.math.lattice.strategies.beamsplitter import ( apply_BS_schwinger, diff --git a/tests/test_physics/test_ansatz.py b/tests/test_physics/test_ansatz.py index 350287ce1..7b8b7ae9d 100644 --- a/tests/test_physics/test_ansatz.py +++ b/tests/test_physics/test_ansatz.py @@ -26,7 +26,7 @@ bargmann_Abc_to_phasespace_cov_means, ) from mrmustard.lab_dev.states.base import DM -from mrmustard.physics.bargmann import wigner_to_bargmann_rho +from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho from mrmustard.lab_dev.circuit_components_utils import BtoPS from ..random import Abc_triple diff --git a/tests/test_physics/test_bargmann.py b/tests/test_physics/test_bargmann.py index 65b4327da..e0d128ae3 100644 --- a/tests/test_physics/test_bargmann.py +++ b/tests/test_physics/test_bargmann.py @@ -3,7 +3,7 @@ from mrmustard import math from mrmustard.lab import Attenuator, Dgate, Gaussian, Ggate from mrmustard.lab_dev import Unitary, Vacuum, Channel -from mrmustard.physics.bargmann import ( +from mrmustard.physics.bargmann_utils import ( wigner_to_bargmann_Choi, wigner_to_bargmann_psi, wigner_to_bargmann_rho, diff --git a/tests/test_physics/test_fidelity.py b/tests/test_physics/test_fidelity.py index 116cb7c3c..499a5b64f 100644 --- a/tests/test_physics/test_fidelity.py +++ b/tests/test_physics/test_fidelity.py @@ -5,7 +5,7 @@ from mrmustard import physics, settings from mrmustard.lab import Coherent, Fock, State -from mrmustard.physics import fock as fp +from mrmustard.physics import fock_utils as fp from mrmustard.physics import gaussian as gp diff --git a/tests/test_physics/test_fock.py b/tests/test_physics/test_fock.py index bce02378d..95e32ddc5 100644 --- a/tests/test_physics/test_fock.py +++ b/tests/test_physics/test_fock.py @@ -40,7 +40,7 @@ Vacuum, ) from mrmustard.math.lattice.strategies import displacement, grad_displacement -from mrmustard.physics import fock +from mrmustard.physics import fock_utils # helper strategies st_angle = st.floats(min_value=0, max_value=2 * np.pi) @@ -49,15 +49,15 @@ def test_fock_state(): n = [4, 5, 6] - array1 = fock.fock_state(n) + array1 = fock_utils.fock_state(n) assert array1.shape == (5, 6, 7) assert array1[4, 5, 6] == 1 - array2 = fock.fock_state(n, cutoffs=10) + array2 = fock_utils.fock_state(n, cutoffs=10) assert array2.shape == (11, 11, 11) assert array2[4, 5, 6] == 1 - array3 = fock.fock_state(n, cutoffs=[5, 6, 7]) + array3 = fock_utils.fock_state(n, cutoffs=[5, 6, 7]) assert array3.shape == (6, 7, 8) assert array3[4, 5, 6] == 1 @@ -66,10 +66,10 @@ def test_fock_state_error(): n = [4, 5] with pytest.raises(ValueError): - fock.fock_state(n, cutoffs=[5, 6, 7]) + fock_utils.fock_state(n, cutoffs=[5, 6, 7]) with pytest.raises(ValueError): - fock.fock_state(n, cutoffs=2) + fock_utils.fock_state(n, cutoffs=2) @given(n_mean=st.floats(0, 3), phi=st_angle) @@ -203,13 +203,13 @@ def test_dm_to_ket(state, kwargs): """Tests pure state density matrix conversion to ket""" state = state(**kwargs) dm = state.dm() - ket = fock.dm_to_ket(dm) + ket = fock_utils.dm_to_ket(dm) # check if ket is normalized assert np.allclose(np.linalg.norm(ket), 1, atol=1e-4) # check kets are equivalent assert np.allclose(ket, state.ket(), atol=1e-4) - dm_reconstructed = fock.ket_to_dm(ket) + dm_reconstructed = fock_utils.ket_to_dm(ket) # check ket leads to same dm assert np.allclose(dm, dm_reconstructed, atol=1e-15) @@ -220,7 +220,7 @@ def test_dm_to_ket_error(): e = ValueError if math.backend_name == "tensorflow" else TypeError with pytest.raises(e): - fock.dm_to_ket(state) + fock_utils.dm_to_ket(state) def test_fock_trace_mode1_dm(): @@ -259,14 +259,14 @@ def test_fock_trace_function(): """tests that the Fock state is correctly traced""" state = Vacuum(2) >> Ggate(2) >> Attenuator([0.1, 0.1]) dm = state.dm([3, 20]) - dm_traced = fock.trace(dm, keep=[0]) + dm_traced = fock_utils.trace(dm, keep=[0]) assert np.allclose(dm_traced, State(dm=dm).get_modes(0).dm(), atol=1e-5) def test_dm_choi(): """tests that choi op is correctly applied to a dm""" circ = Ggate(1) >> Attenuator([0.1]) - dm_out = fock.apply_choi_to_dm(circ.choi([10]), Vacuum(1).dm([10]), [0], [0]) + dm_out = fock_utils.apply_choi_to_dm(circ.choi([10]), Vacuum(1).dm([10]), [0], [0]) dm_expected = (Vacuum(1) >> circ).dm([10]) assert np.allclose(dm_out, dm_expected, atol=1e-5) @@ -281,7 +281,7 @@ def test_apply_kraus_to_ket_1mode(): """Test that Kraus operators are applied to a ket on the correct indices""" ket = np.random.normal(size=(2, 3, 4)) kraus = np.random.normal(size=(5, 3)) - ket_out = fock.apply_kraus_to_ket(kraus, ket, [1], [1]) + ket_out = fock_utils.apply_kraus_to_ket(kraus, ket, [1], [1]) assert ket_out.shape == (2, 5, 4) @@ -289,7 +289,9 @@ def test_apply_kraus_to_ket_1mode_with_argument_names(): """Test that Kraus operators are applied to a ket on the correct indices with argument names""" ket = np.random.normal(size=(2, 3, 4)) kraus = np.random.normal(size=(5, 3)) - ket_out = fock.apply_kraus_to_ket(kraus=kraus, ket=ket, kraus_in_modes=[1], kraus_out_modes=[1]) + ket_out = fock_utils.apply_kraus_to_ket( + kraus=kraus, ket=ket, kraus_in_modes=[1], kraus_out_modes=[1] + ) assert ket_out.shape == (2, 5, 4) @@ -297,7 +299,7 @@ def test_apply_kraus_to_ket_2mode(): """Test that Kraus operators are applied to a ket on the correct indices""" ket = np.random.normal(size=(2, 3, 4)) kraus = np.random.normal(size=(5, 3, 4)) - ket_out = fock.apply_kraus_to_ket(kraus, ket, [1, 2], [1]) + ket_out = fock_utils.apply_kraus_to_ket(kraus, ket, [1, 2], [1]) assert ket_out.shape == (2, 5) @@ -305,7 +307,7 @@ def test_apply_kraus_to_ket_2mode_2(): """Test that Kraus operators are applied to a ket on the correct indices""" ket = np.random.normal(size=(2, 3)) kraus = np.random.normal(size=(5, 4, 3)) - ket_out = fock.apply_kraus_to_ket(kraus, ket, [1], [1, 2]) + ket_out = fock_utils.apply_kraus_to_ket(kraus, ket, [1], [1, 2]) assert ket_out.shape == (2, 5, 4) @@ -313,7 +315,7 @@ def test_apply_kraus_to_dm_1mode(): """Test that Kraus operators are applied to a dm on the correct indices""" dm = np.random.normal(size=(2, 3, 2, 3)) kraus = np.random.normal(size=(5, 3)) - dm_out = fock.apply_kraus_to_dm(kraus, dm, [1], [1]) + dm_out = fock_utils.apply_kraus_to_dm(kraus, dm, [1], [1]) assert dm_out.shape == (2, 5, 2, 5) @@ -321,7 +323,9 @@ def test_apply_kraus_to_dm_1mode_with_argument_names(): """Test that Kraus operators are applied to a dm on the correct indices with argument names""" dm = np.random.normal(size=(2, 3, 2, 3)) kraus = np.random.normal(size=(5, 3)) - dm_out = fock.apply_kraus_to_dm(kraus=kraus, dm=dm, kraus_in_modes=[1], kraus_out_modes=[1]) + dm_out = fock_utils.apply_kraus_to_dm( + kraus=kraus, dm=dm, kraus_in_modes=[1], kraus_out_modes=[1] + ) assert dm_out.shape == (2, 5, 2, 5) @@ -329,7 +333,7 @@ def test_apply_kraus_to_dm_2mode(): """Test that Kraus operators are applied to a dm on the correct indices""" dm = np.random.normal(size=(2, 3, 4, 2, 3, 4)) kraus = np.random.normal(size=(5, 3, 4)) - dm_out = fock.apply_kraus_to_dm(kraus, dm, [1, 2], [1]) + dm_out = fock_utils.apply_kraus_to_dm(kraus, dm, [1, 2], [1]) assert dm_out.shape == (2, 5, 2, 5) @@ -337,7 +341,7 @@ def test_apply_kraus_to_dm_2mode_2(): """Test that Kraus operators are applied to a dm on the correct indices""" dm = np.random.normal(size=(2, 3, 4, 2, 3, 4)) kraus = np.random.normal(size=(5, 6, 3)) - dm_out = fock.apply_kraus_to_dm(kraus, dm, [1], [3, 1]) + dm_out = fock_utils.apply_kraus_to_dm(kraus, dm, [1], [3, 1]) assert dm_out.shape == (2, 6, 4, 5, 2, 6, 4, 5) @@ -345,7 +349,7 @@ def test_apply_choi_to_ket_1mode(): """Test that choi operators are applied to a ket on the correct indices""" ket = np.random.normal(size=(3, 5)) choi = np.random.normal(size=(4, 3, 4, 3)) # [out_l, in_l, out_r, in_r] - ket_out = fock.apply_choi_to_ket(choi, ket, [0], [0]) + ket_out = fock_utils.apply_choi_to_ket(choi, ket, [0], [0]) assert ket_out.shape == (4, 5, 4, 5) @@ -353,7 +357,9 @@ def test_apply_choi_to_ket_1mode_with_argument_names(): """Test that choi operators are applied to a ket on the correct indices with argument names""" ket = np.random.normal(size=(3, 5)) choi = np.random.normal(size=(4, 3, 4, 3)) # [out_l, in_l, out_r, in_r] - ket_out = fock.apply_choi_to_ket(choi=choi, ket=ket, choi_in_modes=[0], choi_out_modes=[0]) + ket_out = fock_utils.apply_choi_to_ket( + choi=choi, ket=ket, choi_in_modes=[0], choi_out_modes=[0] + ) assert ket_out.shape == (4, 5, 4, 5) @@ -361,7 +367,7 @@ def test_apply_choi_to_ket_2mode(): """Test that choi operators are applied to a ket on the correct indices""" ket = np.random.normal(size=(3, 5)) choi = np.random.normal(size=(2, 3, 5, 2, 3, 5)) # [out_l, in_l, out_r, in_r] - ket_out = fock.apply_choi_to_ket(choi, ket, [0, 1], [0]) + ket_out = fock_utils.apply_choi_to_ket(choi, ket, [0, 1], [0]) assert ket_out.shape == (2, 2) @@ -369,7 +375,7 @@ def test_apply_choi_to_dm_1mode(): """Test that choi operators are applied to a dm on the correct indices""" dm = np.random.normal(size=(3, 5, 3, 5)) choi = np.random.normal(size=(4, 3, 4, 3)) # [out_l, in_l, out_r, in_r] - dm_out = fock.apply_choi_to_dm(choi, dm, [0], [0]) + dm_out = fock_utils.apply_choi_to_dm(choi, dm, [0], [0]) assert dm_out.shape == (4, 5, 4, 5) @@ -377,7 +383,7 @@ def test_apply_choi_to_dm_1mode_with_argument_names(): """Test that choi operators are applied to a dm on the correct indices with argument names""" dm = np.random.normal(size=(3, 5, 3, 5)) choi = np.random.normal(size=(4, 3, 4, 3)) # [out_l, in_l, out_r, in_r] - dm_out = fock.apply_choi_to_dm(choi=choi, dm=dm, choi_in_modes=[0], choi_out_modes=[0]) + dm_out = fock_utils.apply_choi_to_dm(choi=choi, dm=dm, choi_in_modes=[0], choi_out_modes=[0]) assert dm_out.shape == (4, 5, 4, 5) @@ -385,7 +391,7 @@ def test_apply_choi_to_dm_2mode(): """Test that choi operators are applied to a dm on the correct indices""" dm = np.random.normal(size=(4, 5, 4, 5)) choi = np.random.normal(size=(2, 3, 5, 2, 3, 5)) # [out_l_1,2, in_l_1, out_r_1,2, in_r_1] - dm_out = fock.apply_choi_to_dm(choi, dm, [1], [1, 2]) + dm_out = fock_utils.apply_choi_to_dm(choi, dm, [1], [1, 2]) assert dm_out.shape == (4, 2, 3, 4, 2, 3) @@ -466,15 +472,17 @@ def test_number_means(x, y): @given(x=st.floats(-1, 1), y=st.floats(-1, 1)) def test_number_variances_coh(x, y): - assert np.allclose(fock.number_variances(Coherent(x, y).ket([80]), False)[0], x * x + y * y) - assert np.allclose(fock.number_variances(Coherent(x, y).dm([80]), True)[0], x * x + y * y) + assert np.allclose( + fock_utils.number_variances(Coherent(x, y).ket([80]), False)[0], x * x + y * y + ) + assert np.allclose(fock_utils.number_variances(Coherent(x, y).dm([80]), True)[0], x * x + y * y) def test_number_variances_fock(): - assert np.allclose(fock.number_variances(Fock(n=1).ket(), False), 0) - assert np.allclose(fock.number_variances(Fock(n=1).dm(), True), 0) + assert np.allclose(fock_utils.number_variances(Fock(n=1).ket(), False), 0) + assert np.allclose(fock_utils.number_variances(Fock(n=1).dm(), True), 0) def test_normalize_dm(): dm = np.array([[0.2, 0], [0, 0.2]]) - assert np.allclose(fock.normalize(dm, True), np.array([[0.5, 0], [0, 0.5]])) + assert np.allclose(fock_utils.normalize(dm, True), np.array([[0.5, 0], [0, 0.5]])) From 5d3db1ea77000d8f08fbd90befbfce7986949e64 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 9 Sep 2024 14:49:36 -0400 Subject: [PATCH 02/87] rename --- mrmustard/physics/bargmann_utils.py | 251 +++++++ mrmustard/physics/fock_utils.py | 1033 +++++++++++++++++++++++++++ 2 files changed, 1284 insertions(+) create mode 100644 mrmustard/physics/bargmann_utils.py create mode 100644 mrmustard/physics/fock_utils.py diff --git a/mrmustard/physics/bargmann_utils.py b/mrmustard/physics/bargmann_utils.py new file mode 100644 index 000000000..6600cf371 --- /dev/null +++ b/mrmustard/physics/bargmann_utils.py @@ -0,0 +1,251 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains functions for performing calculations on objects in the Bargmann representations. +""" + +import numpy as np + +from mrmustard import math, settings +from mrmustard.physics.husimi import pq_to_aadag, wigner_to_husimi +from mrmustard.utils.typing import ComplexMatrix + + +def cayley(X, c): + r"""Returns the Cayley transform of a matrix: + :math:`cay(X) = (X - cI)(X + cI)^{-1}` + + Args: + c (float): the parameter of the Cayley transform + X (Tensor): a matrix + + Returns: + Tensor: the Cayley transform of X + """ + I = math.eye(X.shape[0], dtype=X.dtype) + return math.solve(X + c * I, X - c * I) + + +def wigner_to_bargmann_rho(cov, means): + r"""Converts the wigner representation in terms of covariance matrix and mean vector into the Bargmann `A,B,C` triple + for a density matrix (i.e. for `M` modes, `A` has shape `2M x 2M` and `B` has shape `2M`). + The order of the rows/columns of A and B corresponds to a density matrix with the usual ordering of the indices. + + Note that here A and B are defined with respect to the literature. + """ + N = cov.shape[-1] // 2 + A = math.matmul(math.Xmat(N), cayley(pq_to_aadag(cov), c=0.5)) + Q, beta = wigner_to_husimi(cov, means) + b = math.solve(Q, beta) + B = math.conj(b) + num_C = math.exp(-0.5 * math.sum(math.conj(beta) * b)) + detQ = math.det(Q) + den_C = math.sqrt(detQ, dtype=num_C.dtype) + C = num_C / den_C + return A, B, C + + +def wigner_to_bargmann_psi(cov, means): + r"""Converts the wigner representation in terms of covariance matrix and mean vector into the Bargmann A,B,C triple + for a Hilbert vector (i.e. for M modes, A has shape M x M and B has shape M). + """ + N = cov.shape[-1] // 2 + A, B, C = wigner_to_bargmann_rho(cov, means) + return A[N:, N:], B[N:], math.sqrt(C) + # NOTE: c for th psi is to calculated from the global phase formula. + + +def wigner_to_bargmann_Choi(X, Y, d): + r"""Converts the wigner representation in terms of covariance matrix and mean vector into the Bargmann `A,B,C` triple + for a channel (i.e. for M modes, A has shape 4M x 4M and B has shape 4M).""" + N = X.shape[-1] // 2 + I2 = math.eye(2 * N, dtype=X.dtype) + XT = math.transpose(X) + xi = 0.5 * (I2 + math.matmul(X, XT) + 2 * Y / settings.HBAR) + detxi = math.det(xi) + xi_inv = math.inv(xi) + A = math.block( + [ + [I2 - xi_inv, math.matmul(xi_inv, X)], + [math.matmul(XT, xi_inv), I2 - math.matmul(math.matmul(XT, xi_inv), X)], + ] + ) + I = math.eye(N, dtype=math.complex128) + o = math.zeros_like(I) + R = math.block( + [[I, 1j * I, o, o], [o, o, I, -1j * I], [I, -1j * I, o, o], [o, o, I, 1j * I]] + ) / np.sqrt(2) + A = math.matmul(math.matmul(R, A), math.dagger(R)) + A = math.matmul(math.Xmat(2 * N), A) + b = math.matvec(xi_inv, d) + B = math.matvec(math.conj(R), math.concat([b, -math.matvec(XT, b)], axis=-1)) / math.sqrt( + settings.HBAR, dtype=R.dtype + ) + C = math.exp(-0.5 * math.sum(d * b) / settings.HBAR) / math.sqrt(detxi, dtype=b.dtype) + # now A and B have order [out_r, in_r out_l, in_l]. + return A, B, math.cast(C, math.complex128) + + +def wigner_to_bargmann_U(X, d): + r"""Converts the wigner representation in terms of covariance matrix and mean vector into the Bargmann `A,B,C` triple + for a unitary (i.e. for `M` modes, `A` has shape `2M x 2M` and `B` has shape `2M`). + """ + N = X.shape[-1] // 2 + A, B, C = wigner_to_bargmann_Choi(X, math.zeros_like(X), d) + return A[2 * N :, 2 * N :], B[2 * N :], math.sqrt(C) + + +def norm_ket(A, b, c): + r"""Calculates the l2 norm of a Ket with a representation given by the Bargmann triple A,b,c.""" + M = math.block([[math.conj(A), -math.eye_like(A)], [-math.eye_like(A), A]]) + B = math.concat([math.conj(b), b], 0) + norm_squared = ( + math.abs(c) ** 2 + * math.exp(-0.5 * math.sum(B * math.matvec(math.inv(M), B))) + / math.sqrt((-1) ** A.shape[-1] * math.det(M)) + ) + return math.real(math.sqrt(norm_squared)) + + +def trace_dm(A, b, c): + r"""Calculates the total trace of the density matrix with representation given by the Bargmann triple A,b,c.""" + M = A - math.Xmat(A.shape[-1] // 2) + trace = ( + c + * math.exp(-0.5 * math.sum(b * math.matvec(math.inv(M), b))) + / math.sqrt((-1) ** (A.shape[-1] // 2) * math.det(M)) + ) + return math.real(trace) + + +def au2Symplectic(A): + r""" + helper for finding the Au of a unitary from its symplectic rep. + Au : in bra-ket order + """ + # A represents the A matrix corresponding to unitary U + A = A * (1.0 + 0.0 * 1j) + m = A.shape[-1] + m = m // 2 + + # identifying blocks of A_u + u_2 = A[..., :m, m:] + u_3 = A[..., m:, m:] + + # The formula to apply comes here + S_1 = math.conj(math.inv(math.transpose(u_2))) + S_2 = -S_1 @ math.conj(u_3) + S_3 = math.conj(S_2) + S_4 = math.conj(S_1) + + S = math.block([[S_1, S_2], [S_3, S_4]]) + + transformation = ( + 1 + / np.sqrt(2) + * math.block( + [ + [math.eye(m, dtype=math.complex128), math.eye(m, dtype=math.complex128)], + [-1j * math.eye(m, dtype=math.complex128), 1j * math.eye(m, dtype=math.complex128)], + ] + ) + ) + + return math.real(transformation @ S @ math.conj(math.transpose(transformation))) + + +def symplectic2Au(S): + r""" + The inverse of au2Symplectic i.e., returns symplectic, given Au + + S: symplectic in XXPP order + """ + m = S.shape[-1] + m = m // 2 + # the following lines of code transform the quadrature symplectic matrix to + # the annihilation one + transformation = ( + 1 + / np.sqrt(2) + * math.block( + [ + [math.eye(m, dtype=math.complex128), math.eye(m, dtype=math.complex128)], + [-1j * math.eye(m, dtype=math.complex128), 1j * math.eye(m, dtype=math.complex128)], + ] + ) + ) + S = np.conjugate(math.transpose(transformation)) @ S @ transformation + # identifying blocks of S + S_1 = S[:m, :m] + S_2 = S[:m, m:] + + # TODO: broadcasting/batch stuff consider a batch dimension + + # the formula to apply comes here + A_1 = S_2 @ math.conj(math.inv(S_1)) # use solve for inverse + A_2 = math.conj(math.inv(math.transpose(S_1))) + A_3 = math.transpose(A_2) + A_4 = -math.conj(math.solve(S_1, S_2)) + # -np.conjugate(np.linalg.pinv(S_1)) @ np.conjugate(S_2) + + A = math.block([[A_1, A_2], [A_3, A_4]]) + + return A + + +def XY_of_channel(A: ComplexMatrix): + r""" + Outputting the X and Y matrices corresponding to a channel determined by the "A" + matrix. + + Args: + A: the A matrix of the channel + """ + n = A.shape[-1] // 2 + m = n // 2 + + # here we transform to the other convention for wires i.e. {out-bra, out-ket, in-bra, in-ket} + A_out = math.block( + [[A[:m, :m], A[:m, 2 * m : 3 * m]], [A[2 * m : 3 * m, :m], A[2 * m : 3 * m, 2 * m : 3 * m]]] + ) + R = math.block( + [ + [A[:m, m : 2 * m], A[:m, 3 * m :]], + [A[2 * m : 3 * m, m : 2 * m], A[2 * m : 3 * m, 3 * m :]], + ] + ) + X_tilde = -math.inv(np.eye(n) - math.Xmat(m) @ A_out) @ math.Xmat(m) @ R @ math.Xmat(m) + transformation = math.block( + [ + [math.eye(m, dtype=math.complex128), math.eye(m, dtype=math.complex128)], + [-1j * math.eye(m, dtype=math.complex128), 1j * math.eye(m, dtype=math.complex128)], + ] + ) + X = -transformation @ X_tilde @ math.conj(transformation).T / 2 + + sigma_H = math.inv(math.eye(n) - math.Xmat(m) @ A_out) # the complex-Husimi covariance matrix + + N = sigma_H[m:, m:] + M = sigma_H[:m, m:] + sigma = ( + math.block([[math.real(N + M), math.imag(N + M)], [math.imag(M - N), math.real(N - M)]]) + - math.eye(n) / 2 + ) + Y = sigma - X @ X.T / 2 + if math.norm(math.imag(X)) > settings.ATOL or math.norm(math.imag(Y)) > settings.ATOL: + raise ValueError( + "Invalid input for the A matrix of channel, caused imaginary X and/or Y matrices." + ) + return math.real(X), math.real(Y) diff --git a/mrmustard/physics/fock_utils.py b/mrmustard/physics/fock_utils.py new file mode 100644 index 000000000..0d4f7c83a --- /dev/null +++ b/mrmustard/physics/fock_utils.py @@ -0,0 +1,1033 @@ +# Copyright 2021 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=redefined-outer-name + +""" +This module contains functions for performing calculations on objects in the Fock representations. +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import Sequence, Iterable + +import numpy as np + +from mrmustard import math, settings +from mrmustard.math.lattice import strategies +from mrmustard.math.caching import tensor_int_cache +from mrmustard.math.tensor_wrappers.mmtensor import MMTensor +from mrmustard.physics.bargmann_utils import ( + wigner_to_bargmann_Choi, + wigner_to_bargmann_psi, + wigner_to_bargmann_rho, + wigner_to_bargmann_U, +) +from mrmustard.utils.typing import ComplexTensor, Matrix, Scalar, Tensor, Vector, Batch + +SQRT = np.sqrt(np.arange(1e6)) + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~ static functions ~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +def fock_state(n: Sequence[int], cutoffs: int | Sequence[int] | None = None) -> Tensor: + r""" + The Fock array of a tensor product of one-mode ``Number`` states. + + Args: + n: The photon numbers of the number states. + cutoffs: The cutoffs of the arrays for the number states. If it is given as + an ``int``, it is broadcasted to all the states. If ``None``, it + defaults to ``[n1+1, n2+1, ...]``, where ``ni`` is the photon number + of the ``i``th mode. + + Returns: + The Fock array of a tensor product of one-mode ``Number`` states. + """ + n = math.atleast_1d(n) + if cutoffs is None: + cutoffs = list(n) + elif isinstance(cutoffs, int): + cutoffs = [cutoffs] * len(n) + + if len(cutoffs) != len(n): + msg = f"Expected ``len(cutoffs)={len(n)}`` but found ``{len(cutoffs)}``." + raise ValueError(msg) + + shape = tuple([c + 1 for c in cutoffs]) + array = np.zeros(shape, dtype=np.complex128) + + try: + array[tuple(n)] = 1 + except IndexError: + msg = "Photon numbers cannot be larger than the corresponding cutoffs." + raise ValueError(msg) + + return math.astensor(array) + + +def autocutoffs(cov: Matrix, means: Vector, probability: float): + r"""Returns the cutoffs of a Gaussian state by computing the 1-mode marginals until + the probability of the marginal is less than ``probability``. + + Args: + cov: the covariance matrix + means: the means vector + probability: the cutoff probability + + Returns: + Tuple[int, ...]: the suggested cutoffs + """ + M = len(means) // 2 + cutoffs = [] + for i in range(M): + cov_i = np.array([[cov[i, i], cov[i, i + M]], [cov[i + M, i], cov[i + M, i + M]]]) + means_i = np.array([means[i], means[i + M]]) + # apply 1-d recursion until probability is less than 0.99 + A, B, C = [math.asnumpy(x) for x in wigner_to_bargmann_rho(cov_i, means_i)] + diag = math.hermite_renormalized_diagonal(A, B, C, cutoffs=[settings.AUTOCUTOFF_MAX_CUTOFF]) + # find at what index in the cumsum the probability is more than 0.99 + for i, val in enumerate(np.cumsum(diag)): + if val > probability: + cutoffs.append(max(i + 1, settings.AUTOCUTOFF_MIN_CUTOFF)) + break + else: + cutoffs.append(settings.AUTOCUTOFF_MAX_CUTOFF) + return cutoffs + + +def wigner_to_fock_state( + cov: Matrix, + means: Vector, + shape: Sequence[int], + max_prob: float = 1.0, + max_photons: int | None = None, + return_dm: bool = True, +) -> Tensor: + r"""Returns the Fock representation of a Gaussian state. + Use with caution: if the cov matrix is that of a mixed state, + setting return_dm to False will produce nonsense. + If return_dm=False, we can apply max_prob and max_photons to stop the + computation of the Fock representation early, when those conditions are met. + + * If the state is pure it can return the state vector (ket) or the density matrix. + The index ordering is going to be [i's] in ket_i + * If the state is mixed it can return the density matrix. + The index order is going to be [i's,j's] in dm_ij + + Args: + cov: the Wigner covariance matrix + means: the Wigner means vector + shape: the shape of the tensor + max_prob: the maximum probability of a the state (applies only if the ket is returned) + max_photons: the maximum number of photons in the state (applies only if the ket is returned) + return_dm: whether to return the density matrix (otherwise it returns the ket) + + Returns: + Tensor: the fock representation + """ + if return_dm: + A, B, C = wigner_to_bargmann_rho(cov, means) + # NOTE: change the order of the index in AB + Xmat = math.Xmat(A.shape[-1] // 2) + A = math.matmul(math.matmul(Xmat, A), Xmat) + B = math.matvec(Xmat, B) + return math.hermite_renormalized(A, B, C, shape=shape) + else: # here we can apply max prob and max photons + A, B, C = wigner_to_bargmann_psi(cov, means) + if max_photons is None: + max_photons = sum(shape) - len(shape) + if max_prob < 1.0 or max_photons < sum(shape) - len(shape): + return math.hermite_renormalized_binomial( + A, B, C, shape=shape, max_l2=max_prob, global_cutoff=max_photons + 1 + ) + return math.hermite_renormalized(A, B, C, shape=tuple(shape)) + + +def wigner_to_fock_U(X, d, shape): + r"""Returns the Fock representation of a Gaussian unitary transformation. + The index order is out_l, in_l, where in_l is to be contracted with the indices of a ket, + or with the left indices of a density matrix. + + Arguments: + X: the X matrix + d: the d vector + shape: the shape of the tensor + + Returns: + Tensor: the fock representation of the unitary transformation + """ + A, B, C = wigner_to_bargmann_U(X, d) + return math.hermite_renormalized(A, B, C, shape=tuple(shape)) + + +def wigner_to_fock_Choi(X, Y, d, shape): + r"""Returns the Fock representation of a Gaussian Choi matrix. + The order of choi indices is :math:`[\mathrm{out}_l, \mathrm{in}_l, \mathrm{out}_r, \mathrm{in}_r]` + where :math:`\mathrm{in}_l` and :math:`\mathrm{in}_r` are to be contracted with the left and right indices of a density matrix. + + Arguments: + X: the X matrix + Y: the Y matrix + d: the d vector + shape: the shape of the tensor + + Returns: + Tensor: the fock representation of the Choi matrix + """ + A, B, C = wigner_to_bargmann_Choi(X, Y, d) + # NOTE: change the order of the index in AB + Xmat = math.Xmat(A.shape[-1] // 2) + A = math.matmul(math.matmul(Xmat, A), Xmat) + N = B.shape[-1] // 2 + B = math.concat([B[N:], B[:N]], axis=-1) + return math.hermite_renormalized(A, B, C, shape=tuple(shape)) + + +def ket_to_dm(ket: Tensor) -> Tensor: + r"""Maps a ket to a density matrix. + + Args: + ket: the ket + + Returns: + Tensor: the density matrix + """ + return math.outer(ket, math.conj(ket)) + + +def dm_to_ket(dm: Tensor) -> Tensor: + r"""Maps a density matrix to a ket if the state is pure. + + If the state is pure :math:`\hat \rho= |\psi\rangle\langle \psi|` then the + ket is the eigenvector of :math:`\rho` corresponding to the eigenvalue 1. + + Args: + dm (Tensor): the density matrix + + Returns: + Tensor: the ket + + Raises: + ValueError: if ket for mixed states cannot be calculated + """ + + is_pure_dm = np.isclose(purity(dm), 1.0, atol=1e-6) + if not is_pure_dm: + raise ValueError("Cannot calculate ket for mixed states.") + + cutoffs = dm.shape[: len(dm.shape) // 2] + d = int(np.prod(cutoffs)) + dm = math.reshape(dm, (d, d)) + + eigvals, eigvecs = math.eigh(dm) + # eigenvalues and related eigenvectors are sorted in non-decreasing order, + # meaning the associated eigvec to largest eigval is stored last. + ket = eigvecs[:, -1] * math.sqrt(eigvals[-1]) + ket = math.reshape(ket, cutoffs) + + return ket + + +def ket_to_probs(ket: Tensor) -> Tensor: + r"""Maps a ket to probabilities. + + Args: + ket: the ket + + Returns: + Tensor: the probabilities vector + """ + return math.abs(ket) ** 2 + + +def dm_to_probs(dm: Tensor) -> Tensor: + r"""Extracts the diagonals of a density matrix. + + Args: + dm: the density matrix + + Returns: + Tensor: the probabilities vector + """ + return math.all_diagonals(dm, real=True) + + +def U_to_choi(U: Tensor, Udual: Tensor | None = None) -> Tensor: + r"""Converts a unitary transformation to a Choi tensor. + + Args: + U: the unitary transformation + Udual: the dual unitary transformation (optional, will use conj U if not provided) + + Returns: + Tensor: the Choi tensor. The index order is going to be :math:`[\mathrm{out}_l, \mathrm{in}_l, \mathrm{out}_r, \mathrm{in}_r]` + where :math:`\mathrm{in}_l` and :math:`\mathrm{in}_r` are to be contracted with the left and right indices of the density matrix. + """ + return math.outer(U, math.conj(U) if Udual is None else Udual) + + +def fidelity(state_a, state_b, a_ket: bool, b_ket: bool) -> Scalar: + r"""Computes the fidelity between two states in Fock representation.""" + if a_ket and b_ket: + min_cutoffs = [slice(min(a, b)) for a, b in zip(state_a.shape, state_b.shape)] + state_a = state_a[tuple(min_cutoffs)] + state_b = state_b[tuple(min_cutoffs)] + return math.abs(math.sum(math.conj(state_a) * state_b)) ** 2 + + if a_ket: + min_cutoffs = [ + slice(min(a, b)) + for a, b in zip(state_a.shape, state_b.shape[: len(state_b.shape) // 2]) + ] + state_a = state_a[tuple(min_cutoffs)] + state_b = state_b[tuple(min_cutoffs * 2)] + a = math.reshape(state_a, -1) + return math.real( + math.sum(math.conj(a) * math.matvec(math.reshape(state_b, (len(a), len(a))), a)) + ) + + if b_ket: + min_cutoffs = [ + slice(min(a, b)) + for a, b in zip(state_a.shape[: len(state_a.shape) // 2], state_b.shape) + ] + state_a = state_a[tuple(min_cutoffs * 2)] + state_b = state_b[tuple(min_cutoffs)] + b = math.reshape(state_b, -1) + return math.real( + math.sum(math.conj(b) * math.matvec(math.reshape(state_a, (len(b), len(b))), b)) + ) + + # mixed state + # Richard Jozsa (1994) Fidelity for Mixed Quantum States, Journal of Modern Optics, 41:12, 2315-2323, DOI: 10.1080/09500349414552171 + + # trim states to have same cutoff + min_cutoffs = [ + slice(min(a, b)) + for a, b in zip( + state_a.shape[: len(state_a.shape) // 2], + state_b.shape[: len(state_b.shape) // 2], + ) + ] + state_a = state_a[tuple(min_cutoffs * 2)] + state_b = state_b[tuple(min_cutoffs * 2)] + return math.abs( + ( + math.trace( + math.sqrtm( + math.matmul(math.matmul(math.sqrtm(state_a), state_b), math.sqrtm(state_a)) + ) + ) + ** 2 + ) + ) + + +def number_means(tensor, is_dm: bool): + r"""Returns the mean of the number operator in each mode.""" + probs = math.all_diagonals(tensor, real=True) if is_dm else math.abs(tensor) ** 2 + modes = list(range(len(probs.shape))) + marginals = [math.sum(probs, axes=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] + return math.astensor( + [ + math.sum(marginal * math.arange(len(marginal), dtype=math.float64)) + for marginal in marginals + ] + ) + + +def number_variances(tensor, is_dm: bool): + r"""Returns the variance of the number operator in each mode.""" + probs = math.all_diagonals(tensor, real=True) if is_dm else math.abs(tensor) ** 2 + modes = list(range(len(probs.shape))) + marginals = [math.sum(probs, axes=modes[:k] + modes[k + 1 :]) for k in range(len(modes))] + return math.astensor( + [ + ( + math.sum(marginal * math.arange(marginal.shape[0], dtype=marginal.dtype) ** 2) + - math.sum(marginal * math.arange(marginal.shape[0], dtype=marginal.dtype)) ** 2 + ) + for marginal in marginals + ] + ) + + +def purity(dm: Tensor) -> Scalar: + r"""Returns the purity of a density matrix.""" + cutoffs = dm.shape[: len(dm.shape) // 2] + d = int(np.prod(cutoffs)) # combined cutoffs in all modes + dm = math.reshape(dm, (d, d)) + dm = dm / math.trace(dm) # assumes all nonzero values are included in the density matrix + return math.abs(math.sum(math.transpose(dm) * dm)) # tr(rho^2) + + +def validate_contraction_indices(in_idx, out_idx, M, name): + r"""Validates the indices used for the contraction of a tensor.""" + if len(set(in_idx)) != len(in_idx): + raise ValueError(f"{name}_in_idx should not contain repeated indices.") + if len(set(out_idx)) != len(out_idx): + raise ValueError(f"{name}_out_idx should not contain repeated indices.") + if not set(range(M)).intersection(out_idx).issubset(set(in_idx)): + wrong_indices = set(range(M)).intersection(out_idx) - set(in_idx) + raise ValueError( + f"Indices {wrong_indices} in {name}_out_idx are trying to replace uncontracted indices." + ) + + +def apply_kraus_to_ket(kraus, ket, kraus_in_modes, kraus_out_modes=None): + r"""Applies a kraus operator to a ket. + It assumes that the ket is indexed as left_1, ..., left_n. + + The kraus op has indices that contract with the ket (kraus_in_modes) and indices that are left over (kraus_out_modes). + The final index order will be sorted (note that an index appearing in both kraus_in_modes and kraus_out_modes will replace the original index). + + Args: + kraus (array): the kraus operator to be applied + ket (array): the ket to which the operator is applied + kraus_in_modes (list of ints): the indices (counting from 0) of the kraus operator that contract with the ket + kraus_out_modes (list of ints): the indices (counting from 0) of the kraus operator that are leftover + + Returns: + array: the resulting ket with indices as kraus_out_modes + uncontracted ket indices + """ + if kraus_out_modes is None: + kraus_out_modes = kraus_in_modes + + if not set(kraus_in_modes).issubset(range(ket.ndim)): + raise ValueError("kraus_in_modes should be a subset of the ket indices.") + + # check that there are no repeated indices in kraus_in_modes and kraus_out_modes (separately) + validate_contraction_indices(kraus_in_modes, kraus_out_modes, ket.ndim, "kraus") + + ket = MMTensor(ket, axis_labels=[f"in_left_{i}" for i in range(ket.ndim)]) + kraus = MMTensor( + kraus, + axis_labels=[f"out_left_{i}" for i in kraus_out_modes] + + [f"in_left_{i}" for i in kraus_in_modes], + ) + + # contract the operator with the ket. + # now the leftover indices are in the order kraus_out_modes + uncontracted ket indices + kraus_ket = kraus @ ket + + # sort kraus_ket.axis_labels by the int at the end of each label. + # Each label is guaranteed to have a unique int at the end. + new_axis_labels = sorted(kraus_ket.axis_labels, key=lambda x: int(x.split("_")[-1])) + + return kraus_ket.transpose(new_axis_labels).tensor + + +def apply_kraus_to_dm(kraus, dm, kraus_in_modes, kraus_out_modes=None): + r"""Applies a kraus operator to a density matrix. + It assumes that the density matrix is indexed as left_1, ..., left_n, right_1, ..., right_n. + + The kraus operator has indices that contract with the density matrix (kraus_in_modes) and indices that are leftover (kraus_out_modes). + `kraus` will contract from the left and from the right with the density matrix. For right contraction the kraus op is conjugated. + + Args: + kraus (array): the operator to be applied + dm (array): the density matrix to which the operator is applied + kraus_in_modes (list of ints): the indices (counting from 0) of the kraus operator that contract with the density matrix + kraus_out_modes (list of ints): the indices (counting from 0) of the kraus operator that are leftover (default None, in which case kraus_out_modes = kraus_in_modes) + + Returns: + array: the resulting density matrix + """ + if kraus_out_modes is None: + kraus_out_modes = kraus_in_modes + + if not set(kraus_in_modes).issubset(range(dm.ndim // 2)): + raise ValueError("kraus_in_modes should be a subset of the density matrix indices.") + + # check that there are no repeated indices in kraus_in_modes and kraus_out_modes (separately) + validate_contraction_indices(kraus_in_modes, kraus_out_modes, dm.ndim // 2, "kraus") + + dm = MMTensor( + dm, + axis_labels=[f"left_{i}" for i in range(dm.ndim // 2)] + + [f"right_{i}" for i in range(dm.ndim // 2)], + ) + kraus = MMTensor( + kraus, + axis_labels=[f"out_left_{i}" for i in kraus_out_modes] + + [f"left_{i}" for i in kraus_in_modes], + ) + kraus_conj = MMTensor( + math.conj(kraus.tensor), + axis_labels=[f"out_right_{i}" for i in kraus_out_modes] + + [f"right_{i}" for i in kraus_in_modes], + ) + + # contract the kraus operator with the density matrix from the left and from the right. + k_dm_k = kraus @ dm @ kraus_conj + # now the leftover indices are in the order: + # out_left_modes + uncontracted left indices + uncontracted right indices + out_right_modes + + # sort k_dm_k.axis_labels by the int at the end of each label, first left, then right + N = k_dm_k.tensor.ndim // 2 + left = sorted(k_dm_k.axis_labels[:N], key=lambda x: int(x.split("_")[-1])) + right = sorted(k_dm_k.axis_labels[N:], key=lambda x: int(x.split("_")[-1])) + + return k_dm_k.transpose(left + right).tensor + + +def apply_choi_to_dm( + choi: ComplexTensor, + dm: ComplexTensor, + choi_in_modes: Sequence[int], + choi_out_modes: Sequence[int] | None = None, +): + r"""Applies a choi operator to a density matrix. + It assumes that the density matrix is indexed as left_1, ..., left_n, right_1, ..., right_n. + + The choi operator has indices that contract with the density matrix (choi_in_modes) and indices that are left over (choi_out_modes). + `choi` will contract choi_in_modes from the left and from the right with the density matrix. + + Args: + choi (array): the choi operator to be applied + dm (array): the density matrix to which the choi operator is applied + choi_in_modes (list of ints): the input modes of the choi operator that contract with the density matrix + choi_out_modes (list of ints): the output modes of the choi operator + + Returns: + array: the resulting density matrix + """ + if choi_out_modes is None: + choi_out_modes = choi_in_modes + if not set(choi_in_modes).issubset(range(dm.ndim // 2)): + raise ValueError("choi_in_modes should be a subset of the density matrix indices.") + + # check that there are no repeated indices in kraus_in_modes and kraus_out_modes (separately) + validate_contraction_indices(choi_in_modes, choi_out_modes, dm.ndim // 2, "choi") + + dm = MMTensor( + dm, + axis_labels=[f"in_left_{i}" for i in range(dm.ndim // 2)] + + [f"in_right_{i}" for i in range(dm.ndim // 2)], + ) + choi = MMTensor( + choi, + axis_labels=[f"out_left_{i}" for i in choi_out_modes] + + [f"in_left_{i}" for i in choi_in_modes] + + [f"out_right_{i}" for i in choi_out_modes] + + [f"in_right_{i}" for i in choi_in_modes], + ) + + # contract the choi matrix with the density matrix. + # now the leftover indices are in the order out_left_modes + out_right_modes + uncontracted left indices + uncontracted right indices + choi_dm = choi @ dm + + # sort choi_dm.axis_labels by the int at the end of each label, first left, then right + left_labels = [label for label in choi_dm.axis_labels if "left" in label] + left = sorted(left_labels, key=lambda x: int(x.split("_")[-1])) + right_labels = [label for label in choi_dm.axis_labels if "right" in label] + right = sorted(right_labels, key=lambda x: int(x.split("_")[-1])) + + return choi_dm.transpose(left + right).tensor + + +def apply_choi_to_ket(choi, ket, choi_in_modes, choi_out_modes=None): + r"""Applies a choi operator to a ket. + It assumes that the ket is indexed as left_1, ..., left_n. + + The choi operator has indices that contract with the ket (choi_in_modes) and indices that are left over (choi_out_modes). + `choi` will contract choi_in_modes from the left and from the right with the ket. + + Args: + choi (array): the choi operator to be applied + ket (array): the ket to which the choi operator is applied + choi_in_modes (list of ints): the indices of the choi operator that contract with the ket + choi_out_modes (list of ints): the indices of the choi operator that re leftover + + Returns: + array: the resulting ket + """ + if choi_out_modes is None: + choi_out_modes = choi_in_modes + + if not set(choi_in_modes).issubset(range(ket.ndim)): + raise ValueError("choi_in_modes should be a subset of the ket indices.") + + # check that there are no repeated indices in kraus_in_modes and kraus_out_modes (separately) + validate_contraction_indices(choi_in_modes, choi_out_modes, ket.ndim, "choi") + + ket = MMTensor(ket, axis_labels=[f"left_{i}" for i in range(ket.ndim)]) + ket_dual = MMTensor(math.conj(ket.tensor), axis_labels=[f"right_{i}" for i in range(ket.ndim)]) + choi = MMTensor( + choi, + axis_labels=[f"out_left_{i}" for i in choi_out_modes] + + [f"left_{i}" for i in choi_in_modes] + + [f"out_right_{i}" for i in choi_out_modes] + + [f"right_{i}" for i in choi_in_modes], + ) + + # contract the choi matrix with the ket and its dual, like choi @ |ket> Tensor: + r"""Harmonic oscillator eigenstate wavefunction `\psi_n(q) = `. + + Args: + q (Vector): a vector containing the q points at which the function is evaluated (units of \sqrt{\hbar}) + cutoff (int): maximum number of photons + + Returns: + Tensor: a tensor of size ``len(q)*cutoff``. Each entry with index ``[i, j]`` represents the eigenstate evaluated + with number of photons ``i`` evaluated at position ``q[j]``, i.e., `\psi_i(q_j)`. + + .. details:: + + .. admonition:: Definition + :class: defn + + The q-quadrature eigenstates are defined as + + .. math:: + + \psi_n(x) = 1/sqrt[2^n n!](\frac{\omega}{\pi \hbar})^{1/4} + \exp{-\frac{\omega}{2\hbar} x^2} H_n(\sqrt{\frac{\omega}{\pi}} x) + + where :math:`H_n(x)` is the (physicists) `n`-th Hermite polynomial. + """ + hbar = settings.HBAR + x = math.cast(q / np.sqrt(hbar), math.complex128) # unit-less vector + + # prefactor term (\Omega/\hbar \pi)**(1/4) * 1 / sqrt(2**n) + prefactor = math.cast( + (np.pi * hbar) ** (-0.25) * math.pow(0.5, math.arange(0, cutoff) / 2), + math.complex128, + ) + + # Renormalized physicist hermite polys: Hn / sqrt(n!) + R = -np.array([[2 + 0j]]) # to get the physicist polys + + def f_hermite_polys(xi): + return math.hermite_renormalized(R, math.astensor([2 * xi]), 1 + 0j, [cutoff]) + + hermite_polys = math.map_fn(f_hermite_polys, x) + + # (real) wavefunction + psi = math.exp(-(x**2 / 2)) * math.transpose(prefactor * hermite_polys) + return psi + + +@lru_cache +def estimate_dx(cutoff, period_resolution=20): + r"""Estimates a suitable quadrature discretization interval `dx`. Uses the fact + that Fock state `n` oscillates with angular frequency :math:`\sqrt{2(n + 1)}`, + which follows from the relation + + .. math:: + + \psi^{[n]}'(q) = q - sqrt(2*(n + 1))*\psi^{[n+1]}(q) + + by setting q = 0, and approximating the oscillation amplitude by `\psi^{[n+1]}(0) + + Ref: https://en.wikipedia.org/wiki/Hermite_polynomials#Hermite_functions + + Args + cutoff (int): Fock cutoff + period_resolution (int): Number of points used to sample one Fock + wavefunction oscillation. Larger values yields better approximations + and thus smaller `dx`. + + Returns + (float): discretization value of quadrature + """ + fock_cutoff_frequency = np.sqrt(2 * (cutoff + 1)) + fock_cutoff_period = 2 * np.pi / fock_cutoff_frequency + dx_estimate = fock_cutoff_period / period_resolution + return dx_estimate + + +@lru_cache +def estimate_xmax(cutoff, minimum=5): + r"""Estimates a suitable quadrature axis length + + Args + cutoff (int): Fock cutoff + minimum (float): Minimum value of the returned xmax + + Returns + (float): maximum quadrature value + """ + if cutoff == 0: + xmax_estimate = 3 + else: + # maximum q for a classical particle with energy n=cutoff + classical_endpoint = np.sqrt(2 * cutoff) + # approximate probability of finding particle outside classical region + excess_probability = 1 / (7.464 * cutoff ** (1 / 3)) + # Emperical factor that yields reasonable results + A = 5 + xmax_estimate = classical_endpoint * (1 + A * excess_probability) + return max(minimum, xmax_estimate) + + +@lru_cache +def estimate_quadrature_axis(cutoff, minimum=5, period_resolution=20): + """Generates a suitable quadrature axis. + + Args + cutoff (int): Fock cutoff + minimum (float): Minimum value of the returned xmax + period_resolution (int): Number of points used to sample one Fock + wavefunction oscillation. Larger values yields better approximations + and thus smaller dx. + + Returns + (array): quadrature axis + """ + xmax = estimate_xmax(cutoff, minimum=minimum) + dx = estimate_dx(cutoff, period_resolution=period_resolution) + xaxis = np.arange(-xmax, xmax, dx) + xaxis = np.append(xaxis, xaxis[-1] + dx) + xaxis = xaxis - np.mean(xaxis) # center around 0 + return xaxis + + +def quadrature_basis( + fock_array: Tensor, + quad: Batch[Vector], + conjugates: bool | list[bool] = False, + phi: Scalar = 0.0, +): + r"""Given the Fock basis representation return the quadrature basis representation. + + Args: + fock_array (Tensor): fock tensor amplitudes + quad (Batch[Vector]): points at which the quadrature basis is evaluated + conjugates (list[bool]): which dimensions of the array to conjugate based on + whether it is a bra or a ket + phi (float): angle of the quadrature basis vector + + Returns: + tuple(Tensor): quadrature basis representation at the points in quad + """ + dims = len(fock_array.shape) + + if quad.shape[-1] != dims: + raise ValueError( + f"Input fock array has dimension {dims} whereas ``quad`` has {quad.shape[-1]}." + ) + + conjugates = conjugates if isinstance(conjugates, Iterable) else [conjugates] * dims + + # construct quadrature basis vectors + shapes = fock_array.shape + quad_basis_vecs = [] + for dim in range(dims): + q_to_n = oscillator_eigenstate(quad[..., dim], shapes[dim]) + if not np.isclose(phi, 0.0): + theta = -math.arange(shapes[dim]) * phi + Ur = math.make_complex(math.cos(theta), math.sin(theta)) + q_to_n = math.einsum("a,ab->ab", Ur, q_to_n) + if conjugates[dim]: + q_to_n = math.conj(q_to_n) + quad_basis_vecs += [math.cast(q_to_n, "complex128")] + + # Convert each dimension to quadrature + subscripts = [chr(i) for i in range(98, 98 + dims)] + fock_string = "".join(subscripts[:dims]) #'bcd....' + q_string = "".join([fock_string[i] + "a," for i in range(dims - 1)] + [fock_string[-1] + "a"]) + quad_array = math.einsum( + fock_string + "," + q_string + "->" + "a", fock_array, *quad_basis_vecs + ) + + return quad_array + + +def quadrature_distribution( + state: Tensor, + quadrature_angle: float = 0.0, + x: Vector | None = None, +): + r"""Given the ket or density matrix of a single-mode state, it generates the probability + density distribution :math:`\tr [ \rho |x_\phi> the quadrature eigenvector with angle `\phi` + equal to ``quadrature_angle``. + + Args: + state (Tensor): single mode state ket or density matrix + quadrature_angle (float): angle of the quadrature basis vector + x (Vector): points at which the quadrature distribution is evaluated + + Returns: + tuple(Vector, Vector): coordinates at which the pdf is evaluated and the probability distribution + """ + cutoff = state.shape[0] + if x is None: + x = np.sqrt(settings.HBAR) * math.new_constant(estimate_quadrature_axis(cutoff), "q_tensor") + + dims = len(state.shape) + is_dm = dims == 2 + + quad = math.transpose(math.astensor([x] * dims)) + conjugates = [True, False] if is_dm else [False] + quad_basis = quadrature_basis(state, quad, conjugates, quadrature_angle) + pdf = quad_basis if is_dm else math.abs(quad_basis) ** 2 + + return x, math.real(pdf) + + +def sample_homodyne(state: Tensor, quadrature_angle: float = 0.0) -> tuple[float, float]: + r"""Given a single-mode state, it generates the pdf of :math:`\tr [ \rho |x_\phi> 2: + raise ValueError( + "Input state has dimension {state.shape}. Make sure is either a single-mode ket or dm." + ) + + x, pdf = quadrature_distribution(state, quadrature_angle) + probs = pdf * (x[1] - x[0]) + + # draw a sample from the distribution + pdf = math.Categorical(probs=probs, name="homodyne_dist") + sample_idx = pdf.sample() + homodyne_sample = math.gather(x, sample_idx) + probability_sample = math.gather(probs, sample_idx) + + return homodyne_sample, probability_sample + + +@math.custom_gradient +def displacement(x, y, shape, tol=1e-15): + r"""creates a single mode displacement matrix""" + alpha = math.asnumpy(x) + 1j * math.asnumpy(y) + + if np.sqrt(x * x + y * y) > tol: + gate = strategies.displacement(tuple(shape), alpha) + else: + gate = math.eye(max(shape), dtype="complex128")[: shape[0], : shape[1]] + + ret = math.astensor(gate, dtype=gate.dtype.name) + if math.backend_name == "numpy": + return ret + + def grad(dL_dDc): + dD_da, dD_dac = strategies.jacobian_displacement(math.asnumpy(gate), alpha) + dL_dac = np.sum(np.conj(dL_dDc) * dD_dac + dL_dDc * np.conj(dD_da)) + dLdx = 2 * np.real(dL_dac) + dLdy = 2 * np.imag(dL_dac) + return math.astensor(dLdx, dtype=x.dtype), math.astensor(dLdy, dtype=y.dtype) + + return ret, grad + + +@math.custom_gradient +def beamsplitter(theta: float, phi: float, shape: Sequence[int], method: str): + r"""Creates a beamsplitter tensor with given cutoffs using a numba-based fock lattice strategy. + + Args: + theta (float): transmittivity angle of the beamsplitter + phi (float): phase angle of the beamsplitter + cutoffs (int,int): cutoff dimensions of the two modes + """ + if method == "vanilla": + bs_unitary = strategies.beamsplitter(shape, math.asnumpy(theta), math.asnumpy(phi)) + elif method == "schwinger": + bs_unitary = strategies.beamsplitter_schwinger( + shape, math.asnumpy(theta), math.asnumpy(phi) + ) + else: + raise ValueError( + f"Unknown beamsplitter method {method}. Options are 'vanilla' and 'schwinger'." + ) + + ret = math.astensor(bs_unitary, dtype=bs_unitary.dtype.name) + if math.backend_name == "numpy": + return ret + + def vjp(dLdGc): + dtheta, dphi = strategies.beamsplitter_vjp( + math.asnumpy(bs_unitary), + math.asnumpy(math.conj(dLdGc)), + math.asnumpy(theta), + math.asnumpy(phi), + ) + return math.astensor(dtheta, dtype=theta.dtype), math.astensor(dphi, dtype=phi.dtype) + + return ret, vjp + + +@math.custom_gradient +def squeezer(r, phi, shape): + r"""creates a single mode squeezer matrix using a numba-based fock lattice strategy""" + sq_unitary = strategies.squeezer(shape, math.asnumpy(r), math.asnumpy(phi)) + + ret = math.astensor(sq_unitary, dtype=sq_unitary.dtype.name) + if math.backend_name == "numpy": + return ret + + def vjp(dLdGc): + dr, dphi = strategies.squeezer_vjp( + math.asnumpy(sq_unitary), + math.asnumpy(math.conj(dLdGc)), + math.asnumpy(r), + math.asnumpy(phi), + ) + return math.astensor(dr, dtype=r.dtype), math.astensor(dphi, phi.dtype) + + return ret, vjp + + +@math.custom_gradient +def squeezed(r, phi, shape): + r"""creates a single mode squeezed state using a numba-based fock lattice strategy""" + sq_ket = strategies.squeezed(shape, math.asnumpy(r), math.asnumpy(phi)) + + ret = math.astensor(sq_ket, dtype=sq_ket.dtype.name) + if math.backend_name == "numpy": + return ret + + def vjp(dLdGc): + dr, dphi = strategies.squeezed_vjp( + math.asnumpy(sq_ket), + math.asnumpy(math.conj(dLdGc)), + math.asnumpy(r), + math.asnumpy(phi), + ) + return math.astensor(dr, dtype=r.dtype), math.astensor(dphi, phi.dtype) + + return ret, vjp From 1ac933c8e5d41d2df1bc5d40031cd6b3496d673e Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 10:19:05 -0400 Subject: [PATCH 03/87] init --- mrmustard/lab_dev/states/base.py | 2 +- mrmustard/physics/bargmann_utils.py | 47 +- mrmustard/physics/representations.py | 774 ---------- mrmustard/physics/representations/__init__.py | 20 + .../bargmann.py} | 1363 +++++++---------- mrmustard/physics/representations/base.py | 85 + mrmustard/physics/representations/fock.py | 465 ++++++ ...est_bargmann.py => test_bargmann_utils.py} | 0 .../{test_fock.py => test_fock_utils.py} | 0 tests/test_physics/test_representations.py | 432 ------ .../test_representations/__init__.py | 0 .../test_representations/test_bargmann.py | 231 +++ 12 files changed, 1444 insertions(+), 1975 deletions(-) delete mode 100644 mrmustard/physics/representations.py create mode 100644 mrmustard/physics/representations/__init__.py rename mrmustard/physics/{ansatze.py => representations/bargmann.py} (51%) create mode 100644 mrmustard/physics/representations/base.py create mode 100644 mrmustard/physics/representations/fock.py rename tests/test_physics/{test_bargmann.py => test_bargmann_utils.py} (100%) rename tests/test_physics/{test_fock.py => test_fock_utils.py} (100%) delete mode 100644 tests/test_physics/test_representations.py create mode 100644 tests/test_physics/test_representations/__init__.py create mode 100644 tests/test_physics/test_representations/test_bargmann.py diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 0237958ca..d23a28b5c 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -55,7 +55,7 @@ from mrmustard.physics.gaussian import purity from mrmustard.physics.representations import Bargmann, Fock from mrmustard.lab_dev.utils import shape_check -from mrmustard.physics.ansatze import ( +from mrmustard.physics.bargmann_utils import ( bargmann_Abc_to_phasespace_cov_means, ) from mrmustard.lab_dev.circuit_components_utils import BtoPS, BtoQ, TraceOut diff --git a/mrmustard/physics/bargmann_utils.py b/mrmustard/physics/bargmann_utils.py index 6600cf371..10c903371 100644 --- a/mrmustard/physics/bargmann_utils.py +++ b/mrmustard/physics/bargmann_utils.py @@ -20,7 +20,52 @@ from mrmustard import math, settings from mrmustard.physics.husimi import pq_to_aadag, wigner_to_husimi -from mrmustard.utils.typing import ComplexMatrix +from mrmustard.utils.typing import ComplexMatrix, Matrix, Vector, Scalar + + +def bargmann_Abc_to_phasespace_cov_means( + A: Matrix, b: Vector, c: Scalar, batched: bool = False +) -> tuple[Matrix, Vector, Scalar]: + r""" + Function to derive the covariance matrix and mean vector of a Gaussian state from its Wigner characteristic function in ABC form. + + The covariance matrix and mean vector can be used to write the characteristic function of a Gaussian state + :math: + \Chi_G(r) = \exp\left( -\frac{1}{2}r^T \Omega^T cov \Omega r + i r^T\Omega^T mean \right), + and the Wigner function of a Gaussian state: + :math: + W_G(r) = \frac{1}{\sqrt{\Det(cov)}} \exp\left( -\frac{1}{2}(r - mean)^T cov^{-1} (r-mean) \right). + + The internal expression of our Gaussian state :math:`\rho` is in Bargmann representation, one can write the characteristic function of a Gaussian state in Bargmann representation as + :math: + \Chi_G(\alpha) = \Tr(\rho D) = c \exp\left( -\frac{1}{2}\alpha^T A \alpha + \alpha^T b \right). + + This function is to go from the Abc triple in characteristic phase space into the covariance and mean vector for Gaussian state. + + Args: + A, b, c: The ``(A, b, c)`` triple of the state in characteristic phase space. + + Returns: + The covariance matrix, mean vector and coefficient of the state in phase space. + """ + # batched = len(A.shape) == 3 and len(b.shape) == 2 and len(c.shape) == 1 + A = math.atleast_3d(A) + b = math.atleast_2d(b) + c = math.atleast_1d(c) + num_modes = A.shape[-1] // 2 + Omega = math.cast(math.transpose(math.J(num_modes)), dtype=math.complex128) + W = math.transpose(math.conj(math.rotmat(num_modes))) + coeff = c + cov = [ + -Omega @ W @ Amat @ math.transpose(W) @ math.transpose(Omega) * settings.HBAR for Amat in A + ] + mean = [ + 1j * math.matvec(Omega @ W, bvec) * math.sqrt(settings.HBAR, dtype=math.complex128) + for bvec in b + ] + if batched: + return math.astensor(cov), math.astensor(mean), coeff + return cov[0], mean[0], coeff[0] def cayley(X, c): diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py deleted file mode 100644 index 38967647f..000000000 --- a/mrmustard/physics/representations.py +++ /dev/null @@ -1,774 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This module contains the classes for the available representations. -""" - -from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, Union - -from inspect import signature - -import numpy as np -from numpy.typing import ArrayLike - -from matplotlib import colors -import matplotlib.pyplot as plt - -from IPython.display import display - -from mrmustard import math, settings -from mrmustard.physics.gaussian_integrals import ( - contract_two_Abc_poly, - reorder_abc, - complex_gaussian_integral, -) -from mrmustard.physics.ansatze import Ansatz, PolyExpAnsatz, ArrayAnsatz -from mrmustard.utils.typing import ( - Batch, - ComplexMatrix, - ComplexTensor, - ComplexVector, - Scalar, - Tensor, -) -from mrmustard import widgets - -__all__ = ["Representation", "Bargmann", "Fock"] - - -class Representation(ABC): - r""" - A base class for representations. - - Representations can be initialized using the ``from_ansatz`` method, which automatically equips - them with all the functionality required to perform mathematical operations, such as equality, - multiplication, subtraction, etc. - """ - - def __init__(self) -> None: - self._contract_idxs: tuple[int, ...] = () - self._ansatz = None - - @property - def ansatz(self) -> Ansatz: - r""" - The ansatz of the representation. - """ - return self._ansatz - - @property - @abstractmethod - def data(self) -> tuple | Tensor: - r""" - The data of the representation. - For now, it's the triple for Bargmann and the array for Fock. - """ - - @property - @abstractmethod - def scalar(self) -> Scalar: - r""" - The scalar part of the representation. - For now it's ``c`` for Bargmann and the array for Fock. - """ - - @property - def triple( - self, - ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: - r""" - The batch of triples :math:`(A_i, b_i, c_i)`. - """ - return self.ansatz.triple - - @classmethod - def from_ansatz(cls, ansatz: Ansatz) -> Representation: - r""" - Returns a representation from an ansatz. - """ - ret = cls(**{key: None for key in signature(cls).parameters.keys()}) - ret._ansatz = ansatz - return ret - - @abstractmethod - def from_function(cls, fn: Callable, **kwargs: Any) -> Representation: - r""" - Returns a representation from a function and kwargs. - """ - - @abstractmethod - def reorder(self, order: tuple[int, ...] | list[int]) -> Representation: - r""" - Reorders the representation indices. - """ - - @abstractmethod - def to_dict(self) -> dict[str, ArrayLike]: - r"""Serialize a Representation.""" - - @classmethod - @abstractmethod - def from_dict(cls, data: dict[str, ArrayLike]) -> Representation: - r"""Deserialize a Representation.""" - - def __eq__(self, other: Representation) -> bool: - r""" - Whether this representation is equal to another. - """ - return self.ansatz == other.ansatz - - def __add__(self, other: Representation) -> Representation: - r""" - Adds this representation to another. - """ - if self.__class__.__name__ != other.__class__.__name__: - msg = f"Cannot add ``{self.__class__.__name__}`` representation to " - msg += f"``{other.__class__.__name__}`` representation." - raise ValueError(msg) - return self.from_ansatz(self.ansatz + other.ansatz) - - def __sub__(self, other) -> Representation: - r""" - Subtracts another representation from this one. - """ - return self.from_ansatz(self.ansatz - other.ansatz) - - def __mul__(self, other: Representation | Scalar) -> Representation: - r""" - Multiplies this representation by another or by a scalar. - """ - try: - return self.from_ansatz(self.ansatz * other.ansatz) - except AttributeError: - return self.from_ansatz(self.ansatz * other) - - def __rmul__(self, other: Representation | Scalar) -> Representation: - r""" - Multiplies this representation by another or by a scalar on the right. - """ - return self.__mul__(other) - - def __truediv__(self, other: Representation | Scalar) -> Representation: - r""" - Divides this representation by another or by a scalar. - """ - try: - return self.from_ansatz(self.ansatz / other.ansatz) - except AttributeError: - return self.from_ansatz(self.ansatz / other) - - def __rtruediv__(self, other: Representation | Scalar) -> Representation: - r""" - Divides this representation by another or by a scalar on the right. - """ - return self.from_ansatz(other / self.ansatz) - - def __and__(self, other: Representation) -> Representation: - r""" - Takes the outer product of this representation with another. - """ - return self.from_ansatz(self.ansatz & other.ansatz) - - def __getitem__(self, idx: int | tuple[int, ...]) -> Representation: - r""" - Stores the indices for contraction. - """ - raise NotImplementedError - - -class Bargmann(Representation): - r""" - The Fock-Bargmann representation of a broad class of quantum states, transformations, - measurements, channels, etc. - - The ansatz available in this representation is a linear combination of exponentials - of bilinear forms with a polynomial part: - - .. math:: - F(z) = \sum_i \textrm{poly}_i(z) \textrm{exp}(z^T A_i z / 2 + z^T b_i) - - This function allows for vector space operations on Bargmann objects including - linear combinations (``+``), outer product (``&``), and inner product (``@``). - - .. code-block :: - - >>> from mrmustard.physics.representations import Bargmann - >>> from mrmustard.physics.triples import displacement_gate_Abc, vacuum_state_Abc - - >>> # bargmann representation of one-mode vacuum - >>> rep_vac = Bargmann(*vacuum_state_Abc(1)) - - >>> # bargmann representation of one-mode dgate with gamma=1+0j - >>> rep_dgate = Bargmann(*displacement_gate_Abc(1)) - - The inner product is defined as the contraction of two Bargmann objects across marked indices. - Indices are marked using ``__getitem__``. Once the indices are marked for contraction, they are - be used the next time the inner product (``@``) is called. For example: - - .. code-block :: - - >>> import numpy as np - - >>> # mark indices for contraction - >>> idx_vac = [0] - >>> idx_rep = [1] - - >>> # bargmann representation of coh = vacuum >> dgate - >>> rep_coh = rep_vac[idx_vac] @ rep_dgate[idx_rep] - >>> assert np.allclose(rep_coh.A, [[0,],]) - >>> assert np.allclose(rep_coh.b, [1,]) - >>> assert np.allclose(rep_coh.c, 0.6065306597126334) - - This can also be used to contract existing indices in a single Bargmann object, e.g. - to implement the partial trace. - - .. code-block :: - - >>> trace = (rep_coh @ rep_coh.conj()).trace([0], [1]) - >>> assert np.allclose(trace.A, 0) - >>> assert np.allclose(trace.b, 0) - >>> assert trace.c == 1 - - The ``A``, ``b``, and ``c`` parameters can be batched to represent superpositions. - - .. code-block :: - - >>> # bargmann representation of one-mode coherent state with gamma=1+0j - >>> A_plus = [[0,],] - >>> b_plus = [1,] - >>> c_plus = 0.6065306597126334 - - >>> # bargmann representation of one-mode coherent state with gamma=-1+0j - >>> A_minus = [[0,],] - >>> b_minus = [-1,] - >>> c_minus = 0.6065306597126334 - - >>> # bargmann representation of a superposition of coherent states - >>> A = [A_plus, A_minus] - >>> b = [b_plus, b_minus] - >>> c = [c_plus, c_minus] - >>> rep_coh_sup = Bargmann(A, b, c) - - Note that the operations that change the shape of the ansatz (outer product and inner - product) do not automatically modify the ordering of the combined or leftover indices. - However, the ``reordering`` method allows reordering the representation after the products - have been carried out. - - Args: - A: A batch of quadratic coefficient :math:`A_i`. - b: A batch of linear coefficients :math:`b_i`. - c: A batch of arrays :math:`c_i`. - - Note: The args can be passed non-batched, as they will be automatically broadcasted to the - correct batch shape. - """ - - def __init__( - self, - A: Batch[ComplexMatrix], - b: Batch[ComplexVector], - c: Batch[ComplexTensor] = 1.0, - ): - super().__init__() - self._ansatz = PolyExpAnsatz(A=A, b=b, c=c) - - @property - def A(self) -> Batch[ComplexMatrix]: - r""" - The batch of quadratic coefficient :math:`A_i`. - """ - return self.ansatz.A - - @property - def b(self) -> Batch[ComplexVector]: - r""" - The batch of linear coefficients :math:`b_i` - """ - return self.ansatz.b - - @property - def c(self) -> Batch[ComplexTensor]: - r""" - The batch of arrays :math:`c_i`. - """ - return self.ansatz.c - - @property - def data( - self, - ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: - r""" - The data of the representation. - """ - return self.triple - - @property - def scalar(self) -> Batch[ComplexTensor]: - r""" - The scalar part of the representation. - """ - if self.ansatz.polynomial_shape[0] > 0: - return self([]) - else: - return self.c - - @classmethod - def from_function(cls, fn: Callable, **kwargs: Any) -> Bargmann: - r""" - Returns a Bargmann object from a generator function. - """ - return cls.from_ansatz(PolyExpAnsatz.from_function(fn, **kwargs)) - - def conj(self): - r""" - The conjugate of this Bargmann object. - """ - new = self.__class__(math.conj(self.A), math.conj(self.b), math.conj(self.c)) - new._contract_idxs = self._contract_idxs # pylint: disable=protected-access - return new - - def plot( - self, - just_phase: bool = False, - with_measure: bool = False, - log_scale: bool = False, - xlim=(-2 * np.pi, 2 * np.pi), - ylim=(-2 * np.pi, 2 * np.pi), - ) -> tuple[plt.figure.Figure, plt.axes.Axes]: # pragma: no cover - r""" - Plots the Bargmann function :math:`F(z)` on the complex plane. Phase is represented by - color, magnitude by brightness. The function can be multiplied by :math:`exp(-|z|^2)` - to represent the Bargmann function times the measure function (for integration). - - Args: - just_phase: Whether to plot only the phase of the Bargmann function. - with_measure: Whether to plot the bargmann function times the measure function - :math:`exp(-|z|^2)`. - log_scale: Whether to plot the log of the Bargmann function. - xlim: The `x` limits of the plot. - ylim: The `y` limits of the plot. - - Returns: - The figure and axes of the plot - """ - # eval F(z) on a grid of complex numbers - X, Y = np.mgrid[xlim[0] : xlim[1] : 400j, ylim[0] : ylim[1] : 400j] - Z = (X + 1j * Y).T - f_values = self(Z[..., None]) - if log_scale: - f_values = np.log(np.abs(f_values)) * np.exp(1j * np.angle(f_values)) - if with_measure: - f_values = f_values * np.exp(-(np.abs(Z) ** 2)) - - # Get phase and magnitude of F(z) - phases = np.angle(f_values) / (2 * np.pi) % 1 - magnitudes = np.abs(f_values) - magnitudes_scaled = magnitudes / np.max(magnitudes) - - # Convert to RGB - hsv_values = np.zeros(f_values.shape + (3,)) - hsv_values[..., 0] = phases - hsv_values[..., 1] = 1 - hsv_values[..., 2] = 1 if just_phase else magnitudes_scaled - rgb_values = colors.hsv_to_rgb(hsv_values) - - # Plot the image - fig, ax = plt.subplots() - ax.imshow(rgb_values, origin="lower", extent=[xlim[0], xlim[1], ylim[0], ylim[1]]) - ax.set_xlabel("$Re(z)$") - ax.set_ylabel("$Im(z)$") - - name = "F_{" + self.ansatz.name + "}(z)" - name = f"\\arg({name})\\log|{name}|" if log_scale else name - title = name + "e^{-|z|^2}" if with_measure else name - title = f"\\arg({name})" if just_phase else title - ax.set_title(f"${title}$") - plt.show(block=False) - return fig, ax - - def reorder(self, order: tuple[int, ...] | list[int]) -> Bargmann: - r""" - Reorders the indices of the ``A`` matrix and ``b`` vector of the ``(A, b, c)`` triple in - this Bargmann object. - - .. code-block:: - - >>> from mrmustard.physics.representations import Bargmann - >>> from mrmustard.physics.triples import displacement_gate_Abc - - >>> rep_dgate1 = Bargmann(*displacement_gate_Abc([0.1, 0.2, 0.3])) - >>> rep_dgate2 = Bargmann(*displacement_gate_Abc([0.2, 0.3, 0.1])) - - >>> assert rep_dgate1.reorder([1, 2, 0, 4, 5, 3]) == rep_dgate2 - - Args: - order: The new order. - - Returns: - The reordered Bargmann object. - """ - A, b, c = reorder_abc((self.A, self.b, self.c), order) - return self.__class__(A, b, c) - - def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> Bargmann: - r""" - The partial trace over the given index pairs. - - Args: - idx_z: The first part of the pairs of indices to trace over. - idx_zconj: The second part. - - Returns: - Bargmann: the ansatz with the given indices traced over - """ - A, b, c = [], [], [] - for Abc in zip(self.A, self.b, self.c): - Aij, bij, cij = complex_gaussian_integral(Abc, idx_z, idx_zconj, measure=-1.0) - A.append(Aij) - b.append(bij) - c.append(cij) - return Bargmann(A, b, c) - - def __call__(self, z: ComplexTensor) -> ComplexTensor: - r""" - Evaluates the Bargmann function at the given array of points. - - Args: - z: The array of points. - - Returns: - The value of the Bargmann function at ``z``. - """ - return self.ansatz(z) - - def __getitem__(self, idx: int | tuple[int, ...]) -> Bargmann: - r""" - A copy of self with the given indices marked for contraction. - """ - idx = (idx,) if isinstance(idx, int) else idx - for i in idx: - if i >= self.ansatz.num_vars: - raise IndexError( - f"Index {i} out of bounds for ansatz {self.ansatz.__class__.__qualname__} of dimension {self.ansatz.num_vars}." - ) - new = self.__class__(self.A, self.b, self.c) - new._contract_idxs = idx - return new - - def __matmul__(self, other: Bargmann) -> Bargmann: - r""" - Implements the inner product in Bargmann representation. - - ..code-block:: - - >>> from mrmustard.physics.representations import Bargmann - >>> from mrmustard.physics.triples import displacement_gate_Abc, vacuum_state_Abc - >>> rep1 = Bargmann(*vacuum_state_Abc(1)) - >>> rep2 = Bargmann(*displacement_gate_Abc(1)) - >>> rep3 = rep1[0] @ rep2[1] - >>> assert np.allclose(rep3.A, [[0,],]) - >>> assert np.allclose(rep3.b, [1,]) - - Args: - other: Another Bargmann representation. - - Returns: - Bargmann: the resulting Bargmann representation. - - """ - if isinstance(other, Fock): - raise NotImplementedError("Only matmul Bargmann with Bargmann") - - idx_s = self._contract_idxs - idx_o = other._contract_idxs - - Abc = [] - if settings.UNSAFE_ZIP_BATCH: - if self.ansatz.batch_size != other.ansatz.batch_size: - raise ValueError( - f"Batch size of the two ansatze must match since the settings.UNSAFE_ZIP_BATCH is {settings.UNSAFE_ZIP_BATCH}." - ) - for (A1, b1, c1), (A2, b2, c2) in zip( - zip(self.A, self.b, self.c), zip(other.A, other.b, other.c) - ): - Abc.append(contract_two_Abc_poly((A1, b1, c1), (A2, b2, c2), idx_s, idx_o)) - else: - for A1, b1, c1 in zip(self.A, self.b, self.c): - for A2, b2, c2 in zip(other.A, other.b, other.c): - Abc.append(contract_two_Abc_poly((A1, b1, c1), (A2, b2, c2), idx_s, idx_o)) - - A, b, c = zip(*Abc) - return Bargmann(A, b, c) - - def to_dict(self) -> dict[str, ArrayLike]: - """Serialize a Bargmann instance.""" - return {"A": self.A, "b": self.b, "c": self.c} - - @classmethod - def from_dict(cls, data: dict[str, ArrayLike]) -> Bargmann: - """Deserialize a Bargmann instance.""" - return cls(**data) - - def _ipython_display_(self): - display(widgets.bargmann(self)) - - -class Fock(Representation): - r""" - The Fock representation of a broad class of quantum states, transformations, measurements, - channels, etc. - - The ansatz available in this representation is ``ArrayAnsatz``. - - This function allows for vector space operations on Fock objects including - linear combinations, outer product (``&``), and inner product (``@``). - - .. code-block:: - - >>> from mrmustard.physics.representations import Fock - - >>> # initialize Fock objects - >>> array1 = np.random.random((5,7,8)) - >>> array2 = np.random.random((5,7,8)) - >>> array3 = np.random.random((3,5,7,8)) # where 3 is the batch. - >>> fock1 = Fock(array1) - >>> fock2 = Fock(array2) - >>> fock3 = Fock(array3, batched=True) - - >>> # linear combination can be done with the same batch dimension - >>> fock4 = 1.3 * fock1 - fock2 * 2.1 - - >>> # division by a scalar - >>> fock5 = fock1 / 1.3 - - >>> # inner product by contracting on marked indices - >>> fock6 = fock1[2] @ fock3[2] - - >>> # outer product (tensor product) - >>> fock7 = fock1 & fock3 - - >>> # conjugation - >>> fock8 = fock1.conj() - - Args: - array: the (batched) array in Fock representation. - batched: whether the array input has a batch dimension. - - Note: The args can be passed non-batched, as they will be automatically broadcasted to the - correct batch shape. - - """ - - def __init__(self, array: Batch[Tensor], batched=False): - super().__init__() - self._ansatz = ArrayAnsatz(array=array, batched=batched) - - @property - def array(self) -> Batch[Tensor]: - r""" - The array from the ansatz. - """ - return self.ansatz.array - - @property - def data(self) -> Batch[Tensor]: - r""" - The data of the representation. - """ - return self.array - - @property - def scalar(self) -> Scalar: - r""" - The scalar part of the representation. - I.e. the vacuum component of the Fock object, whatever it may be. - Given that the first axis of the array is the batch axis, this is the first element of the array. - """ - return self.array[(slice(None),) + (0,) * self.ansatz.num_vars] - - @classmethod - def from_function(cls, fn: Callable, **kwargs: Any) -> Fock: - r""" - Returns a Fock object from a generator function. - """ - return cls.from_ansatz(ArrayAnsatz.from_function(fn, **kwargs)) - - def conj(self): - r""" - The conjugate of this Fock object. - """ - new = self.from_ansatz(self.ansatz.conj) - new._contract_idxs = self._contract_idxs # pylint: disable=protected-access - return new - - def reduce(self, shape: Union[int, Iterable[int]]) -> Fock: - r""" - Returns a new ``Fock`` with a sliced array. - - .. code-block:: - - >>> from mrmustard import math - >>> from mrmustard.physics.representations import Fock - - >>> array1 = math.arange(27).reshape((3, 3, 3)) - >>> fock1 = Fock(array1) - - >>> fock2 = fock1.reduce(3) - >>> assert fock1 == fock2 - - >>> fock3 = fock1.reduce(2) - >>> array3 = [[[0, 1], [3, 4]], [[9, 10], [12, 13]]] - >>> assert fock3 == Fock(array3) - - >>> fock4 = fock1.reduce((1, 3, 1)) - >>> array4 = [[[0], [3], [6]]] - >>> assert fock4 == Fock(array4) - - Args: - shape: The shape of the array of the returned ``Fock``. - """ - return self.from_ansatz(self.ansatz.reduce(shape)) - - def reorder(self, order: tuple[int, ...] | list[int]) -> Fock: - r""" - Reorders the indices of the array with the given order. - - Args: - order: The order. Does not need to refer to the batch dimension. - - Returns: - The reordered Fock. - """ - return self.from_ansatz( - ArrayAnsatz(math.transpose(self.array, [0] + [i + 1 for i in order])) - ) - - def sum_batch(self) -> Fock: - r""" - Sums over the batch dimension of the array. Turns an object with any batch size to a batch size of 1. - - Returns: - The collapsed Fock object. - """ - return self.from_ansatz(ArrayAnsatz(math.expand_dims(math.sum(self.array, axes=[0]), 0))) - - def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Fock: - r""" - Implements the partial trace over the given index pairs. - - Args: - idxs1: The first part of the pairs of indices to trace over. - idxs2: The second part. - - Returns: - The traced-over Fock object. - """ - if len(idxs1) != len(idxs2) or not set(idxs1).isdisjoint(idxs2): - raise ValueError("idxs must be of equal length and disjoint") - order = ( - [0] - + [i + 1 for i in range(len(self.array.shape) - 1) if i not in idxs1 + idxs2] - + [i + 1 for i in idxs1] - + [i + 1 for i in idxs2] - ) - new_array = math.transpose(self.array, order) - n = np.prod(new_array.shape[-len(idxs2) :]) - new_array = math.reshape(new_array, new_array.shape[: -2 * len(idxs1)] + (n, n)) - trace = math.trace(new_array) - return self.from_ansatz(ArrayAnsatz([trace] if trace.shape == () else trace)) - - def __getitem__(self, idx: int | tuple[int, ...]) -> Fock: - r""" - Returns a copy of self with the given indices marked for contraction. - """ - idx = (idx,) if isinstance(idx, int) else idx - for i in idx: - if i >= len(self.array.shape): - raise IndexError( - f"Index {i} out of bounds for ansatz {self.ansatz.__class__.__qualname__} with {self.ansatz.num_vars} variables." - ) - new = self.from_ansatz(self.ansatz) - new._contract_idxs = idx - return new - - def __matmul__(self, other: Fock) -> Fock: - r""" - Implements the inner product of fock arrays over the marked indices. - - .. code-block:: - >>> from mrmustard.physics.representations import Fock - >>> f = Fock(np.random.random((3, 5, 10))) # 10 is reduced to 8 - >>> g = Fock(np.random.random((2, 5, 8))) - >>> h = f[1,2] @ g[1,2] - >>> assert h.array.shape == (1,3,2) # batch size is 1 - >>> f = Fock(np.random.random((3, 5, 10)), batched=True) - >>> g = Fock(np.random.random((2, 5, 8)), batched=True) - >>> h = f[0,1] @ g[0,1] - >>> assert h.array.shape == (6,) # batch size is 3 x 2 = 6 - - Args: - other: Another representation. - - Returns: - A ``Fock``representation. - """ - if isinstance(other, Bargmann): - raise NotImplementedError("only matmul Fock with Fock") - - idx_s = list(self._contract_idxs) - idx_o = list(other._contract_idxs) - - # the number of batches in self and other - n_batches_s = self.array.shape[0] - n_batches_o = other.array.shape[0] - - # the shapes each batch in self and other - shape_s = self.array.shape[1:] - shape_o = other.array.shape[1:] - - new_shape_s = list(shape_s) - new_shape_o = list(shape_o) - for s, o in zip(idx_s, idx_o): - new_shape_s[s] = min(shape_s[s], shape_o[o]) - new_shape_o[o] = min(shape_s[s], shape_o[o]) - - reduced_s = self.reduce(new_shape_s)[idx_s] - reduced_o = other.reduce(new_shape_o)[idx_o] - - axes = [list(idx_s), list(idx_o)] - batched_array = [] - for i in range(n_batches_s): - for j in range(n_batches_o): - batched_array.append(math.tensordot(reduced_s.array[i], reduced_o.array[j], axes)) - return self.from_ansatz(ArrayAnsatz(batched_array)) - - def to_dict(self) -> dict[str, ArrayLike]: - """Serialize a Fock instance.""" - return {"array": self.data} - - @classmethod - def from_dict(cls, data: dict[str, ArrayLike]) -> Fock: - """Deserialize a Fock instance.""" - return cls(data["array"], batched=True) - - def _ipython_display_(self): - w = widgets.fock(self) - if w is None: - print(repr(self)) - return - display(w) diff --git a/mrmustard/physics/representations/__init__.py b/mrmustard/physics/representations/__init__.py new file mode 100644 index 000000000..34c0933bb --- /dev/null +++ b/mrmustard/physics/representations/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +""" + +from .base import * +from .bargmann import * +from .fock import * diff --git a/mrmustard/physics/ansatze.py b/mrmustard/physics/representations/bargmann.py similarity index 51% rename from mrmustard/physics/ansatze.py rename to mrmustard/physics/representations/bargmann.py index 6b6a99c4a..ca8b15831 100644 --- a/mrmustard/physics/ansatze.py +++ b/mrmustard/physics/representations/bargmann.py @@ -12,296 +12,292 @@ # See the License for the specific language governing permissions and # limitations under the License. + """ -This module contains the classes for the available ansatze. +This module contains the Bargmann representation. """ from __future__ import annotations +from typing import Any, Callable import itertools -from abc import ABC, abstractmethod -from typing import Any, Callable, Sequence -from warnings import warn import numpy as np +from numpy.typing import ArrayLike + +from matplotlib import colors +import matplotlib.pyplot as plt + +from IPython.display import display -from mrmustard import math, settings -from mrmustard.math.parameters import Variable -from mrmustard.utils.argsort import argsort_gen from mrmustard.utils.typing import ( Batch, ComplexMatrix, ComplexTensor, ComplexVector, - Matrix, Scalar, - Tensor, Vector, ) -__all__ = [ - "Ansatz", - "ArrayAnsatz", - "PolyExpBase", - "PolyExpAnsatz", -] +from mrmustard.physics.gaussian_integrals import ( + reorder_abc, + complex_gaussian_integral, + contract_two_Abc_poly, +) +from mrmustard import math, settings, widgets +from mrmustard.math.parameters import Variable -class Ansatz(ABC): +from mrmustard.utils.argsort import argsort_gen + +__all__ = ["Bargmann"] + + +class Bargmann: r""" - A function over a continuous and/or discrete domain. + The Fock-Bargmann representation of a broad class of quantum states, transformations, + measurements, channels, etc. - An ansatz supports basic mathematical operations such as addition, subtraction, - multiplication, division, negation, equality, etc. + The ansatz available in this representation is a linear combination of exponentials + of bilinear forms with a polynomial part: - Note that ``n``-dimensional arrays are like functions defined over an integer lattice of points, - so this class also works for, e.g., the Fock representation. + .. math:: + F(z) = \sum_i \textrm{poly}_i(z) \textrm{exp}(z^T A_i z / 2 + z^T b_i) - This class is abstract. Concrete ``Ansatz`` classes have to implement the - ``__call__``, ``__mul__``, ``__add__``, ``__sub__``, ``__neg__``, and ``__eq__`` methods. - """ + This function allows for vector space operations on Bargmann objects including + linear combinations (``+``), outer product (``&``), and inner product (``@``). - def __init__(self) -> None: - self._fn = None - self._kwargs = {} + .. code-block :: - @abstractmethod - def from_function(cls, fn: Callable, **kwargs: Any) -> Ansatz: - r""" - Returns an ansatz from a function and kwargs. - """ + >>> from mrmustard.physics.representations import Bargmann + >>> from mrmustard.physics.triples import displacement_gate_Abc, vacuum_state_Abc - @abstractmethod - def __neg__(self) -> Ansatz: - r""" - Negates this ansatz. - """ + >>> # bargmann representation of one-mode vacuum + >>> rep_vac = Bargmann(*vacuum_state_Abc(1)) - @abstractmethod - def __eq__(self, other: Ansatz) -> bool: - r""" - Whether this ansatz is equal to another ansatz. - """ + >>> # bargmann representation of one-mode dgate with gamma=1+0j + >>> rep_dgate = Bargmann(*displacement_gate_Abc(1)) - @abstractmethod - def __add__(self, other: Ansatz) -> Ansatz: - r""" - Sums this ansatz to another ansatz. - """ + The inner product is defined as the contraction of two Bargmann objects across marked indices. + Indices are marked using ``__getitem__``. Once the indices are marked for contraction, they are + be used the next time the inner product (``@``) is called. For example: - def __sub__(self, other: Ansatz) -> Ansatz: - r""" - Subtracts other from this ansatz. - """ - try: - return self.__add__(-other) - except AttributeError as e: - raise TypeError(f"Cannot subtract {self.__class__} and {other.__class__}.") from e + .. code-block :: - @abstractmethod - def __call__(self, point: Any) -> Scalar: - r""" - Evaluates this ansatz at a given point in the domain. - """ + >>> import numpy as np - @abstractmethod - def __truediv__(self, other: Scalar | Ansatz) -> Ansatz: - r""" - Divides this ansatz by another ansatz or by a scalar. - """ + >>> # mark indices for contraction + >>> idx_vac = [0] + >>> idx_rep = [1] - @abstractmethod - def __mul__(self, other: Scalar | Ansatz) -> Ansatz: - r""" - Multiplies this ansatz by another ansatz. - """ + >>> # bargmann representation of coh = vacuum >> dgate + >>> rep_coh = rep_vac[idx_vac] @ rep_dgate[idx_rep] + >>> assert np.allclose(rep_coh.A, [[0,],]) + >>> assert np.allclose(rep_coh.b, [1,]) + >>> assert np.allclose(rep_coh.c, 0.6065306597126334) - @abstractmethod - def __and__(self, other: Ansatz) -> Ansatz: - r""" - Tensor product of this ansatz with another ansatz. - """ + This can also be used to contract existing indices in a single Bargmann object, e.g. + to implement the partial trace. - def __rmul__(self, other: Scalar) -> Ansatz: - r""" - Multiplies this ansatz by a scalar. - """ - return self * other + .. code-block :: + >>> trace = (rep_coh @ rep_coh.conj()).trace([0], [1]) + >>> assert np.allclose(trace.A, 0) + >>> assert np.allclose(trace.b, 0) + >>> assert trace.c == 1 -# pylint: disable=too-many-instance-attributes -class PolyExpBase(Ansatz): - r""" - A family of Ansatze parametrized by a triple of a matrix, a vector and an array. - For example, the Bargmann representation :math:`c\:\textrm{exp}(z A z / 2 + b z)` is of this - form (where ``A``, ``b``, ``c`` is the triple), or the characteristic function of the - Wigner representation (where ``Sigma``, ``mu``, ``1`` is the triple). + The ``A``, ``b``, and ``c`` parameters can be batched to represent superpositions. + + .. code-block :: + + >>> # bargmann representation of one-mode coherent state with gamma=1+0j + >>> A_plus = [[0,],] + >>> b_plus = [1,] + >>> c_plus = 0.6065306597126334 - Note that this class is not initializable (despite having an initializer) because it does - not implement all the abstract methods of ``Ansatz``, and it is in fact more general. - Concrete ansatze that inherit from this class need to implement ``__call__``, - ``__mul__`` and ``__matmul__``, which are representation-specific. + >>> # bargmann representation of one-mode coherent state with gamma=-1+0j + >>> A_minus = [[0,],] + >>> b_minus = [-1,] + >>> c_minus = 0.6065306597126334 - Note that the arguments are expected to be batched, i.e. to have a batch dimension - or to be an iterable. This is because this class also provides the linear superposition - functionality by implementing the ``__add__`` method, which concatenates the batch dimensions. + >>> # bargmann representation of a superposition of coherent states + >>> A = [A_plus, A_minus] + >>> b = [b_plus, b_minus] + >>> c = [c_plus, c_minus] + >>> rep_coh_sup = Bargmann(A, b, c) - As this can blow up the number of terms in the representation, it is recommended to - run the `simplify()` method after adding terms together, which combines together - terms that have the same exponential part. + Note that the operations that change the shape of the ansatz (outer product and inner + product) do not automatically modify the ordering of the combined or leftover indices. + However, the ``reordering`` method allows reordering the representation after the products + have been carried out. Args: - mat: the matrix-like data - vec: the vector-like data - array: the array-like data + A: A batch of quadratic coefficient :math:`A_i`. + b: A batch of linear coefficients :math:`b_i`. + c: A batch of arrays :math:`c_i`. + + Note: The args can be passed non-batched, as they will be automatically broadcasted to the + correct batch shape. """ def __init__( self, - mat: Batch[Matrix], - vec: Batch[Vector], - array: Batch[Tensor], + A: Batch[ComplexMatrix], + b: Batch[ComplexVector], + c: Batch[ComplexTensor] = 1.0, + name: str = "", + batched: bool = False, ): - super().__init__() - self._mat = mat - self._vec = vec - self._array = array + if A is None and b is None and c is not None: + raise ValueError("Please provide either A or b.") - # if (mat, vec, array) have been converted to backend - self._backends = [False, False, False] + # Representation base class + self._contract_idxs: tuple[int, ...] = () + self._fn = None + self._kwargs = {} + self.name = name + self._A = A + self._b = b + self._c = c + self._backends = [False, False, False] self._simplified = False @property - def array(self) -> Batch[ComplexMatrix]: + def A(self) -> Batch[ComplexMatrix]: r""" - The array of this ansatz. + The batch of quadratic coefficient :math:`A_i`. """ self._generate_ansatz() - if not self._backends[2]: - self._array = math.atleast_1d(self._array) - self._backends[2] = True - return self._array + if not self._backends[0]: + self._A = math.atleast_3d(self._A) + self._backends[0] = True + return self._A - @array.setter - def array(self, array): - self._array = array - self._backends[2] = False + @A.setter + def A(self, value): + self._A = value + self._backends[0] = False @property - def batch_size(self): + def b(self) -> Batch[ComplexVector]: r""" - The batch size of this ansatz. + The batch of linear coefficients :math:`b_i` """ - return self.mat.shape[0] + self._generate_ansatz() + if not self._backends[1]: + self._b = math.atleast_2d(self._b) + self._backends[1] = True + return self._b + + @b.setter + def b(self, value): + self._b = value + self._backends[1] = False @property - def polynomial_shape(self) -> tuple[int, tuple]: + def batch_size(self): r""" - This method finds the dimensionality of the polynomial, i.e. how many wires - have polynomials attached to them and what the degree(+1) of the polynomial is - on each of the wires. + The batch size of this representation. """ - dim_poly = len(self.array.shape) - 1 - shape_poly = self.array.shape[1:] - return dim_poly, shape_poly + return self.c.shape[0] @property - def mat(self) -> Batch[ComplexMatrix]: + def c(self) -> Batch[ComplexTensor]: r""" - The matrix of this ansatz. + The batch of arrays :math:`c_i`. """ self._generate_ansatz() - if not self._backends[0]: - self._mat = math.atleast_3d(self._mat) - self._backends[0] = True - return self._mat + if not self._backends[2]: + self._c = math.atleast_1d(self._c) + self._backends[2] = True + return self._c - @mat.setter - def mat(self, array): - self._mat = array - self._backends[0] = False + @c.setter + def c(self, value): + self._c = value + self._backends[2] = False + + @property + def data( + self, + ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: + r""" + The data of the representation. + """ + return self.triple @property def num_vars(self): r""" The number of variables in this ansatz. """ - return self.mat.shape[-1] - self.polynomial_shape[0] + return self.A.shape[-1] - self.polynomial_shape[0] @property - def vec(self) -> Batch[ComplexMatrix]: + def polynomial_shape(self) -> tuple[int, tuple]: r""" - The vector of this ansatz. + This method finds the dimensionality of the polynomial, i.e. how many wires + have polynomials attached to them and what the degree(+1) of the polynomial is + on each of the wires. """ - self._generate_ansatz() - if not self._backends[1]: - self._vec = math.atleast_2d(self._vec) - self._backends[1] = True - return self._vec + dim_poly = len(self.c.shape) - 1 + shape_poly = self.c.shape[1:] + return dim_poly, shape_poly - @vec.setter - def vec(self, array): - self._vec = array - self._backends[1] = False + @property + def scalar(self) -> Batch[ComplexTensor]: + r""" + The scalar part of the representation. + """ + if self.polynomial_shape[0] > 0: + return self([]) + else: + return self.c - def simplify(self) -> None: + @property + def triple( + self, + ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: r""" - Simplifies the representation by combining together terms that have the same - exponential part, i.e. two terms along the batch are considered equal if their - matrix and vector are equal. In this case only one is kept and the arrays are added. + The batch of triples :math:`(A_i, b_i, c_i)`. + """ + return self.A, self.b, self.c - Does not run if the representation has already been simplified, so it is safe to call. + @classmethod + def from_dict(cls, data: dict[str, ArrayLike]) -> Bargmann: + """Deserialize a Bargmann instance.""" + return cls(**data) + + @classmethod + def from_function(cls, fn: Callable, **kwargs: Any) -> Bargmann: + r""" + Returns a Bargmann object from a generator function. """ - if self._simplified: - return - indices_to_check = set(range(self.batch_size)) - removed = [] - while indices_to_check: - i = indices_to_check.pop() - for j in indices_to_check.copy(): - if np.allclose(self.mat[i], self.mat[j]) and np.allclose(self.vec[i], self.vec[j]): - self.array = math.update_add_tensor(self.array, [[i]], [self.array[j]]) - indices_to_check.remove(j) - removed.append(j) - to_keep = [i for i in range(self.batch_size) if i not in removed] - self.mat = math.gather(self.mat, to_keep, axis=0) - self.vec = math.gather(self.vec, to_keep, axis=0) - self.array = math.gather(self.array, to_keep, axis=0) - self._simplified = True + ret = cls(None, None, None) + ret._fn = fn + ret._kwargs = kwargs + return ret - def simplify_v2(self) -> None: + def conj(self): r""" - A different implementation of ``simplify`` that orders the batch dimension first. + The conjugate of this Bargmann object. """ - if self._simplified: - return - self._order_batch() - to_keep = [d0 := 0] - mat, vec = self.mat[d0], self.vec[d0] - for d in range(1, self.batch_size): - if np.allclose(mat, self.mat[d]) and np.allclose(vec, self.vec[d]): - self.array = math.update_add_tensor(self.array, [[d0]], [self.array[d]]) - else: - to_keep.append(d) - d0 = d - mat, vec = self.mat[d0], self.vec[d0] - self.mat = math.gather(self.mat, to_keep, axis=0) - self.vec = math.gather(self.vec, to_keep, axis=0) - self.array = math.gather(self.array, to_keep, axis=0) - self._simplified = True + ret = Bargmann(math.conj(self.A), math.conj(self.b), math.conj(self.c)) + ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access + return ret - def decompose_ansatz(self) -> PolyExpAnsatz: + def decompose_ansatz(self) -> Bargmann: r""" - This method decomposes a PolyExpAnsatz. Given an ansatz of dimensions: + This method decomposes a Bargmann representation. Given a representation of dimensions: A=(batch,n+m,n+m), b=(batch,n+m), c = (batch,k_1,k_2,...,k_m), - it can be rewritten as an ansatz of dimensions + it can be rewritten as a representation of dimensions A=(batch,2n,2n), b=(batch,2n), c = (batch,l_1,l_2,...,l_n), with l_i = sum_j k_j This decomposition is typically favourable if m>n, and will only run if that is the case. The naming convention is ``n = dim_alpha`` and ``m = dim_beta`` and ``(k_1,k_2,...,k_m) = shape_beta`` """ dim_beta, _ = self.polynomial_shape - dim_alpha = self.mat.shape[-1] - dim_beta + dim_alpha = self.A.shape[-1] - dim_beta batch_size = self.batch_size if dim_beta > dim_alpha: A_decomp = [] @@ -309,248 +305,171 @@ def decompose_ansatz(self) -> PolyExpAnsatz: c_decomp = [] for i in range(batch_size): A_decomp_i, b_decomp_i, c_decomp_i = self._decompose_ansatz_single( - self.mat[i], self.vec[i], self.array[i] + self.A[i], self.b[i], self.c[i] ) A_decomp.append(A_decomp_i) b_decomp.append(b_decomp_i) c_decomp.append(c_decomp_i) - return PolyExpAnsatz(A_decomp, b_decomp, c_decomp) + return Bargmann(A_decomp, b_decomp, c_decomp) else: - return PolyExpAnsatz(self.mat, self.vec, self.array) + return Bargmann(self.A, self.b, self.c) - def _decompose_ansatz_single(self, Ai, bi, ci): - dim_beta, shape_beta = self.polynomial_shape - dim_alpha = self.mat.shape[-1] - dim_beta - A_bar = math.block( - [ - [ - math.zeros((dim_alpha, dim_alpha), dtype=Ai.dtype), - Ai[:dim_alpha, dim_alpha:], - ], - [ - Ai[dim_alpha:, :dim_alpha], - Ai[dim_alpha:, dim_alpha:], - ], - ] - ) - b_bar = math.concat((math.zeros((dim_alpha), dtype=bi.dtype), bi[dim_alpha:]), axis=0) - poly_bar = math.hermite_renormalized( - A_bar, - b_bar, - complex(1), - (math.sum(shape_beta),) * dim_alpha + shape_beta, - ) - c_decomp = math.sum( - poly_bar * ci, - axes=math.arange( - len(poly_bar.shape) - dim_beta, len(poly_bar.shape), dtype=math.int32 - ).tolist(), - ) - A_decomp = math.block( - [ - [ - Ai[:dim_alpha, :dim_alpha], - math.eye(dim_alpha, dtype=Ai.dtype), - ], - [ - math.eye((dim_alpha), dtype=Ai.dtype), - math.zeros((dim_alpha, dim_alpha), dtype=Ai.dtype), - ], - ] - ) - b_decomp = math.concat((bi[:dim_alpha], math.zeros((dim_alpha), dtype=bi.dtype)), axis=0) - return A_decomp, b_decomp, c_decomp - - def _equal_no_array(self, other: PolyExpBase) -> bool: - self.simplify() - other.simplify() - return np.allclose(self.vec, other.vec, atol=1e-10) and np.allclose( - self.mat, other.mat, atol=1e-10 - ) - - def _generate_ansatz(self): + def plot( + self, + just_phase: bool = False, + with_measure: bool = False, + log_scale: bool = False, + xlim=(-2 * np.pi, 2 * np.pi), + ylim=(-2 * np.pi, 2 * np.pi), + ) -> tuple[plt.figure.Figure, plt.axes.Axes]: # pragma: no cover r""" - This method computes and sets the matrix, vector and array given a function - and some kwargs. - """ - names = list(self._kwargs.keys()) - vars = list(self._kwargs.values()) + Plots the Bargmann function :math:`F(z)` on the complex plane. Phase is represented by + color, magnitude by brightness. The function can be multiplied by :math:`exp(-|z|^2)` + to represent the Bargmann function times the measure function (for integration). - params = {} - param_types = [] - for name, param in zip(names, vars): - try: - params[name] = param.value - param_types.append(type(param)) - except AttributeError: - params[name] = param - - if self._array is None or Variable in param_types: - mat, vec, array = self._fn(**params) - self.mat = mat - self.vec = vec - self.array = array + Args: + just_phase: Whether to plot only the phase of the Bargmann function. + with_measure: Whether to plot the bargmann function times the measure function + :math:`exp(-|z|^2)`. + log_scale: Whether to plot the log of the Bargmann function. + xlim: The `x` limits of the plot. + ylim: The `y` limits of the plot. - def _order_batch(self): - r""" - This method orders the batch dimension by the lexicographical order of the - flattened arrays (mat, vec, array). This is a very cheap way to enforce - an ordering of the batch dimension, which is useful for simplification and for - determining (in)equality between two Bargmann representations. + Returns: + The figure and axes of the plot """ - generators = [ - itertools.chain( - math.asnumpy(self.vec[i]).flat, - math.asnumpy(self.mat[i]).flat, - math.asnumpy(self.array[i]).flat, - ) - for i in range(self.batch_size) - ] - sorted_indices = argsort_gen(generators) - self.mat = math.gather(self.mat, sorted_indices, axis=0) - self.vec = math.gather(self.vec, sorted_indices, axis=0) - self.array = math.gather(self.array, sorted_indices, axis=0) - - def __add__(self, other: PolyExpBase) -> PolyExpBase: + # eval F(z) on a grid of complex numbers + X, Y = np.mgrid[xlim[0] : xlim[1] : 400j, ylim[0] : ylim[1] : 400j] + Z = (X + 1j * Y).T + f_values = self(Z[..., None]) + if log_scale: + f_values = np.log(np.abs(f_values)) * np.exp(1j * np.angle(f_values)) + if with_measure: + f_values = f_values * np.exp(-(np.abs(Z) ** 2)) + + # Get phase and magnitude of F(z) + phases = np.angle(f_values) / (2 * np.pi) % 1 + magnitudes = np.abs(f_values) + magnitudes_scaled = magnitudes / np.max(magnitudes) + + # Convert to RGB + hsv_values = np.zeros(f_values.shape + (3,)) + hsv_values[..., 0] = phases + hsv_values[..., 1] = 1 + hsv_values[..., 2] = 1 if just_phase else magnitudes_scaled + rgb_values = colors.hsv_to_rgb(hsv_values) + + # Plot the image + fig, ax = plt.subplots() + ax.imshow(rgb_values, origin="lower", extent=[xlim[0], xlim[1], ylim[0], ylim[1]]) + ax.set_xlabel("$Re(z)$") + ax.set_ylabel("$Im(z)$") + + name = "F_{" + self.name + "}(z)" + name = f"\\arg({name})\\log|{name}|" if log_scale else name + title = name + "e^{-|z|^2}" if with_measure else name + title = f"\\arg({name})" if just_phase else title + ax.set_title(f"${title}$") + plt.show(block=False) + return fig, ax + + def reorder(self, order: tuple[int, ...] | list[int]) -> Bargmann: r""" - Adds two ansatze together. This means concatenating them in the batch dimension. - In the case where c is a polynomial of different shapes it will add padding zeros to make - the shapes fit. Example: If the shape of c1 is (1,3,4,5) and the shape of c2 is (1,5,4,3) then the - shape of the combined object will be (2,5,4,5). - """ - combined_matrices = math.concat([self.mat, other.mat], axis=0) - combined_vectors = math.concat([self.vec, other.vec], axis=0) - - a0s = self.array.shape[1:] - a1s = other.array.shape[1:] - if a0s == a1s: - combined_arrays = math.concat([self.array, other.array], axis=0) - else: - s_max = np.maximum(np.array(a0s), np.array(a1s)) - - padding_array0 = np.array( - ( - np.zeros(len(s_max) + 1), - np.concatenate((np.array([0]), np.array((s_max - a0s)))), - ), - dtype=int, - ).T - padding_tuple0 = tuple(tuple(padding_array0[i]) for i in range(len(s_max) + 1)) - - padding_array1 = np.array( - ( - np.zeros(len(s_max) + 1), - np.concatenate((np.array([0]), np.array((s_max - a1s)))), - ), - dtype=int, - ).T - padding_tuple1 = tuple(tuple(padding_array1[i]) for i in range(len(s_max) + 1)) - a0_new = np.pad(self.array, padding_tuple0, "constant") - a1_new = np.pad(other.array, padding_tuple1, "constant") - combined_arrays = math.concat([a0_new, a1_new], axis=0) - # note output is not simplified - return self.__class__(combined_matrices, combined_vectors, combined_arrays) - - def __eq__(self, other: PolyExpBase) -> bool: - return self._equal_no_array(other) and np.allclose(self.array, other.array, atol=1e-10) - - def __neg__(self) -> PolyExpBase: - return self.__class__(self.mat, self.vec, -self.array) - - -class PolyExpAnsatz(PolyExpBase): - r""" - The ansatz of the Fock-Bargmann representation. - - Represents the ansatz function: - - :math:`F(z) = \sum_i [\sum_k c^{(i)}_k \partial_y^k \textrm{exp}((z,y)^T A_i (z,y) / 2 + (z,y)^T b_i)|_{y=0}]` - - with ``k`` being a multi-index. The matrices :math:`A_i` and vectors :math:`b_i` are - parameters of the exponential terms in the ansatz, and :math:`z` is a vector of variables, and and :math:`y` is a vector linked to the polynomial coefficients. - The dimension of ``z + y`` must be equal to the dimension of ``A`` and ``b``. + Reorders the indices of the ``A`` matrix and ``b`` vector of the ``(A, b, c)`` triple in + this Bargmann object. .. code-block:: - >>> from mrmustard.physics.ansatze import PolyExpAnsatz - - - >>> A = np.array([[1.0, 0.0], [0.0, 1.0]]) - >>> b = np.array([1.0, 1.0]) - >>> c = np.array([[1.0,2.0,3.0]]) + >>> from mrmustard.physics.representations import Bargmann + >>> from mrmustard.physics.triples import displacement_gate_Abc - >>> F = PolyExpAnsatz(A, b, c) - >>> z = np.array([[1.0],[2.0],[3.0]]) + >>> rep_dgate1 = Bargmann(*displacement_gate_Abc([0.1, 0.2, 0.3])) + >>> rep_dgate2 = Bargmann(*displacement_gate_Abc([0.2, 0.3, 0.1])) - >>> # calculate the value of the function at the three different ``z``, since z is batched. - >>> val = F(z) + >>> assert rep_dgate1.reorder([1, 2, 0, 4, 5, 3]) == rep_dgate2 - A and b can be batched or not, but c needs to include an explicit batch dimension that matches A and b. - Args: - A: The list of square matrices :math:`A_i` - b: The list of vectors :math:`b_i` - c: The list of arrays :math:`c_i` is coefficients for the polynomial terms in the ansatz. - An explicit batch dimension that matched A and b has to be given for c. - - """ - - def __init__( - self, - A: Batch[Matrix] | None = None, - b: Batch[Vector] | None = None, - c: Batch[Tensor | Scalar] = np.array([[1.0]]), - name: str = "", - ): - self.name = name - - if A is None and b is None and c is not None: - raise ValueError("Please provide either A or b.") - super().__init__(mat=A, vec=b, array=c) + Args: + order: The new order. - @property - def A(self) -> Batch[ComplexMatrix]: - r""" - The list of square matrices :math:`A_i`. + Returns: + The reordered Bargmann object. """ - return self.mat + A, b, c = reorder_abc(self.triple, order) + return Bargmann(A, b, c) - @property - def b(self) -> Batch[ComplexVector]: + def simplify(self) -> None: r""" - The list of vectors :math:`b_i`. - """ - return self.vec + Simplifies the representation by combining together terms that have the same + exponential part, i.e. two terms along the batch are considered equal if their + matrix and vector are equal. In this case only one is kept and the arrays are added. - @property - def c(self) -> Batch[ComplexTensor]: - r""" - The array of coefficients for the polynomial terms in the ansatz. + Does not run if the representation has already been simplified, so it is safe to call. """ - return self.array + if self._simplified: + return + indices_to_check = set(range(self.batch_size)) + removed = [] + while indices_to_check: + i = indices_to_check.pop() + for j in indices_to_check.copy(): + if np.allclose(self.A[i], self.A[j]) and np.allclose(self.b[i], self.b[j]): + self.c = math.update_add_tensor(self.c, [[i]], [self.c[j]]) + indices_to_check.remove(j) + removed.append(j) + to_keep = [i for i in range(self.batch_size) if i not in removed] + self.A = math.gather(self.A, to_keep, axis=0) + self.b = math.gather(self.b, to_keep, axis=0) + self.c = math.gather(self.c, to_keep, axis=0) + self._simplified = True - @property - def triple( - self, - ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: + def simplify_v2(self) -> None: r""" - The batch of triples :math:`(A_i, b_i, c_i)`. + A different implementation of ``simplify`` that orders the batch dimension first. """ - return self.A, self.b, self.c + if self._simplified: + return + self._order_batch() + to_keep = [d0 := 0] + mat, vec = self.A[d0], self.b[d0] + for d in range(1, self.batch_size): + if np.allclose(mat, self.A[d]) and np.allclose(vec, self.b[d]): + self.c = math.update_add_tensor(self.c, [[d0]], [self.c[d]]) + else: + to_keep.append(d) + d0 = d + mat, vec = self.A[d0], self.b[d0] + self.A = math.gather(self.A, to_keep, axis=0) + self.b = math.gather(self.b, to_keep, axis=0) + self.c = math.gather(self.c, to_keep, axis=0) + self._simplified = True - @classmethod - def from_function(cls, fn: Callable, **kwargs: Any) -> PolyExpAnsatz: + def to_dict(self) -> dict[str, ArrayLike]: + """Serialize a Bargmann instance.""" + return {"A": self.A, "b": self.b, "c": self.c} + + def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> Bargmann: r""" - Returns a PolyExpAnsatz object from a generator function. - """ - ret = cls(None, None, None) - ret._fn = fn - ret._kwargs = kwargs - return ret + The partial trace over the given index pairs. + + Args: + idx_z: The first part of the pairs of indices to trace over. + idx_zconj: The second part. - def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz: + Returns: + Bargmann: the ansatz with the given indices traced over + """ + A, b, c = [], [], [] + for Abc in zip(self.A, self.b, self.c): + Aij, bij, cij = complex_gaussian_integral(Abc, idx_z, idx_zconj, measure=-1.0) + A.append(Aij) + b.append(bij) + c.append(cij) + return Bargmann(A, b, c) + + def _call_all(self, z: Batch[Vector]) -> Bargmann: r""" - Value of this ansatz at ``z``. If ``z`` is batched a value of the function at each of the batches are returned. + Value of this representation at ``z``. If ``z`` is batched a value of the function at each of the batches are returned. If ``Abc`` is batched it is thought of as a linear combination, and thus the results are added linearly together. Note that the batch dimension of ``z`` and ``Abc`` can be different. @@ -613,6 +532,39 @@ def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz: ) # (b_arg) return val + def _call_none(self, z: Batch[Vector]) -> Bargmann: + r""" + Returns a new ansatz that corresponds to currying (partially evaluate) the current one. + For example, if ``self`` represents the function ``F(z1,z2)``, the call ``self._call_none([np.array([1.0, None]])`` + returns ``F(1.0, z2)`` as a new ansatz with a single variable. + Note that the batch of the triple and argument in this method is handled parwise, unlike the regular call where the batch over the triple is a superposition. + + Args: + z: slice in C^n where the function is evaluated, while unevaluated along other axes of the space. + + Returns: + A new ansatz. + """ + + batch_abc = self.batch_size + batch_arg = z.shape[0] + Abc = [] + if batch_abc == 1 and batch_arg > 1: + for i in range(batch_arg): + Abc.append(self._call_none_single(self.A[0], self.b[0], self.c[0], z[i])) + elif batch_arg == 1 and batch_abc > 1: + for i in range(batch_abc): + Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[0])) + elif batch_abc == batch_arg: + for i in range(batch_abc): + Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[i])) + else: + raise ValueError( + "Batch size of the ansatz and argument must match or one of the batch sizes must be 1." + ) + A, b, c = zip(*Abc) + return Bargmann(A=A, b=b, c=c) + def _call_none_single(self, Ai, bi, ci, zi): r""" Helper function for the call_none method. Returns the new triple. @@ -656,56 +608,158 @@ def _call_none_single(self, Ai, bi, ci, zi): gamma, math.gather(math.gather(Ai, z_not_none, axis=0), z_not_none, axis=1), ) - b_part = math.einsum("j,j", math.gather(bi, z_not_none, axis=0), gamma) - exp_sum = math.exp(1 / 2 * A_part + b_part) - new_c = ci * exp_sum - return new_A, new_b, new_c + b_part = math.einsum("j,j", math.gather(bi, z_not_none, axis=0), gamma) + exp_sum = math.exp(1 / 2 * A_part + b_part) + new_c = ci * exp_sum + return new_A, new_b, new_c + + def _decompose_ansatz_single(self, Ai, bi, ci): + dim_beta, shape_beta = self.polynomial_shape + dim_alpha = self.A.shape[-1] - dim_beta + A_bar = math.block( + [ + [ + math.zeros((dim_alpha, dim_alpha), dtype=Ai.dtype), + Ai[:dim_alpha, dim_alpha:], + ], + [ + Ai[dim_alpha:, :dim_alpha], + Ai[dim_alpha:, dim_alpha:], + ], + ] + ) + b_bar = math.concat((math.zeros((dim_alpha), dtype=bi.dtype), bi[dim_alpha:]), axis=0) + poly_bar = math.hermite_renormalized( + A_bar, + b_bar, + complex(1), + (math.sum(shape_beta),) * dim_alpha + shape_beta, + ) + c_decomp = math.sum( + poly_bar * ci, + axes=math.arange( + len(poly_bar.shape) - dim_beta, len(poly_bar.shape), dtype=math.int32 + ).tolist(), + ) + A_decomp = math.block( + [ + [ + Ai[:dim_alpha, :dim_alpha], + math.eye(dim_alpha, dtype=Ai.dtype), + ], + [ + math.eye((dim_alpha), dtype=Ai.dtype), + math.zeros((dim_alpha, dim_alpha), dtype=Ai.dtype), + ], + ] + ) + b_decomp = math.concat((bi[:dim_alpha], math.zeros((dim_alpha), dtype=bi.dtype)), axis=0) + return A_decomp, b_decomp, c_decomp - def _call_none(self, z: Batch[Vector]) -> PolyExpAnsatz: + def _equal_no_array(self, other: Bargmann) -> bool: + self.simplify() + other.simplify() + return np.allclose(self.b, other.b, atol=1e-10) and np.allclose(self.A, other.A, atol=1e-10) + + def _generate_ansatz(self): r""" - Returns a new ansatz that corresponds to currying (partially evaluate) the current one. - For example, if ``self`` represents the function ``F(z1,z2)``, the call ``self.call_none([np.array([1.0, None]])`` - returns ``F(1.0, z2)`` as a new ansatz with a single variable. - Note that the batch of the triple and argument in this method is handled parwise, unlike the regular call where the batch over the triple is a superposition. + This method computes and sets the (A, b, c) given a function + and some kwargs. + """ + names = list(self._kwargs.keys()) + vars = list(self._kwargs.values()) - Args: - z: slice in C^n where the function is evaluated, while unevaluated along other axes of the space. + params = {} + param_types = [] + for name, param in zip(names, vars): + try: + params[name] = param.value + param_types.append(type(param)) + except AttributeError: + params[name] = param - Returns: - A new ansatz. + if self._c is None or Variable in param_types: + A, b, c = self._fn(**params) + self.A = A + self.b = b + self.c = c + + def _ipython_display_(self): + display(widgets.bargmann(self)) + + def _order_batch(self): + r""" + This method orders the batch dimension by the lexicographical order of the + flattened arrays (A, b, c). This is a very cheap way to enforce + an ordering of the batch dimension, which is useful for simplification and for + determining (in)equality between two Bargmann representations. """ + generators = [ + itertools.chain( + math.asnumpy(self.b[i]).flat, + math.asnumpy(self.A[i]).flat, + math.asnumpy(self.c[i]).flat, + ) + for i in range(self.batch_size) + ] + sorted_indices = argsort_gen(generators) + self.A = math.gather(self.A, sorted_indices, axis=0) + self.b = math.gather(self.b, sorted_indices, axis=0) + self.c = math.gather(self.c, sorted_indices, axis=0) - batch_abc = self.batch_size - batch_arg = z.shape[0] - Abc = [] - if batch_abc == 1 and batch_arg > 1: - for i in range(batch_arg): - Abc.append(self._call_none_single(self.A[0], self.b[0], self.c[0], z[i])) - elif batch_arg == 1 and batch_abc > 1: - for i in range(batch_abc): - Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[0])) - elif batch_abc == batch_arg: - for i in range(batch_abc): - Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[i])) + def __add__(self, other: Bargmann) -> Bargmann: + r""" + Adds two Bargmann representations together. This means concatenating them in the batch dimension. + In the case where c is a polynomial of different shapes it will add padding zeros to make + the shapes fit. Example: If the shape of c1 is (1,3,4,5) and the shape of c2 is (1,5,4,3) then the + shape of the combined object will be (2,5,4,5). + """ + combined_matrices = math.concat([self.A, other.A], axis=0) + combined_vectors = math.concat([self.b, other.b], axis=0) + + a0s = self.c.shape[1:] + a1s = other.c.shape[1:] + if a0s == a1s: + combined_arrays = math.concat([self.c, other.c], axis=0) else: - raise ValueError( - "Batch size of the ansatz and argument must match or one of the batch sizes must be 1." - ) - A, b, c = zip(*Abc) - return self.__class__(A=A, b=b, c=c) + s_max = np.maximum(np.array(a0s), np.array(a1s)) + + padding_array0 = np.array( + ( + np.zeros(len(s_max) + 1), + np.concatenate((np.array([0]), np.array((s_max - a0s)))), + ), + dtype=int, + ).T + padding_tuple0 = tuple(tuple(padding_array0[i]) for i in range(len(s_max) + 1)) + + padding_array1 = np.array( + ( + np.zeros(len(s_max) + 1), + np.concatenate((np.array([0]), np.array((s_max - a1s)))), + ), + dtype=int, + ).T + padding_tuple1 = tuple(tuple(padding_array1[i]) for i in range(len(s_max) + 1)) + a0_new = np.pad(self.c, padding_tuple0, "constant") + a1_new = np.pad(other.c, padding_tuple1, "constant") + combined_arrays = math.concat([a0_new, a1_new], axis=0) + # note output is not simplified + return Bargmann(combined_matrices, combined_vectors, combined_arrays) - def __and__(self, other: PolyExpAnsatz) -> PolyExpAnsatz: - r"""Tensor product of this ansatz with another ansatz. + def __and__(self, other: Bargmann) -> Bargmann: + r""" + Tensor product of this Bargmann with another Bargmann. Equivalent to :math:`F(a) * G(b)` (with different arguments, that is). As it distributes over addition on both self and other, the batch size of the result is the product of the batch - size of this anzatz and the other one. + size of this representation and the other one. Args: - other: Another ansatz. + other: Another Barmann. Returns: - The tensor product of this ansatz and other. + The tensor product of this Bargmann and other. """ def andA(A1, A2, dim_alpha1, dim_alpha2, dim_beta1, dim_beta2): @@ -778,13 +832,13 @@ def andc(c1, c2): ] bs = [andb(b1, b2, dim_alpha1, dim_alpha2) for b1, b2 in itertools.product(self.b, other.b)] cs = [andc(c1, c2) for c1, c2 in itertools.product(self.c, other.c)] - return self.__class__(As, bs, cs) + return Bargmann(As, bs, cs) - def __call__(self, z: Batch[Vector]) -> Scalar | PolyExpAnsatz: + def __call__(self, z: Batch[Vector]) -> Scalar | Bargmann: r""" - Returns either the value of the ansatz or a new ansatz depending on the argument. - If the argument contains None, returns a new ansatz. - If the argument only contains numbers, returns the value of the ansatz at that argument. + Returns either the value of the representation or a new representation depending on the argument. + If the argument contains None, returns a new representation. + If the argument only contains numbers, returns the value of the representation at that argument. Note that the batch dimensions are handled differently in the two cases. See subfunctions for furhter information. Args: @@ -798,17 +852,82 @@ def __call__(self, z: Batch[Vector]) -> Scalar | PolyExpAnsatz: else: return self._call_all(z) - def __mul__(self, other: Scalar | PolyExpAnsatz) -> PolyExpAnsatz: - r"""Multiplies this ansatz by a scalar or another ansatz or a plain scalar. + def __eq__(self, other: Bargmann) -> bool: + return self._equal_no_array(other) and np.allclose(self.c, other.c, atol=1e-10) + + def __neg__(self) -> Bargmann: + return Bargmann(self.A, self.b, -self.c) + + def __getitem__(self, idx: int | tuple[int, ...]) -> Bargmann: + r""" + A copy of self with the given indices marked for contraction. + """ + idx = (idx,) if isinstance(idx, int) else idx + for i in idx: + if i >= self.num_vars: + raise IndexError( + f"Index {i} out of bounds for representation of dimension {self.num_vars}." + ) + ret = Bargmann(self.A, self.b, self.c) + ret._contract_idxs = idx + return ret + + def __matmul__(self, other: Bargmann) -> Bargmann: + r""" + Implements the inner product in Bargmann representation. + + ..code-block:: + + >>> from mrmustard.physics.representations import Bargmann + >>> from mrmustard.physics.triples import displacement_gate_Abc, vacuum_state_Abc + >>> rep1 = Bargmann(*vacuum_state_Abc(1)) + >>> rep2 = Bargmann(*displacement_gate_Abc(1)) + >>> rep3 = rep1[0] @ rep2[1] + >>> assert np.allclose(rep3.A, [[0,],]) + >>> assert np.allclose(rep3.b, [1,]) + + Args: + other: Another Bargmann representation. + + Returns: + Bargmann: the resulting Bargmann representation. + + """ + if not isinstance(other, Bargmann): + raise NotImplementedError("Only matmul Bargmann with Bargmann") + + idx_s = self._contract_idxs + idx_o = other._contract_idxs + + Abc = [] + if settings.UNSAFE_ZIP_BATCH: + if self.batch_size != other.batch_size: + raise ValueError( + f"Batch size of the two ansatze must match since the settings.UNSAFE_ZIP_BATCH is {settings.UNSAFE_ZIP_BATCH}." + ) + for (A1, b1, c1), (A2, b2, c2) in zip( + zip(self.A, self.b, self.c), zip(other.A, other.b, other.c) + ): + Abc.append(contract_two_Abc_poly((A1, b1, c1), (A2, b2, c2), idx_s, idx_o)) + else: + for A1, b1, c1 in zip(self.A, self.b, self.c): + for A2, b2, c2 in zip(other.A, other.b, other.c): + Abc.append(contract_two_Abc_poly((A1, b1, c1), (A2, b2, c2), idx_s, idx_o)) + + A, b, c = zip(*Abc) + return Bargmann(A, b, c) + + def __mul__(self, other: Scalar | Bargmann) -> Bargmann: + r"""Multiplies this representation by a scalar or another Bargmann representation. Args: - other: A scalar or another ansatz. + other: A scalar or another Bargmann representation. Raises: - TypeError: If other is neither a scalar nor an ansatz. + TypeError: If other is neither a scalar nor a Bargmann representation. Returns: - PolyExpAnsatz: The product of this ansatz and other. + Bargmann: The product of this representation and other. """ @@ -845,7 +964,7 @@ def mul_c(c1, c2): c3 = math.reshape(math.outer(c1, c2), (c1.shape + c2.shape)) return c3 - if isinstance(other, PolyExpAnsatz): + if isinstance(other, Bargmann): dim_beta1, _ = self.polynomial_shape dim_beta2, _ = other.polynomial_shape @@ -868,24 +987,40 @@ def mul_c(c1, c2): new_b = [mul_b(b1, b2, dim_alpha) for b1, b2 in itertools.product(self.b, other.b)] new_c = [mul_c(c1, c2) for c1, c2 in itertools.product(self.c, other.c)] - return self.__class__(A=new_a, b=new_b, c=new_c) + return Bargmann(A=new_a, b=new_b, c=new_c) else: try: - return self.__class__(self.A, self.b, self.c * other) + return Bargmann(self.A, self.b, self.c * other) except Exception as e: raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e - def __truediv__(self, other: Scalar | PolyExpAnsatz) -> PolyExpAnsatz: - r"""Multiplies this ansatz by a scalar or another ansatz or a plain scalar. + def __rmul__(self, other: any | Scalar) -> any: + r""" + Multiplies this representation by another or by a scalar on the right. + """ + return self.__mul__(other) + + def __sub__(self, other): + r""" + Subtracts other from this representation. + """ + try: + return self.__add__(-other) + except AttributeError as e: + raise TypeError(f"Cannot subtract {self.__class__} and {other.__class__}.") from e + + def __truediv__(self, other: Scalar | Bargmann) -> Bargmann: + r""" + Multiplies this Bargmann by a scalar or another Bargmann. Args: - other: A scalar or another ansatz. + other: A scalar or another Bargmann. Raises: - TypeError: If other is neither a scalar nor an ansatz. + TypeError: If other is neither a scalar nor a Bargmann. Returns: - PolyExpAnsatz: The product of this ansatz and other. + Bargmann: The product of this Bargmann and other. """ @@ -922,7 +1057,7 @@ def div_c(c1, c2): c3 = math.reshape(math.outer(c1, c2), (c1.shape + c2.shape)) return c3 - if isinstance(other, PolyExpAnsatz): + if isinstance(other, Bargmann): dim_beta1, _ = self.polynomial_shape dim_beta2, _ = other.polynomial_shape if dim_beta1 == 0 and dim_beta2 == 0: @@ -945,317 +1080,11 @@ def div_c(c1, c2): new_b = [div_b(b1, -b2, dim_alpha) for b1, b2 in itertools.product(self.b, other.b)] new_c = [div_c(c1, 1 / c2) for c1, c2 in itertools.product(self.c, other.c)] - return self.__class__(A=new_a, b=new_b, c=new_c) + return Bargmann(A=new_a, b=new_b, c=new_c) else: raise NotImplementedError("Only implemented if both c are scalars") else: try: - return self.__class__(self.A, self.b, self.c / other) - except Exception as e: - raise TypeError(f"Cannot divide {self.__class__} and {other.__class__}.") from e - - -class ArrayAnsatz(Ansatz): - r""" - The ansatz of the Fock-Bargmann representation. - - Represents the ansatz as a multidimensional array. - - .. code-block:: - - >>> from mrmustard.physics.ansatze import ArrayAnsatz - - >>> array = np.random.random((2, 4, 5)) - >>> ansatz = ArrayAnsatz(array) - - Args: - array: A (potentially) batched array. - batched: Whether the array input has a batch dimension. - - Note: The args can be passed non-batched, as they will be automatically broadcasted to the - correct batch shape if ``batched`` is set to ``False``. - """ - - def __init__(self, array: Batch[Tensor], batched: bool = True): - super().__init__() - - self._array = array if batched else [array] - self._backend_array = False - self._original_abc_data = None - - @property - def array(self) -> Batch[Tensor]: - r""" - The array of this ansatz. - """ - self._generate_ansatz() - if not self._backend_array: - self._array = math.astensor(self._array) - self._backend_array = True - return self._array - - @array.setter - def array(self, value): - self._array = value - self._backend_array = False - - @property - def batch_size(self): - r""" - The batch size of this ansatz. - """ - return self.array.shape[0] - - @property - def conj(self): - r""" - The conjugate of this ansatz. - """ - return self.__class__(math.conj(self.array)) - - @property - def num_vars(self) -> int: - r""" - The number of variables in this ansatz. - """ - return len(self.array.shape) - 1 - - @property - def triple(self) -> tuple: - r""" - The data of the original PolyExpAnsatz if it exists. - """ - if self._original_abc_data is None: - raise AttributeError( - "This Fock object does not have an original Bargmann representation." - ) - return self._original_abc_data - - @classmethod - def from_function(cls, fn: Callable, **kwargs: Any) -> ArrayAnsatz: - r""" - Returns an ArrayAnsatz object from a generator function. - """ - ret = cls(None, True) - ret._fn = fn - ret._kwargs = kwargs - return ret - - def reduce(self, shape: int | Sequence[int]) -> ArrayAnsatz: - r""" - Returns a new ``ArrayAnsatz`` with a sliced array. - - Args: - shape: The shape of the array of the returned ``ArrayAnsatz``. - """ - if shape == self.array.shape[1:]: - return self - length = self.num_vars - shape = (shape,) * length if isinstance(shape, int) else shape - if len(shape) != length: - msg = f"Expected shape of length {length}, " - msg += f"given shape has length {len(shape)}." - raise ValueError(msg) - - if any(s > t for s, t in zip(shape, self.array.shape[1:])): - warn( - "Warning: the fock array is being padded with zeros. If possible slice the arrays this one will contract with instead." - ) - padded = math.pad( - self.array, - [(0, 0)] + [(0, s - t) for s, t in zip(shape, self.array.shape[1:])], - ) - return ArrayAnsatz(padded) - - ret = self.array[(slice(0, None),) + tuple(slice(0, s) for s in shape)] - return ArrayAnsatz(array=ret, batched=True) - - def _generate_ansatz(self): - r""" - This method computes and sets the array given a function - and some kwargs. - """ - if self._array is None: - self.array = [self._fn(**self._kwargs)] - - def __add__(self, other: ArrayAnsatz) -> ArrayAnsatz: - r""" - Adds the array of this ansatz and the array of another ansatz. - - Args: - other: Another ansatz. - - Raises: - ValueError: If the arrays don't have the same shape. - - Returns: - ArrayAnsatz: The addition of this ansatz and other. - """ - try: - diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) - if diff < 0: - new_array = [ - a + b for a in self.reduce(other.array.shape[1:]).array for b in other.array - ] - else: - new_array = [ - a + b for a in self.array for b in other.reduce(self.array.shape[1:]).array - ] - return self.__class__(array=new_array) - except Exception as e: - raise TypeError(f"Cannot add {self.__class__} and {other.__class__}.") from e - - def __and__(self, other: ArrayAnsatz) -> ArrayAnsatz: - r""" - Tensor product of this ansatz with another ansatz. - - Args: - other: Another ansatz. - - Returns: - The tensor product of this ansatz and other. - Batch size is the product of two batches. - """ - new_array = [math.outer(a, b) for a in self.array for b in other.array] - return self.__class__(array=new_array) - - def __call__(self, point: Any) -> Scalar: - r""" - Evaluates this ansatz at a given point in the domain. - """ - raise AttributeError("Cannot plot ArrayAnsatz.") - - def __eq__(self, other: Ansatz) -> bool: - r""" - Whether this ansatz's array is equal to another ansatz's array. - - Note that the comparison is done by numpy allclose with numpy's default rtol and atol. - - """ - slices = (slice(0, None),) + tuple( - slice(0, min(si, oi)) for si, oi in zip(self.array.shape[1:], other.array.shape[1:]) - ) - return np.allclose(self.array[slices], other.array[slices], atol=1e-10) - - def __mul__(self, other: Scalar | ArrayAnsatz) -> ArrayAnsatz: - r""" - Multiplies this ansatz by another ansatz. - - Args: - other: A scalar or another ansatz. - - Raises: - ValueError: If both of array don't have the same shape. - - Returns: - ArrayAnsatz: The product of this ansatz and other. - """ - if isinstance(other, ArrayAnsatz): - try: - diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) - if diff < 0: - new_array = [ - a * b for a in self.reduce(other.array.shape[1:]).array for b in other.array - ] - else: - new_array = [ - a * b for a in self.array for b in other.reduce(self.array.shape[1:]).array - ] - return self.__class__(array=new_array) - except Exception as e: - raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e - else: - ret = self.__class__(array=self.array * other) - ret._original_abc_data = ( - tuple(i * j for i, j in zip(self._original_abc_data, (1, 1, other))) - if self._original_abc_data is not None - else None - ) - return ret - - def __neg__(self) -> ArrayAnsatz: - r""" - Negates the values in the array. - """ - return self.__class__(array=-self.array) - - def __truediv__(self, other: Scalar | ArrayAnsatz) -> ArrayAnsatz: - r""" - Divides this ansatz by another ansatz. - - Args: - other: A scalar or another ansatz. - - Raises: - ValueError: If the arrays don't have the same shape. - - Returns: - ArrayAnsatz: The division of this ansatz and other. - """ - if isinstance(other, ArrayAnsatz): - try: - diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) - if diff < 0: - new_array = [ - a / b for a in self.reduce(other.array.shape[1:]).array for b in other.array - ] - else: - new_array = [ - a / b for a in self.array for b in other.reduce(self.array.shape[1:]).array - ] - return self.__class__(array=new_array) + return Bargmann(self.A, self.b, self.c / other) except Exception as e: raise TypeError(f"Cannot divide {self.__class__} and {other.__class__}.") from e - else: - ret = self.__class__(array=self.array / other) - ret._original_abc_data = ( - tuple(i / j for i, j in zip(self._original_abc_data, (1, 1, other))) - if self._original_abc_data is not None - else None - ) - return ret - - -def bargmann_Abc_to_phasespace_cov_means( - A: Matrix, b: Vector, c: Scalar, batched: bool = False -) -> tuple[Matrix, Vector, Scalar]: - r""" - Function to derive the covariance matrix and mean vector of a Gaussian state from its Wigner characteristic function in ABC form. - - The covariance matrix and mean vector can be used to write the characteristic function of a Gaussian state - :math: - \Chi_G(r) = \exp\left( -\frac{1}{2}r^T \Omega^T cov \Omega r + i r^T\Omega^T mean \right), - and the Wigner function of a Gaussian state: - :math: - W_G(r) = \frac{1}{\sqrt{\Det(cov)}} \exp\left( -\frac{1}{2}(r - mean)^T cov^{-1} (r-mean) \right). - - The internal expression of our Gaussian state :math:`\rho` is in Bargmann representation, one can write the characteristic function of a Gaussian state in Bargmann representation as - :math: - \Chi_G(\alpha) = \Tr(\rho D) = c \exp\left( -\frac{1}{2}\alpha^T A \alpha + \alpha^T b \right). - - This function is to go from the Abc triple in characteristic phase space into the covariance and mean vector for Gaussian state. - - Args: - A, b, c: The ``(A, b, c)`` triple of the state in characteristic phase space. - - Returns: - The covariance matrix, mean vector and coefficient of the state in phase space. - """ - # batched = len(A.shape) == 3 and len(b.shape) == 2 and len(c.shape) == 1 - A = math.atleast_3d(A) - b = math.atleast_2d(b) - c = math.atleast_1d(c) - num_modes = A.shape[-1] // 2 - Omega = math.cast(math.transpose(math.J(num_modes)), dtype=math.complex128) - W = math.transpose(math.conj(math.rotmat(num_modes))) - coeff = c - cov = [ - -Omega @ W @ Amat @ math.transpose(W) @ math.transpose(Omega) * settings.HBAR for Amat in A - ] - mean = [ - 1j * math.matvec(Omega @ W, bvec) * math.sqrt(settings.HBAR, dtype=math.complex128) - for bvec in b - ] - if batched: - return math.astensor(cov), math.astensor(mean), coeff - return cov[0], mean[0], coeff[0] diff --git a/mrmustard/physics/representations/base.py b/mrmustard/physics/representations/base.py new file mode 100644 index 000000000..798354715 --- /dev/null +++ b/mrmustard/physics/representations/base.py @@ -0,0 +1,85 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This module contains the classes for the available representations. +""" + +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Any, Callable + +from mrmustard.utils.typing import ( + Batch, + ComplexMatrix, + ComplexTensor, + ComplexVector, + Scalar, + Tensor, +) + +__all__ = ["Representation"] + + +class Representation(ABC): + r""" + A base class for representations. + + Representations can be initialized using the ``from_ansatz`` method, which automatically equips + them with all the functionality required to perform mathematical operations, such as equality, + multiplication, subtraction, etc. + """ + + def __init__(self) -> None: + self._contract_idxs: tuple[int, ...] = () + self._fn = None + self._kwargs = {} + + @property + @abstractmethod + def data(self) -> tuple | Tensor: + r""" + The data of the representation. + For now, it's the triple for Bargmann and the array for Fock. + """ + + @property + @abstractmethod + def scalar(self) -> Scalar: + r""" + The scalar part of the representation. + For now it's ``c`` for Bargmann and the array for Fock. + """ + + @property + @abstractmethod + def triple( + self, + ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: + r""" + The batch of triples :math:`(A_i, b_i, c_i)`. + """ + + @abstractmethod + def from_function(cls, fn: Callable, **kwargs: Any) -> Representation: + r""" + Returns a representation from a function and kwargs. + """ + + @abstractmethod + def reorder(self, order: tuple[int, ...] | list[int]) -> Representation: + r""" + Reorders the representation indices. + """ diff --git a/mrmustard/physics/representations/fock.py b/mrmustard/physics/representations/fock.py new file mode 100644 index 000000000..ff7809612 --- /dev/null +++ b/mrmustard/physics/representations/fock.py @@ -0,0 +1,465 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This module contains the classes for the available representations. +""" + +from __future__ import annotations +from typing import Any, Callable, Sequence + +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from IPython.display import display + +from mrmustard import math, widgets +from mrmustard.utils.typing import ( + Batch, + Scalar, + Tensor, +) + +from .base import Representation + +__all__ = ["Fock"] + + +class Fock: + r""" """ + + def __init__(self, array: Batch[Tensor], batched=False): + + self._fn = None + self._kwargs = {} + self._contract_idxs: tuple[int, ...] = () + + self._array = array if batched else [array] + self._backend_array = False + self._original_abc_data = None + + @property + def array(self) -> Batch[Tensor]: + r""" + The array of this ansatz. + """ + self._generate_ansatz() + if not self._backend_array: + self._array = math.astensor(self._array) + self._backend_array = True + return self._array + + @array.setter + def array(self, value): + self._array = value + self._backend_array = False + + @property + def batch_size(self): + r""" + The batch size of this ansatz. + """ + return self.array.shape[0] + + @property + def conj(self): + r""" + The conjugate of this ansatz. + """ + ret = Fock(math.conj(self.array), batched=True) + ret._contract_idxs = self._contract_idxs + return ret + + @property + def data(self) -> Batch[Tensor]: + r""" + The data of the representation. + """ + return self.array + + @property + def num_vars(self) -> int: + r""" + The number of variables in this ansatz. + """ + return len(self.array.shape) - 1 + + @property + def scalar(self) -> Scalar: + r""" + The scalar part of the representation. + I.e. the vacuum component of the Fock object, whatever it may be. + Given that the first axis of the array is the batch axis, this is the first element of the array. + """ + return self.array[(slice(None),) + (0,) * self.num_vars] + + @property + def triple(self) -> tuple: + r""" + The data of the original Bargmann if it exists. + """ + if self._original_abc_data is None: + raise AttributeError( + "This Fock object does not have an original Bargmann representation." + ) + return self._original_abc_data + + @classmethod + def from_dict(cls, data: dict[str, ArrayLike]) -> Fock: + """Deserialize a Fock instance.""" + return cls(data["array"], batched=True) + + @classmethod + def from_function(cls, fn: Callable, **kwargs: Any) -> Fock: + r""" + Returns a Fock object from a generator function. + """ + ret = cls(None, True) + ret._fn = fn + ret._kwargs = kwargs + return ret + + def reduce(self, shape: int | Sequence[int]) -> Fock: + r""" + Returns a new ``Fock`` with a sliced array. + + .. code-block:: + + >>> from mrmustard import math + >>> from mrmustard.physics.representations import Fock + + >>> array1 = math.arange(27).reshape((3, 3, 3)) + >>> fock1 = Fock(array1) + + >>> fock2 = fock1.reduce(3) + >>> assert fock1 == fock2 + + >>> fock3 = fock1.reduce(2) + >>> array3 = [[[0, 1], [3, 4]], [[9, 10], [12, 13]]] + >>> assert fock3 == Fock(array3) + + >>> fock4 = fock1.reduce((1, 3, 1)) + >>> array4 = [[[0], [3], [6]]] + >>> assert fock4 == Fock(array4) + + Args: + shape: The shape of the array of the returned ``Fock``. + """ + if shape == self.array.shape[1:]: + return self + length = self.num_vars + shape = (shape,) * length if isinstance(shape, int) else shape + if len(shape) != length: + msg = f"Expected shape of length {length}, " + msg += f"given shape has length {len(shape)}." + raise ValueError(msg) + + if any(s > t for s, t in zip(shape, self.array.shape[1:])): + warn( + "Warning: the fock array is being padded with zeros. If possible slice the arrays this one will contract with instead." + ) + padded = math.pad( + self.array, + [(0, 0)] + [(0, s - t) for s, t in zip(shape, self.array.shape[1:])], + ) + return Fock(padded, batched=True) + + ret = self.array[(slice(0, None),) + tuple(slice(0, s) for s in shape)] + return Fock(array=ret, batched=True) + + def reorder(self, order: tuple[int, ...] | list[int]) -> Fock: + r""" + Reorders the indices of the array with the given order. + + Args: + order: The order. Does not need to refer to the batch dimension. + + Returns: + The reordered Fock. + """ + + return Fock(math.transpose(self.array, [0] + [i + 1 for i in order]), batched=True) + + def sum_batch(self) -> Fock: + r""" + Sums over the batch dimension of the array. Turns an object with any batch size to a batch size of 1. + + Returns: + The collapsed Fock object. + """ + return Fock(math.sum(self.array, axes=[0]), batched=True) + + def to_dict(self) -> dict[str, ArrayLike]: + """Serialize a Fock instance.""" + return {"array": self.data} + + def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Fock: + r""" + Implements the partial trace over the given index pairs. + + Args: + idxs1: The first part of the pairs of indices to trace over. + idxs2: The second part. + + Returns: + The traced-over Fock object. + """ + if len(idxs1) != len(idxs2) or not set(idxs1).isdisjoint(idxs2): + raise ValueError("idxs must be of equal length and disjoint") + order = ( + [0] + + [i + 1 for i in range(len(self.array.shape) - 1) if i not in idxs1 + idxs2] + + [i + 1 for i in idxs1] + + [i + 1 for i in idxs2] + ) + new_array = math.transpose(self.array, order) + n = np.prod(new_array.shape[-len(idxs2) :]) + new_array = math.reshape(new_array, new_array.shape[: -2 * len(idxs1)] + (n, n)) + trace = math.trace(new_array) + return Fock([trace] if trace.shape == () else trace, batched=True) + + def _generate_ansatz(self): + r""" + This method computes and sets the array given a function + and some kwargs. + """ + if self._array is None: + self.array = [self._fn(**self._kwargs)] + + def _ipython_display_(self): + w = widgets.fock(self) + if w is None: + print(repr(self)) + return + display(w) + + def __add__(self, other: Fock) -> Fock: + r""" + Adds the array of this Fock representation and the array of another Fock representation. + + Args: + other: Another Fock representation. + + Raises: + ValueError: If the arrays don't have the same shape. + + Returns: + ArrayAnsatz: The addition of this representation and other. + """ + try: + diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) + if diff < 0: + new_array = [ + a + b for a in self.reduce(other.array.shape[1:]).array for b in other.array + ] + else: + new_array = [ + a + b for a in self.array for b in other.reduce(self.array.shape[1:]).array + ] + return Fock(array=new_array, batched=True) + except Exception as e: + raise TypeError(f"Cannot add {self.__class__} and {other.__class__}.") from e + + def __and__(self, other: Fock) -> Fock: + r""" + Tensor product of this Fock representation with another Fock representation. + + Args: + other: Another Fock representation. + + Returns: + The tensor product of this representation and other. + Batch size is the product of two batches. + """ + new_array = [math.outer(a, b) for a in self.array for b in other.array] + return Fock(array=new_array, batched=True) + + def __call__(self, point: Any) -> Scalar: + r""" + Evaluates this representation at a given point in the domain. + """ + raise AttributeError("Cannot call Fock.") + + def __eq__(self, other: Representation) -> bool: + r""" + Whether this ansatz's array is equal to another ansatz's array. + + Note that the comparison is done by numpy allclose with numpy's default rtol and atol. + + """ + slices = (slice(0, None),) + tuple( + slice(0, min(si, oi)) for si, oi in zip(self.array.shape[1:], other.array.shape[1:]) + ) + return np.allclose(self.array[slices], other.array[slices], atol=1e-10) + + def __getitem__(self, idx: int | tuple[int, ...]) -> Fock: + r""" + Returns a copy of self with the given indices marked for contraction. + """ + idx = (idx,) if isinstance(idx, int) else idx + for i in idx: + if i >= self.num_vars: + raise IndexError( + f"Index {i} out of bounds for representation with {self.num_vars} variables." + ) + ret = Fock(self.array) + ret._contract_idxs = idx + return ret + + def __matmul__(self, other: Fock) -> Fock: + r""" + Implements the inner product of fock arrays over the marked indices. + + .. code-block:: + >>> from mrmustard.physics.representations import Fock + >>> f = Fock(np.random.random((3, 5, 10))) # 10 is reduced to 8 + >>> g = Fock(np.random.random((2, 5, 8))) + >>> h = f[1,2] @ g[1,2] + >>> assert h.array.shape == (1,3,2) # batch size is 1 + >>> f = Fock(np.random.random((3, 5, 10)), batched=True) + >>> g = Fock(np.random.random((2, 5, 8)), batched=True) + >>> h = f[0,1] @ g[0,1] + >>> assert h.array.shape == (6,) # batch size is 3 x 2 = 6 + + Args: + other: Another representation. + + Returns: + A ``Fock``representation. + """ + if not isinstance(other, Fock): + raise NotImplementedError("only matmul Fock with Fock") + + idx_s = list(self._contract_idxs) + idx_o = list(other._contract_idxs) + + # the number of batches in self and other + n_batches_s = self.array.shape[0] + n_batches_o = other.array.shape[0] + + # the shapes each batch in self and other + shape_s = self.array.shape[1:] + shape_o = other.array.shape[1:] + + new_shape_s = list(shape_s) + new_shape_o = list(shape_o) + for s, o in zip(idx_s, idx_o): + new_shape_s[s] = min(shape_s[s], shape_o[o]) + new_shape_o[o] = min(shape_s[s], shape_o[o]) + + reduced_s = self.reduce(new_shape_s)[idx_s] + reduced_o = other.reduce(new_shape_o)[idx_o] + + axes = [list(idx_s), list(idx_o)] + batched_array = [] + for i in range(n_batches_s): + for j in range(n_batches_o): + batched_array.append(math.tensordot(reduced_s.array[i], reduced_o.array[j], axes)) + return Fock(batched_array, batched=True) + + def __mul__(self, other: Scalar | Fock) -> Fock: + r""" + Multiplies this Fock representation by another Fock representation. + + Args: + other: A scalar or another Fock representation. + + Raises: + ValueError: If both of array don't have the same shape. + + Returns: + ArrayAnsatz: The product of this representation and other. + """ + if isinstance(other, Fock): + try: + diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) + if diff < 0: + new_array = [ + a * b for a in self.reduce(other.array.shape[1:]).array for b in other.array + ] + else: + new_array = [ + a * b for a in self.array for b in other.reduce(self.array.shape[1:]).array + ] + return Fock(array=new_array, batched=True) + except Exception as e: + raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e + else: + ret = Fock(array=self.array * other, batched=True) + ret._original_abc_data = ( + tuple(i * j for i, j in zip(self._original_abc_data, (1, 1, other))) + if self._original_abc_data is not None + else None + ) + return ret + + def __neg__(self) -> Fock: + r""" + Negates the values in the array. + """ + return Fock(array=-self.array, batched=True) + + def __rmul__(self, other: Fock | Scalar) -> Fock: + r""" + Multiplies this representation by another or by a scalar on the right. + """ + return self.__mul__(other) + + def __sub__(self, other: Fock) -> Fock: + r""" + Subtracts other from this ansatz. + """ + try: + return self.__add__(-other) + except AttributeError as e: + raise TypeError(f"Cannot subtract {self.__class__} and {other.__class__}.") from e + + def __truediv__(self, other: Scalar | Fock) -> Fock: + r""" + Divides this Fock representation by another Fock representation. + + Args: + other: A scalar or another Fock representation. + + Raises: + ValueError: If the arrays don't have the same shape. + + Returns: + ArrayAnsatz: The division of this representation and other. + """ + if isinstance(other, Fock): + try: + diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) + if diff < 0: + new_array = [ + a / b for a in self.reduce(other.array.shape[1:]).array for b in other.array + ] + else: + new_array = [ + a / b for a in self.array for b in other.reduce(self.array.shape[1:]).array + ] + return Fock(array=new_array, batched=True) + except Exception as e: + raise TypeError(f"Cannot divide {self.__class__} and {other.__class__}.") from e + else: + ret = Fock(array=self.array / other, batched=True) + ret._original_abc_data = ( + tuple(i / j for i, j in zip(self._original_abc_data, (1, 1, other))) + if self._original_abc_data is not None + else None + ) + return ret diff --git a/tests/test_physics/test_bargmann.py b/tests/test_physics/test_bargmann_utils.py similarity index 100% rename from tests/test_physics/test_bargmann.py rename to tests/test_physics/test_bargmann_utils.py diff --git a/tests/test_physics/test_fock.py b/tests/test_physics/test_fock_utils.py similarity index 100% rename from tests/test_physics/test_fock.py rename to tests/test_physics/test_fock_utils.py diff --git a/tests/test_physics/test_representations.py b/tests/test_physics/test_representations.py deleted file mode 100644 index b0f03f7e1..000000000 --- a/tests/test_physics/test_representations.py +++ /dev/null @@ -1,432 +0,0 @@ -# Copyright 2022 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains tests for ``Representation`` objects.""" - -from unittest.mock import patch - -import numpy as np -from ipywidgets import Box, HBox, VBox, HTML, IntText, Stack, IntSlider, Tab -from plotly.graph_objs import FigureWidget -import pytest - -from mrmustard import math, settings -from mrmustard.physics.gaussian_integrals import ( - contract_two_Abc, - complex_gaussian_integral, -) -from mrmustard.physics.representations import Bargmann, Fock -from ..random import Abc_triple - -# original settings -autocutoff_max0 = settings.AUTOCUTOFF_MAX_CUTOFF - -# pylint: disable = missing-function-docstring - - -class TestBargmannRepresentation: - r""" - Tests the Bargmann Representation. - """ - - Abc_n1 = Abc_triple(1) - Abc_n2 = Abc_triple(2) - Abc_n3 = Abc_triple(3) - - @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) - def test_init_non_batched(self, triple): - A, b, c = triple - bargmann = Bargmann(*triple) - - assert np.allclose(bargmann.A, A) - assert np.allclose(bargmann.b, b) - assert np.allclose(bargmann.c, c) - - @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) - def test_init_from_ansatz(self, triple): - bargmann1 = Bargmann(*triple) - bargmann2 = Bargmann.from_ansatz(bargmann1.ansatz) - - assert bargmann1 == bargmann2 - - @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) - def test_conj(self, triple): - A, b, c = triple - bargmann = Bargmann(*triple).conj() - - assert np.allclose(bargmann.A, math.conj(A)) - assert np.allclose(bargmann.b, math.conj(b)) - assert np.allclose(bargmann.c, math.conj(c)) - - @pytest.mark.parametrize("n", [1, 2, 3]) - def test_and(self, n): - triple1 = Abc_triple(n) - triple2 = Abc_triple(n) - - bargmann = Bargmann(*triple1) & Bargmann(*triple2) - - assert bargmann.A.shape == (1, 2 * n, 2 * n) - assert bargmann.b.shape == (1, 2 * n) - assert bargmann.c.shape == (1,) - - @pytest.mark.parametrize("scalar", [0.5, 1.2]) - @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) - def test_mul_with_scalar(self, scalar, triple): - bargmann1 = Bargmann(*triple) - bargmann_mul = bargmann1 * scalar - - assert np.allclose(bargmann1.A, bargmann_mul.A) - assert np.allclose(bargmann1.b, bargmann_mul.b) - assert np.allclose(bargmann1.c * scalar, bargmann_mul.c) - - @pytest.mark.parametrize("n", [1, 2, 3]) - def test_mul(self, n): - triple1 = Abc_triple(n) - triple2 = Abc_triple(n) - - bargmann1 = Bargmann(*triple1) - bargmann2 = Bargmann(*triple2) - bargmann_mul = bargmann1 * bargmann2 - - assert np.allclose(bargmann_mul.A, bargmann1.A + bargmann2.A) - assert np.allclose(bargmann_mul.b, bargmann1.b + bargmann2.b) - assert np.allclose(bargmann_mul.c, bargmann1.c * bargmann2.c) - - @pytest.mark.parametrize("scalar", [0.5, 1.2]) - @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) - def test_div_with_scalar(self, scalar, triple): - bargmann1 = Bargmann(*triple) - bargmann_div = bargmann1 / scalar - - assert np.allclose(bargmann1.A, bargmann_div.A) - assert np.allclose(bargmann1.b, bargmann_div.b) - assert np.allclose(bargmann1.c / scalar, bargmann_div.c) - - @pytest.mark.parametrize("n", [1, 2, 3]) - def test_div(self, n): - triple1 = Abc_triple(n) - triple2 = Abc_triple(n) - - bargmann1 = Bargmann(*triple1) - bargmann2 = Bargmann(*triple2) - bargmann_div = bargmann1 / bargmann2 - - assert np.allclose(bargmann_div.A, bargmann1.A - bargmann2.A) - assert np.allclose(bargmann_div.b, bargmann1.b - bargmann2.b) - assert np.allclose(bargmann_div.c, bargmann1.c / bargmann2.c) - - @pytest.mark.parametrize("n", [1, 2, 3]) - def test_add(self, n): - triple1 = Abc_triple(n) - triple2 = Abc_triple(n) - - bargmann1 = Bargmann(*triple1) - bargmann2 = Bargmann(*triple2) - bargmann_add = bargmann1 + bargmann2 - - assert np.allclose(bargmann_add.A, math.concat([bargmann1.A, bargmann2.A], axis=0)) - assert np.allclose(bargmann_add.b, math.concat([bargmann1.b, bargmann2.b], axis=0)) - assert np.allclose(bargmann_add.c, math.concat([bargmann1.c, bargmann2.c], axis=0)) - - def test_add_error(self): - bargmann = Bargmann(*Abc_triple(3)) - fock = Fock(np.random.random((1, 4, 4, 4)), batched=True) - - with pytest.raises(ValueError): - bargmann + fock # pylint: disable=pointless-statement - - @pytest.mark.parametrize("n", [1, 2, 3]) - def test_sub(self, n): - triple1 = Abc_triple(n) - triple2 = Abc_triple(n) - - bargmann1 = Bargmann(*triple1) - bargmann2 = Bargmann(*triple2) - bargmann_add = bargmann1 - bargmann2 - - assert np.allclose(bargmann_add.A, math.concat([bargmann1.A, bargmann2.A], axis=0)) - assert np.allclose(bargmann_add.b, math.concat([bargmann1.b, bargmann2.b], axis=0)) - assert np.allclose(bargmann_add.c, math.concat([bargmann1.c, -bargmann2.c], axis=0)) - - def test_trace(self): - triple = Abc_triple(4) - bargmann = Bargmann(*triple).trace([0], [2]) - A, b, c = complex_gaussian_integral(triple, [0], [2]) - - assert np.allclose(bargmann.A, A) - assert np.allclose(bargmann.b, b) - assert np.allclose(bargmann.c, c) - - def test_reorder(self): - triple = Abc_triple(3) - bargmann = Bargmann(*triple).reorder((0, 2, 1)) - - assert np.allclose(bargmann.A[0], triple[0][[0, 2, 1], :][:, [0, 2, 1]]) - assert np.allclose(bargmann.b[0], triple[1][[0, 2, 1]]) - - @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) - def test_call(self, triple): - bargmann = Bargmann(*triple) - z = 0.1 + 0.2j - dim = triple[0].shape[0] - assert bargmann(z * np.ones(dim)) == bargmann.ansatz(z * np.ones(dim)) - - def test_matmul_barg_barg(self): - triple1 = Abc_triple(3) - triple2 = Abc_triple(3) - - res1 = Bargmann(*triple1) @ Bargmann(*triple2) - exp1 = contract_two_Abc(triple1, triple2, [], []) - assert np.allclose(res1.A, exp1[0]) - assert np.allclose(res1.b, exp1[1]) - assert np.allclose(res1.c, exp1[2]) - - @patch("mrmustard.physics.representations.display") - def test_ipython_repr(self, mock_display): - """Test the IPython repr function.""" - rep = Bargmann(*Abc_triple(2)) - rep._ipython_display_() # pylint:disable=protected-access - [box] = mock_display.call_args.args - assert isinstance(box, Box) - assert box.layout.max_width == "50%" - - # data on left, eigvals on right - [data_vbox, eigs_vbox] = box.children - assert isinstance(data_vbox, VBox) - assert isinstance(eigs_vbox, VBox) - - # data forms a stack: header, ansatz, triple data - [header, sub, table] = data_vbox.children - assert isinstance(header, HTML) - assert isinstance(sub, HBox) - assert isinstance(table, HTML) - - # ansatz goes beside button to modify rounding - [ansatz, round_w] = sub.children - assert isinstance(ansatz, HTML) - assert isinstance(round_w, IntText) - - # eigvals have a header and a unit circle plot - [eig_header, unit_circle] = eigs_vbox.children - assert isinstance(eig_header, HTML) - assert isinstance(unit_circle, FigureWidget) - - @patch("mrmustard.physics.representations.display") - def test_ipython_repr_batched(self, mock_display): - """Test the IPython repr function for a batched repr.""" - A1, b1, c1 = Abc_triple(2) - A2, b2, c2 = Abc_triple(2) - rep = Bargmann(np.array([A1, A2]), np.array([b1, b2]), np.array([c1, c2])) - rep._ipython_display_() # pylint:disable=protected-access - [vbox] = mock_display.call_args.args - assert isinstance(vbox, VBox) - - [slider, stack] = vbox.children - assert isinstance(slider, IntSlider) - assert slider.max == 1 # the batch size - 1 - assert isinstance(stack, Stack) - - # max_width is spot-check that this is bargmann widget - assert len(stack.children) == 2 - assert all(box.layout.max_width == "50%" for box in stack.children) - - -class TestFockRepresentation: # pylint:disable=too-many-public-methods - r"""Tests the Fock Representation.""" - - array578 = np.random.random((5, 7, 8)) - array1578 = np.random.random((1, 5, 7, 8)) - array2578 = np.random.random((2, 5, 7, 8)) - array5578 = np.random.random((5, 5, 7, 8)) - - def test_init_batched(self): - fock = Fock(self.array1578, batched=True) - assert isinstance(fock, Fock) - assert np.allclose(fock.array, self.array1578) - - def test_init_non_batched(self): - fock = Fock(self.array578, batched=False) - assert isinstance(fock, Fock) - assert fock.array.shape == (1, 5, 7, 8) - assert np.allclose(fock.array[0, :, :, :], self.array578) - - def test_init_from_ansatz(self): - fock1 = Fock(self.array5578) - fock2 = Fock.from_ansatz(fock1.ansatz) - assert fock1 == fock2 - - def test_sum_batch(self): - fock = Fock(self.array2578, batched=True) - fock_collapsed = fock.sum_batch()[0] - assert fock_collapsed.array.shape == (1, 5, 7, 8) - assert np.allclose(fock_collapsed.array, np.sum(self.array2578, axis=0)) - - def test_and(self): - fock1 = Fock(self.array1578, batched=True) - fock2 = Fock(self.array5578, batched=True) - fock_test = fock1 & fock2 - assert fock_test.array.shape == (5, 5, 7, 8, 5, 7, 8) - assert np.allclose( - math.reshape(fock_test.array, -1), - math.reshape(np.einsum("bcde, pfgh -> bpcdefgh", self.array1578, self.array5578), -1), - ) - - def test_multiply_a_scalar(self): - fock1 = Fock(self.array1578, batched=True) - fock_test = 1.3 * fock1 - assert np.allclose(fock_test.array, 1.3 * self.array1578) - - def test_mul(self): - fock1 = Fock(self.array1578, batched=True) - fock2 = Fock(self.array5578, batched=True) - fock1_mul_fock2 = fock1 * fock2 - assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) - assert np.allclose( - math.reshape(fock1_mul_fock2.array, -1), - math.reshape(np.einsum("bcde, pcde -> bpcde", self.array1578, self.array5578), -1), - ) - - def test_divide_on_a_scalar(self): - fock1 = Fock(self.array1578, batched=True) - fock_test = fock1 / 1.5 - assert np.allclose(fock_test.array, self.array1578 / 1.5) - - def test_truediv(self): - fock1 = Fock(self.array1578, batched=True) - fock2 = Fock(self.array5578, batched=True) - fock1_mul_fock2 = fock1 / fock2 - assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) - assert np.allclose( - math.reshape(fock1_mul_fock2.array, -1), - math.reshape(np.einsum("bcde, pcde -> bpcde", self.array1578, 1 / self.array5578), -1), - ) - - def test_conj(self): - fock = Fock(self.array1578, batched=True) - fock_conj = fock.conj() - assert np.allclose(fock_conj.array, np.conj(self.array1578)) - - def test_matmul_fock_fock(self): - array2 = math.astensor(np.random.random((5, 6, 7, 8, 10))) - fock1 = Fock(self.array2578, batched=True) - fock2 = Fock(array2, batched=True) - fock_test = fock1[2] @ fock2[2] - assert fock_test.array.shape == (10, 5, 7, 6, 7, 10) - assert np.allclose( - math.reshape(fock_test.array, -1), - math.reshape(np.einsum("bcde, pfgeh -> bpcdfgh", self.array2578, array2), -1), - ) - - def test_add(self): - fock1 = Fock(self.array2578, batched=True) - fock2 = Fock(self.array5578, batched=True) - fock1_add_fock2 = fock1 + fock2 - assert fock1_add_fock2.array.shape == (10, 5, 7, 8) - assert np.allclose(fock1_add_fock2.array[0], self.array2578[0] + self.array5578[0]) - assert np.allclose(fock1_add_fock2.array[4], self.array2578[0] + self.array5578[4]) - assert np.allclose(fock1_add_fock2.array[5], self.array2578[1] + self.array5578[0]) - - def test_sub(self): - fock1 = Fock(self.array2578, batched=True) - fock2 = Fock(self.array5578, batched=True) - fock1_sub_fock2 = fock1 - fock2 - assert fock1_sub_fock2.array.shape == (10, 5, 7, 8) - assert np.allclose(fock1_sub_fock2.array[0], self.array2578[0] - self.array5578[0]) - assert np.allclose(fock1_sub_fock2.array[4], self.array2578[0] - self.array5578[4]) - assert np.allclose(fock1_sub_fock2.array[9], self.array2578[1] - self.array5578[4]) - - def test_trace(self): - array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) - fock1 = Fock(array1, batched=True) - fock2 = fock1.trace(idxs1=[0, 3], idxs2=[1, 6]) - assert fock2.array.shape == (2, 1, 4, 1, 3) - assert np.allclose(fock2.array, np.einsum("bccefghfj -> beghj", array1)) - - def test_reorder(self): - array1 = math.astensor(np.arange(8).reshape((1, 2, 2, 2))) - fock1 = Fock(array1, batched=True) - fock2 = fock1.reorder(order=(2, 1, 0)) - assert np.allclose(fock2.array, np.array([[[[0, 4], [2, 6]], [[1, 5], [3, 7]]]])) - assert np.allclose(fock2.array, np.arange(8).reshape((1, 2, 2, 2), order="F")) - - @pytest.mark.parametrize("batched", [True, False]) - def test_reduce(self, batched): - shape = (1, 3, 3, 3) if batched else (3, 3, 3) - array1 = math.astensor(np.arange(27).reshape(shape)) - fock1 = Fock(array1, batched=batched) - - fock2 = fock1.reduce(3) - assert fock1 == fock2 - - fock3 = fock1.reduce(2) - array3 = math.astensor([[[0, 1], [3, 4]], [[9, 10], [12, 13]]]) - assert fock3 == Fock(array3) - - fock4 = fock1.reduce((1, 3, 1)) - array4 = math.astensor([[[0], [3], [6]]]) - assert fock4 == Fock(array4) - - def test_reduce_error(self): - array1 = math.astensor(np.arange(27).reshape((3, 3, 3))) - fock1 = Fock(array1) - - with pytest.raises(ValueError, match="Expected shape"): - fock1.reduce((1, 2)) - - with pytest.raises(ValueError, match="Expected shape"): - fock1.reduce((1, 2, 3, 4, 5)) - - def test_reduce_padded(self): - fock = Fock(self.array578) - with pytest.warns(UserWarning): - fock1 = fock.reduce((8, 8, 8)) - assert fock1.array.shape == (1, 8, 8, 8) - - @pytest.mark.parametrize("shape", [(1, 8), (1, 8, 8)]) - @patch("mrmustard.physics.representations.display") - def test_ipython_repr(self, mock_display, shape): - """Test the IPython repr function.""" - rep = Fock(np.random.random(shape), batched=True) - rep._ipython_display_() # pylint:disable=protected-access - [hbox] = mock_display.call_args.args - assert isinstance(hbox, HBox) - - # the CSS, the header+ansatz, and the tabs of plots - [css, left, plots] = hbox.children - assert isinstance(css, HTML) - assert isinstance(left, VBox) - assert isinstance(plots, Tab) - - # left contains header and ansatz - left = left.children - assert len(left) == 2 and all(isinstance(w, HTML) for w in left) - - # one plot for magnitude, another for phase - assert plots.titles == ("Magnitude", "Phase") - plots = plots.children - assert len(plots) == 2 and all(isinstance(p, FigureWidget) for p in plots) - - @patch("mrmustard.physics.representations.display") - def test_ipython_repr_expects_batch_1(self, mock_display): - """Test the IPython repr function does nothing with real batch.""" - rep = Fock(np.random.random((2, 8)), batched=True) - rep._ipython_display_() # pylint:disable=protected-access - mock_display.assert_not_called() - - @patch("mrmustard.physics.representations.display") - def test_ipython_repr_expects_3_dims_or_less(self, mock_display): - """Test the IPython repr function does nothing with 4+ dims.""" - rep = Fock(np.random.random((1, 4, 4, 4)), batched=True) - rep._ipython_display_() # pylint:disable=protected-access - mock_display.assert_not_called() diff --git a/tests/test_physics/test_representations/__init__.py b/tests/test_physics/test_representations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py new file mode 100644 index 000000000..8bfd2fa3f --- /dev/null +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -0,0 +1,231 @@ +# Copyright 2022 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains tests for ``Representation`` objects.""" + +from unittest.mock import patch + +import numpy as np +from ipywidgets import Box, HBox, VBox, HTML, IntText, Stack, IntSlider +from plotly.graph_objs import FigureWidget +import pytest + +from mrmustard import math, settings +from mrmustard.physics.gaussian_integrals import ( + contract_two_Abc, + complex_gaussian_integral, +) +from mrmustard.physics.representations.bargmann import Bargmann +from ...random import Abc_triple + +# original settings +autocutoff_max0 = settings.AUTOCUTOFF_MAX_CUTOFF + +# pylint: disable = missing-function-docstring + + +class TestBargmannRepresentation: + r""" + Tests the Bargmann Representation. + """ + + Abc_n1 = Abc_triple(1) + Abc_n2 = Abc_triple(2) + Abc_n3 = Abc_triple(3) + + @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) + def test_init_non_batched(self, triple): + A, b, c = triple + bargmann = Bargmann(*triple) + + assert np.allclose(bargmann.A, A) + assert np.allclose(bargmann.b, b) + assert np.allclose(bargmann.c, c) + + @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) + def test_conj(self, triple): + A, b, c = triple + bargmann = Bargmann(*triple).conj() + + assert np.allclose(bargmann.A, math.conj(A)) + assert np.allclose(bargmann.b, math.conj(b)) + assert np.allclose(bargmann.c, math.conj(c)) + + @pytest.mark.parametrize("n", [1, 2, 3]) + def test_and(self, n): + triple1 = Abc_triple(n) + triple2 = Abc_triple(n) + + temp1 = Bargmann(*triple1) + print(temp1.A.shape) + + bargmann = Bargmann(*triple1) & Bargmann(*triple2) + + assert bargmann.A.shape == (1, 2 * n, 2 * n) + assert bargmann.b.shape == (1, 2 * n) + assert bargmann.c.shape == (1,) + + @pytest.mark.parametrize("scalar", [0.5, 1.2]) + @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) + def test_mul_with_scalar(self, scalar, triple): + bargmann1 = Bargmann(*triple) + bargmann_mul = bargmann1 * scalar + + assert np.allclose(bargmann1.A, bargmann_mul.A) + assert np.allclose(bargmann1.b, bargmann_mul.b) + assert np.allclose(bargmann1.c * scalar, bargmann_mul.c) + + @pytest.mark.parametrize("n", [1, 2, 3]) + def test_mul(self, n): + triple1 = Abc_triple(n) + triple2 = Abc_triple(n) + + bargmann1 = Bargmann(*triple1) + bargmann2 = Bargmann(*triple2) + bargmann_mul = bargmann1 * bargmann2 + + assert np.allclose(bargmann_mul.A, bargmann1.A + bargmann2.A) + assert np.allclose(bargmann_mul.b, bargmann1.b + bargmann2.b) + assert np.allclose(bargmann_mul.c, bargmann1.c * bargmann2.c) + + @pytest.mark.parametrize("scalar", [0.5, 1.2]) + @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) + def test_div_with_scalar(self, scalar, triple): + bargmann1 = Bargmann(*triple) + bargmann_div = bargmann1 / scalar + + assert np.allclose(bargmann1.A, bargmann_div.A) + assert np.allclose(bargmann1.b, bargmann_div.b) + assert np.allclose(bargmann1.c / scalar, bargmann_div.c) + + @pytest.mark.parametrize("n", [1, 2, 3]) + def test_div(self, n): + triple1 = Abc_triple(n) + triple2 = Abc_triple(n) + + bargmann1 = Bargmann(*triple1) + bargmann2 = Bargmann(*triple2) + bargmann_div = bargmann1 / bargmann2 + + assert np.allclose(bargmann_div.A, bargmann1.A - bargmann2.A) + assert np.allclose(bargmann_div.b, bargmann1.b - bargmann2.b) + assert np.allclose(bargmann_div.c, bargmann1.c / bargmann2.c) + + @pytest.mark.parametrize("n", [1, 2, 3]) + def test_add(self, n): + triple1 = Abc_triple(n) + triple2 = Abc_triple(n) + + bargmann1 = Bargmann(*triple1) + bargmann2 = Bargmann(*triple2) + bargmann_add = bargmann1 + bargmann2 + + assert np.allclose(bargmann_add.A, math.concat([bargmann1.A, bargmann2.A], axis=0)) + assert np.allclose(bargmann_add.b, math.concat([bargmann1.b, bargmann2.b], axis=0)) + assert np.allclose(bargmann_add.c, math.concat([bargmann1.c, bargmann2.c], axis=0)) + + # def test_add_error(self): + # bargmann = Bargmann(*Abc_triple(3)) + # fock = Fock(np.random.random((1, 4, 4, 4)), batched=True) + + # with pytest.raises(ValueError): + # bargmann + fock # pylint: disable=pointless-statement + + @pytest.mark.parametrize("n", [1, 2, 3]) + def test_sub(self, n): + triple1 = Abc_triple(n) + triple2 = Abc_triple(n) + + bargmann1 = Bargmann(*triple1) + bargmann2 = Bargmann(*triple2) + bargmann_add = bargmann1 - bargmann2 + + assert np.allclose(bargmann_add.A, math.concat([bargmann1.A, bargmann2.A], axis=0)) + assert np.allclose(bargmann_add.b, math.concat([bargmann1.b, bargmann2.b], axis=0)) + assert np.allclose(bargmann_add.c, math.concat([bargmann1.c, -bargmann2.c], axis=0)) + + def test_trace(self): + triple = Abc_triple(4) + bargmann = Bargmann(*triple).trace([0], [2]) + A, b, c = complex_gaussian_integral(triple, [0], [2]) + + assert np.allclose(bargmann.A, A) + assert np.allclose(bargmann.b, b) + assert np.allclose(bargmann.c, c) + + def test_reorder(self): + triple = Abc_triple(3) + bargmann = Bargmann(*triple).reorder((0, 2, 1)) + + assert np.allclose(bargmann.A[0], triple[0][[0, 2, 1], :][:, [0, 2, 1]]) + assert np.allclose(bargmann.b[0], triple[1][[0, 2, 1]]) + + def test_matmul_barg_barg(self): + triple1 = Abc_triple(3) + triple2 = Abc_triple(3) + + res1 = Bargmann(*triple1) @ Bargmann(*triple2) + exp1 = contract_two_Abc(triple1, triple2, [], []) + assert np.allclose(res1.A, exp1[0]) + assert np.allclose(res1.b, exp1[1]) + assert np.allclose(res1.c, exp1[2]) + + # @patch("mrmustard.physics.representations.bargmann.display") + # def test_ipython_repr(self, mock_display): + # """Test the IPython repr function.""" + # rep = Bargmann(*Abc_triple(2)) + # rep._ipython_display_() # pylint:disable=protected-access + # [box] = mock_display.call_args.args + # assert isinstance(box, Box) + # assert box.layout.max_width == "50%" + + # # data on left, eigvals on right + # [data_vbox, eigs_vbox] = box.children + # assert isinstance(data_vbox, VBox) + # assert isinstance(eigs_vbox, VBox) + + # # data forms a stack: header, ansatz, triple data + # [header, sub, table] = data_vbox.children + # assert isinstance(header, HTML) + # assert isinstance(sub, HBox) + # assert isinstance(table, HTML) + + # # ansatz goes beside button to modify rounding + # [ansatz, round_w] = sub.children + # assert isinstance(ansatz, HTML) + # assert isinstance(round_w, IntText) + + # # eigvals have a header and a unit circle plot + # [eig_header, unit_circle] = eigs_vbox.children + # assert isinstance(eig_header, HTML) + # assert isinstance(unit_circle, FigureWidget) + + # @patch("mrmustard.physics.representations.bargmann.display") + # def test_ipython_repr_batched(self, mock_display): + # """Test the IPython repr function for a batched repr.""" + # A1, b1, c1 = Abc_triple(2) + # A2, b2, c2 = Abc_triple(2) + # rep = Bargmann(np.array([A1, A2]), np.array([b1, b2]), np.array([c1, c2])) + # rep._ipython_display_() # pylint:disable=protected-access + # [vbox] = mock_display.call_args.args + # assert isinstance(vbox, VBox) + + # [slider, stack] = vbox.children + # assert isinstance(slider, IntSlider) + # assert slider.max == 1 # the batch size - 1 + # assert isinstance(stack, Stack) + + # # max_width is spot-check that this is bargmann widget + # assert len(stack.children) == 2 + # assert all(box.layout.max_width == "50%" for box in stack.children) From b6feb3b2c253f55bd9ccf68cc48a19d321551c24 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 10:19:12 -0400 Subject: [PATCH 04/87] init --- .../test_representations/test_fock.py | 215 ++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 tests/test_physics/test_representations/test_fock.py diff --git a/tests/test_physics/test_representations/test_fock.py b/tests/test_physics/test_representations/test_fock.py new file mode 100644 index 000000000..8e47f9ce3 --- /dev/null +++ b/tests/test_physics/test_representations/test_fock.py @@ -0,0 +1,215 @@ +# Copyright 2022 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains tests for ``Representation`` objects.""" + +from unittest.mock import patch + +import numpy as np +from ipywidgets import Box, HBox, VBox, HTML, Tab +from plotly.graph_objs import FigureWidget +import pytest + +from mrmustard import math, settings +from mrmustard.physics.representations.fock import Fock + +# original settings +autocutoff_max0 = settings.AUTOCUTOFF_MAX_CUTOFF + +# pylint: disable = missing-function-docstring + + +class TestFockRepresentation: # pylint:disable=too-many-public-methods + r"""Tests the Fock Representation.""" + + array578 = np.random.random((5, 7, 8)) + array1578 = np.random.random((1, 5, 7, 8)) + array2578 = np.random.random((2, 5, 7, 8)) + array5578 = np.random.random((5, 5, 7, 8)) + + def test_init_batched(self): + fock = Fock(self.array1578, batched=True) + assert isinstance(fock, Fock) + assert np.allclose(fock.array, self.array1578) + + def test_init_non_batched(self): + fock = Fock(self.array578, batched=False) + assert isinstance(fock, Fock) + assert fock.array.shape == (1, 5, 7, 8) + assert np.allclose(fock.array[0, :, :, :], self.array578) + + def test_sum_batch(self): + fock = Fock(self.array2578, batched=True) + fock_collapsed = fock.sum_batch()[0] + assert fock_collapsed.array.shape == (1, 5, 7, 8) + assert np.allclose(fock_collapsed.array, np.sum(self.array2578, axis=0)) + + def test_and(self): + fock1 = Fock(self.array1578, batched=True) + fock2 = Fock(self.array5578, batched=True) + fock_test = fock1 & fock2 + assert fock_test.array.shape == (5, 5, 7, 8, 5, 7, 8) + assert np.allclose( + math.reshape(fock_test.array, -1), + math.reshape(np.einsum("bcde, pfgh -> bpcdefgh", self.array1578, self.array5578), -1), + ) + + def test_multiply_a_scalar(self): + fock1 = Fock(self.array1578, batched=True) + fock_test = 1.3 * fock1 + assert np.allclose(fock_test.array, 1.3 * self.array1578) + + def test_mul(self): + fock1 = Fock(self.array1578, batched=True) + fock2 = Fock(self.array5578, batched=True) + fock1_mul_fock2 = fock1 * fock2 + assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) + assert np.allclose( + math.reshape(fock1_mul_fock2.array, -1), + math.reshape(np.einsum("bcde, pcde -> bpcde", self.array1578, self.array5578), -1), + ) + + def test_divide_on_a_scalar(self): + fock1 = Fock(self.array1578, batched=True) + fock_test = fock1 / 1.5 + assert np.allclose(fock_test.array, self.array1578 / 1.5) + + def test_truediv(self): + fock1 = Fock(self.array1578, batched=True) + fock2 = Fock(self.array5578, batched=True) + fock1_mul_fock2 = fock1 / fock2 + assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) + assert np.allclose( + math.reshape(fock1_mul_fock2.array, -1), + math.reshape(np.einsum("bcde, pcde -> bpcde", self.array1578, 1 / self.array5578), -1), + ) + + def test_conj(self): + fock = Fock(self.array1578, batched=True) + fock_conj = fock.conj + assert np.allclose(fock_conj.array, np.conj(self.array1578)) + + def test_matmul_fock_fock(self): + array2 = math.astensor(np.random.random((5, 6, 7, 8, 10))) + fock1 = Fock(self.array2578, batched=True) + fock2 = Fock(array2, batched=True) + fock_test = fock1[2] @ fock2[2] + assert fock_test.array.shape == (10, 5, 7, 6, 7, 10) + assert np.allclose( + math.reshape(fock_test.array, -1), + math.reshape(np.einsum("bcde, pfgeh -> bpcdfgh", self.array2578, array2), -1), + ) + + def test_add(self): + fock1 = Fock(self.array2578, batched=True) + fock2 = Fock(self.array5578, batched=True) + fock1_add_fock2 = fock1 + fock2 + assert fock1_add_fock2.array.shape == (10, 5, 7, 8) + assert np.allclose(fock1_add_fock2.array[0], self.array2578[0] + self.array5578[0]) + assert np.allclose(fock1_add_fock2.array[4], self.array2578[0] + self.array5578[4]) + assert np.allclose(fock1_add_fock2.array[5], self.array2578[1] + self.array5578[0]) + + def test_sub(self): + fock1 = Fock(self.array2578, batched=True) + fock2 = Fock(self.array5578, batched=True) + fock1_sub_fock2 = fock1 - fock2 + assert fock1_sub_fock2.array.shape == (10, 5, 7, 8) + assert np.allclose(fock1_sub_fock2.array[0], self.array2578[0] - self.array5578[0]) + assert np.allclose(fock1_sub_fock2.array[4], self.array2578[0] - self.array5578[4]) + assert np.allclose(fock1_sub_fock2.array[9], self.array2578[1] - self.array5578[4]) + + def test_trace(self): + array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) + fock1 = Fock(array1, batched=True) + fock2 = fock1.trace(idxs1=[0, 3], idxs2=[1, 6]) + assert fock2.array.shape == (2, 1, 4, 1, 3) + assert np.allclose(fock2.array, np.einsum("bccefghfj -> beghj", array1)) + + def test_reorder(self): + array1 = math.astensor(np.arange(8).reshape((1, 2, 2, 2))) + fock1 = Fock(array1, batched=True) + fock2 = fock1.reorder(order=(2, 1, 0)) + assert np.allclose(fock2.array, np.array([[[[0, 4], [2, 6]], [[1, 5], [3, 7]]]])) + assert np.allclose(fock2.array, np.arange(8).reshape((1, 2, 2, 2), order="F")) + + @pytest.mark.parametrize("batched", [True, False]) + def test_reduce(self, batched): + shape = (1, 3, 3, 3) if batched else (3, 3, 3) + array1 = math.astensor(np.arange(27).reshape(shape)) + fock1 = Fock(array1, batched=batched) + + fock2 = fock1.reduce(3) + assert fock1 == fock2 + + fock3 = fock1.reduce(2) + array3 = math.astensor([[[0, 1], [3, 4]], [[9, 10], [12, 13]]]) + assert fock3 == Fock(array3) + + fock4 = fock1.reduce((1, 3, 1)) + array4 = math.astensor([[[0], [3], [6]]]) + assert fock4 == Fock(array4) + + def test_reduce_error(self): + array1 = math.astensor(np.arange(27).reshape((3, 3, 3))) + fock1 = Fock(array1) + + with pytest.raises(ValueError, match="Expected shape"): + fock1.reduce((1, 2)) + + with pytest.raises(ValueError, match="Expected shape"): + fock1.reduce((1, 2, 3, 4, 5)) + + def test_reduce_padded(self): + fock = Fock(self.array578) + with pytest.warns(UserWarning): + fock1 = fock.reduce((8, 8, 8)) + assert fock1.array.shape == (1, 8, 8, 8) + + # @pytest.mark.parametrize("shape", [(1, 8), (1, 8, 8)]) + # @patch("mrmustard.physics.representations.display") + # def test_ipython_repr(self, mock_display, shape): + # """Test the IPython repr function.""" + # rep = Fock(np.random.random(shape), batched=True) + # rep._ipython_display_() # pylint:disable=protected-access + # [hbox] = mock_display.call_args.args + # assert isinstance(hbox, HBox) + + # # the CSS, the header+ansatz, and the tabs of plots + # [css, left, plots] = hbox.children + # assert isinstance(css, HTML) + # assert isinstance(left, VBox) + # assert isinstance(plots, Tab) + + # # left contains header and ansatz + # left = left.children + # assert len(left) == 2 and all(isinstance(w, HTML) for w in left) + + # # one plot for magnitude, another for phase + # assert plots.titles == ("Magnitude", "Phase") + # plots = plots.children + # assert len(plots) == 2 and all(isinstance(p, FigureWidget) for p in plots) + + # @patch("mrmustard.physics.representations.display") + # def test_ipython_repr_expects_batch_1(self, mock_display): + # """Test the IPython repr function does nothing with real batch.""" + # rep = Fock(np.random.random((2, 8)), batched=True) + # rep._ipython_display_() # pylint:disable=protected-access + # mock_display.assert_not_called() + + # @patch("mrmustard.physics.representations.display") + # def test_ipython_repr_expects_3_dims_or_less(self, mock_display): + # """Test the IPython repr function does nothing with 4+ dims.""" + # rep = Fock(np.random.random((1, 4, 4, 4)), batched=True) + # rep._ipython_display_() # pylint:disable=protected-access + # mock_display.assert_not_called() From 61a83cffa63fa90d3bb29813abf4167aa849c1c5 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 10:35:45 -0400 Subject: [PATCH 05/87] fock working --- mrmustard/physics/representations/fock.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mrmustard/physics/representations/fock.py b/mrmustard/physics/representations/fock.py index ff7809612..1ef05f9b8 100644 --- a/mrmustard/physics/representations/fock.py +++ b/mrmustard/physics/representations/fock.py @@ -201,7 +201,7 @@ def sum_batch(self) -> Fock: Returns: The collapsed Fock object. """ - return Fock(math.sum(self.array, axes=[0]), batched=True) + return Fock(math.expand_dims(math.sum(self.array, axes=[0]), 0), batched=True) def to_dict(self) -> dict[str, ArrayLike]: """Serialize a Fock instance.""" @@ -316,7 +316,7 @@ def __getitem__(self, idx: int | tuple[int, ...]) -> Fock: raise IndexError( f"Index {i} out of bounds for representation with {self.num_vars} variables." ) - ret = Fock(self.array) + ret = Fock(self.array, batched=True) ret._contract_idxs = idx return ret From 9f60a21597328c12a5aa2a4d28c85675f25847e0 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 10:39:28 -0400 Subject: [PATCH 06/87] Representation --- mrmustard/physics/representations/bargmann.py | 13 ++--- mrmustard/physics/representations/fock.py | 51 ++++++++++++++++--- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/mrmustard/physics/representations/bargmann.py b/mrmustard/physics/representations/bargmann.py index ca8b15831..e9f56560b 100644 --- a/mrmustard/physics/representations/bargmann.py +++ b/mrmustard/physics/representations/bargmann.py @@ -50,10 +50,12 @@ from mrmustard.utils.argsort import argsort_gen +from .base import Representation + __all__ = ["Bargmann"] -class Bargmann: +class Bargmann(Representation): r""" The Fock-Bargmann representation of a broad class of quantum states, transformations, measurements, channels, etc. @@ -146,22 +148,17 @@ def __init__( b: Batch[ComplexVector], c: Batch[ComplexTensor] = 1.0, name: str = "", - batched: bool = False, ): if A is None and b is None and c is not None: raise ValueError("Please provide either A or b.") - # Representation base class - self._contract_idxs: tuple[int, ...] = () - self._fn = None - self._kwargs = {} - self.name = name - + super().__init__() self._A = A self._b = b self._c = c self._backends = [False, False, False] self._simplified = False + self.name = name @property def A(self) -> Batch[ComplexMatrix]: diff --git a/mrmustard/physics/representations/fock.py b/mrmustard/physics/representations/fock.py index 1ef05f9b8..be546851c 100644 --- a/mrmustard/physics/representations/fock.py +++ b/mrmustard/physics/representations/fock.py @@ -39,15 +39,54 @@ __all__ = ["Fock"] -class Fock: - r""" """ +class Fock(Representation): + r""" + The Fock representation of a broad class of quantum states, transformations, measurements, + channels, etc. - def __init__(self, array: Batch[Tensor], batched=False): + The ansatz available in this representation is ``ArrayAnsatz``. + + This function allows for vector space operations on Fock objects including + linear combinations, outer product (``&``), and inner product (``@``). + + .. code-block:: + + >>> from mrmustard.physics.representations import Fock + + >>> # initialize Fock objects + >>> array1 = np.random.random((5,7,8)) + >>> array2 = np.random.random((5,7,8)) + >>> array3 = np.random.random((3,5,7,8)) # where 3 is the batch. + >>> fock1 = Fock(array1) + >>> fock2 = Fock(array2) + >>> fock3 = Fock(array3, batched=True) + + >>> # linear combination can be done with the same batch dimension + >>> fock4 = 1.3 * fock1 - fock2 * 2.1 + + >>> # division by a scalar + >>> fock5 = fock1 / 1.3 - self._fn = None - self._kwargs = {} - self._contract_idxs: tuple[int, ...] = () + >>> # inner product by contracting on marked indices + >>> fock6 = fock1[2] @ fock3[2] + >>> # outer product (tensor product) + >>> fock7 = fock1 & fock3 + + >>> # conjugation + >>> fock8 = fock1.conj() + + Args: + array: the (batched) array in Fock representation. + batched: whether the array input has a batch dimension. + + Note: The args can be passed non-batched, as they will be automatically broadcasted to the + correct batch shape. + + """ + + def __init__(self, array: Batch[Tensor], batched=False): + super().__init__() self._array = array if batched else [array] self._backend_array = False self._original_abc_data = None From 803cd88f26b03690ea939b873e00fc1f51fdf26c Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 10:52:56 -0400 Subject: [PATCH 07/87] tests passing --- mrmustard/lab_dev/circuit_components.py | 18 +- .../circuit_components_utils/trace_out.py | 2 +- mrmustard/lab_dev/states/base.py | 22 +- mrmustard/lab_dev/transformations/base.py | 10 +- mrmustard/physics/representations/bargmann.py | 21 +- mrmustard/physics/representations/fock.py | 2 +- tests/test_lab_dev/test_circuit_components.py | 28 +- .../test_circuit_components_utils.py | 2 +- .../test_lab_dev/test_states/test_coherent.py | 6 +- .../test_transformations/test_cft.py | 2 +- .../test_transformations_base.py | 4 +- tests/test_physics/test_ansatz.py | 612 ------------------ .../test_representations/test_bargmann.py | 2 +- 13 files changed, 60 insertions(+), 671 deletions(-) delete mode 100644 tests/test_physics/test_ansatz.py diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index f88f71448..fad1c79d2 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -163,7 +163,7 @@ def adjoint(self) -> CircuitComponent: """ bras = self.wires.bra.indices kets = self.wires.ket.indices - rep = self.representation.reorder(kets + bras).conj() if self.representation else None + rep = self.representation.reorder(kets + bras).conj if self.representation else None ret = CircuitComponent(rep, self.wires.adjoint, self.name) ret.short_name = self.short_name @@ -180,7 +180,7 @@ def dual(self) -> CircuitComponent: ik = self.wires.ket.input.indices ib = self.wires.bra.input.indices ob = self.wires.bra.output.indices - rep = self.representation.reorder(ib + ob + ik + ok).conj() if self.representation else None + rep = self.representation.reorder(ib + ob + ik + ok).conj if self.representation else None ret = CircuitComponent(rep, self.wires.dual, self.name) ret.short_name = self.short_name @@ -457,7 +457,7 @@ def bargmann_triple( """ try: A, b, c = self.representation.triple - if not batched and self.representation.ansatz.batch_size == 1: + if not batched and self.representation.batch_size == 1: return A[0], b[0], c[0] else: return A, b, c @@ -478,7 +478,7 @@ def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> Compl Returns: array: The Fock representation of this component. """ - num_vars = self.representation.ansatz.num_vars + num_vars = self.representation.num_vars if isinstance(shape, int): shape = (shape,) * num_vars try: @@ -488,7 +488,7 @@ def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> Compl raise ValueError( f"Expected Fock shape of length {num_vars}, got length {len(shape)}" ) - if self.representation.ansatz.polynomial_shape[0] == 0: + if self.representation.polynomial_shape[0] == 0: arrays = [math.hermite_renormalized(A, b, c, shape) for A, b, c in zip(As, bs, cs)] else: arrays = [ @@ -575,9 +575,9 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: """ fock = Fock(self.fock(shape, batched=True), batched=True) try: - fock.ansatz._original_abc_data = self.representation.triple + fock._original_abc_data = self.representation.triple except AttributeError: - fock.ansatz._original_abc_data = None + fock._original_abc_data = None try: ret = self._getitem_builtin(self.modes) ret._representation = fock @@ -607,8 +607,8 @@ def to_bargmann(self) -> CircuitComponent: if isinstance(self.representation, Bargmann): return self else: - if self.representation.ansatz._original_abc_data: - A, b, c = self.representation.ansatz._original_abc_data + if self.representation._original_abc_data: + A, b, c = self.representation._original_abc_data else: A, b, _ = identity_Abc(len(self.wires.quantum)) c = self.representation.data diff --git a/mrmustard/lab_dev/circuit_components_utils/trace_out.py b/mrmustard/lab_dev/circuit_components_utils/trace_out.py index 0b42332d2..39beecaed 100644 --- a/mrmustard/lab_dev/circuit_components_utils/trace_out.py +++ b/mrmustard/lab_dev/circuit_components_utils/trace_out.py @@ -83,7 +83,7 @@ def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitCompon repr = other.representation wires = other.wires elif not ket or not bra: - repr = other.representation.conj()[idx_z] @ other.representation[idx_z] + repr = other.representation.conj[idx_z] @ other.representation[idx_z] wires, _ = (other.wires.adjoint @ other.wires)[0] @ self.wires else: repr = other.representation.trace(idx_z, idx_zconj) diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index d23a28b5c..f9df4a19f 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -618,9 +618,9 @@ def __init__( representation: Bargmann | Fock | None = None, name: str | None = None, ): - if representation and representation.ansatz.num_vars != 2 * len(modes): + if representation and representation.num_vars != 2 * len(modes): raise ValueError( - f"Expected a representation with {2*len(modes)} variables, found {representation.ansatz.num_vars}." + f"Expected a representation with {2*len(modes)} variables, found {representation.num_vars}." ) super().__init__( wires=[modes, (), modes, ()], @@ -634,7 +634,7 @@ def is_positive(self) -> bool: r""" Whether this DM is a positive operator. """ - batch_dim = self.representation.ansatz.batch_size + batch_dim = self.representation.batch_size if batch_dim > 1: raise ValueError( "Physicality conditions are not implemented for batch dimension larger than 1." @@ -752,11 +752,11 @@ def auto_shape( respect_manual_shape: Whether to respect the non-None values in ``manual_shape``. """ # experimental: - if self.representation.ansatz.batch_size == 1: + if self.representation.batch_size == 1: try: # fock shape = self._representation.array.shape[1:] except AttributeError: # bargmann - if self.representation.ansatz.polynomial_shape[0] == 0: + if self.representation.polynomial_shape[0] == 0: repr = self.representation A, b, c = repr.A[0], repr.b[0], repr.c[0] repr = repr / self.probability @@ -928,9 +928,9 @@ def __init__( representation: Bargmann | Fock | None = None, name: str | None = None, ): - if representation and representation.ansatz.num_vars != len(modes): + if representation and representation.num_vars != len(modes): raise ValueError( - f"Expected a representation with {len(modes)} variables, found {representation.ansatz.num_vars}." + f"Expected a representation with {len(modes)} variables, found {representation.num_vars}." ) super().__init__( wires=[(), (), modes, ()], @@ -944,7 +944,7 @@ def is_physical(self) -> bool: r""" Whether the ket object is a physical one. """ - batch_dim = self.representation.ansatz.batch_size + batch_dim = self.representation.batch_size if batch_dim > 1: raise ValueError( "Physicality conditions are not implemented for batch dimension larger than 1." @@ -1047,12 +1047,12 @@ def auto_shape( respect_manual_shape: Whether to respect the non-None values in ``manual_shape``. """ # experimental: - if self.representation.ansatz.batch_size == 1: + if self.representation.batch_size == 1: try: # fock shape = self._representation.array.shape[1:] except AttributeError: # bargmann - if self.representation.ansatz.polynomial_shape[0] == 0: - repr = self.representation.conj() & self.representation + if self.representation.polynomial_shape[0] == 0: + repr = self.representation.conj & self.representation A, b, c = repr.A[0], repr.b[0], repr.c[0] repr = repr / self.probability shape = autoshape_numba( diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index 2367c9623..7c6f35d1b 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -95,11 +95,11 @@ def inverse(self) -> Transformation: ) if not isinstance(self.representation, Bargmann): raise NotImplementedError("Only Bargmann representation is supported.") - if self.representation.ansatz.batch_size > 1: + if self.representation.batch_size > 1: raise NotImplementedError("Batched transformations are not supported.") # compute the inverse - A, b, _ = self.dual.representation.conj().triple # apply X(.)X + A, b, _ = self.dual.representation.conj.triple # apply X(.)X almost_inverse = self._from_attributes( Bargmann(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 + 0j), self.wires ) @@ -181,7 +181,7 @@ def symplectic(self): r""" Returns the symplectic matrix that corresponds to this unitary """ - batch_size = self.representation.ansatz.batch_size + batch_size = self.representation.batch_size return [au2Symplectic(self.representation.A[batch, :, :]) for batch in range(batch_size)] @classmethod @@ -280,7 +280,7 @@ def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Channel: U = Unitary.random(range(3 * m), max_r) u_psi = Vacuum(range(2 * m)) >> U A = u_psi.representation - kraus = A.conj()[range(2 * m)] @ A[range(2 * m)] + kraus = A.conj[range(2 * m)] @ A[range(2 * m)] return Channel.from_bargmann(modes, modes, kraus.triple) @property @@ -288,7 +288,7 @@ def is_CP(self) -> bool: r""" Whether this channel is completely positive (CP). """ - batch_dim = self.representation.ansatz.batch_size + batch_dim = self.representation.batch_size if batch_dim > 1: raise ValueError( "Physicality conditions are not implemented for batch dimension larger than 1." diff --git a/mrmustard/physics/representations/bargmann.py b/mrmustard/physics/representations/bargmann.py index e9f56560b..8ec6645bc 100644 --- a/mrmustard/physics/representations/bargmann.py +++ b/mrmustard/physics/representations/bargmann.py @@ -103,7 +103,7 @@ class Bargmann(Representation): .. code-block :: - >>> trace = (rep_coh @ rep_coh.conj()).trace([0], [1]) + >>> trace = (rep_coh @ rep_coh.conj).trace([0], [1]) >>> assert np.allclose(trace.A, 0) >>> assert np.allclose(trace.b, 0) >>> assert trace.c == 1 @@ -215,6 +215,15 @@ def c(self, value): self._c = value self._backends[2] = False + @property + def conj(self): + r""" + The conjugate of this Bargmann object. + """ + ret = Bargmann(math.conj(self.A), math.conj(self.b), math.conj(self.c)) + ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access + return ret + @property def data( self, @@ -276,14 +285,6 @@ def from_function(cls, fn: Callable, **kwargs: Any) -> Bargmann: ret._kwargs = kwargs return ret - def conj(self): - r""" - The conjugate of this Bargmann object. - """ - ret = Bargmann(math.conj(self.A), math.conj(self.b), math.conj(self.c)) - ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access - return ret - def decompose_ansatz(self) -> Bargmann: r""" This method decomposes a Bargmann representation. Given a representation of dimensions: @@ -991,7 +992,7 @@ def mul_c(c1, c2): except Exception as e: raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e - def __rmul__(self, other: any | Scalar) -> any: + def __rmul__(self, other: Bargmann | Scalar) -> Bargmann: r""" Multiplies this representation by another or by a scalar on the right. """ diff --git a/mrmustard/physics/representations/fock.py b/mrmustard/physics/representations/fock.py index be546851c..344f9e8f6 100644 --- a/mrmustard/physics/representations/fock.py +++ b/mrmustard/physics/representations/fock.py @@ -74,7 +74,7 @@ class Fock(Representation): >>> fock7 = fock1 & fock3 >>> # conjugation - >>> fock8 = fock1.conj() + >>> fock8 = fock1.conj Args: array: the (batched) array in Fock representation. diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 30651083a..2eba61a0e 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -122,7 +122,7 @@ def test_adjoint(self): assert d1_adj.name == d1.name assert d1_adj.wires == d1.wires.adjoint assert ( - d1_adj.representation == d1.representation.conj() + d1_adj.representation == d1.representation.conj ) # this holds for the Dgate but not in general d1_adj_adj = d1_adj.adjoint @@ -470,19 +470,19 @@ def test_quadrature_channel(self): back = Channel.from_quadrature([0], [0], C.quadrature_triple()) assert C == back - @pytest.mark.parametrize("is_fock,widget_cls", [(False, Box), (True, HBox)]) - @patch("mrmustard.lab_dev.circuit_components.display") - def test_ipython_repr(self, mock_display, is_fock, widget_cls): - """Test the IPython repr function.""" - dgate = Dgate([1], x=0.1, y=0.1) - if is_fock: - dgate = dgate.to_fock() - dgate._ipython_display_() # pylint:disable=protected-access - [box] = mock_display.call_args.args - assert isinstance(box, Box) - [wires_widget, rep_widget] = box.children - assert isinstance(wires_widget, HTML) - assert type(rep_widget) is widget_cls + # @pytest.mark.parametrize("is_fock,widget_cls", [(False, Box), (True, HBox)]) + # @patch("mrmustard.lab_dev.circuit_components.display") + # def test_ipython_repr(self, mock_display, is_fock, widget_cls): + # """Test the IPython repr function.""" + # dgate = Dgate([1], x=0.1, y=0.1) + # if is_fock: + # dgate = dgate.to_fock() + # dgate._ipython_display_() # pylint:disable=protected-access + # [box] = mock_display.call_args.args + # assert isinstance(box, Box) + # [wires_widget, rep_widget] = box.children + # assert isinstance(wires_widget, HTML) + # assert type(rep_widget) is widget_cls @patch("mrmustard.lab_dev.circuit_components.display") def test_ipython_repr_invalid_obj(self, mock_display): diff --git a/tests/test_lab_dev/test_circuit_components_utils.py b/tests/test_lab_dev/test_circuit_components_utils.py index 3433e6196..83f6eeedb 100644 --- a/tests/test_lab_dev/test_circuit_components_utils.py +++ b/tests/test_lab_dev/test_circuit_components_utils.py @@ -261,6 +261,6 @@ def wavefunction_coh(alpha, quad, axis_angle): quad = np.random.random() state = Coherent([0], x, y) - wavefunction = (state >> BtoQ([0], axis_angle)).representation.ansatz + wavefunction = (state >> BtoQ([0], axis_angle)).representation assert np.allclose(wavefunction(quad), wavefunction_coh(x + 1j * y, quad, axis_angle)) diff --git a/tests/test_lab_dev/test_states/test_coherent.py b/tests/test_lab_dev/test_states/test_coherent.py index a2c7b4454..fc5c70870 100644 --- a/tests/test_lab_dev/test_states/test_coherent.py +++ b/tests/test_lab_dev/test_states/test_coherent.py @@ -86,11 +86,11 @@ def test_linear_combinations(self): state3 = Coherent([0], x=3, y=4) lc = state1 + state2 - state3 - assert lc.representation.ansatz.batch_size == 3 + assert lc.representation.batch_size == 3 - assert (lc @ lc.dual).representation.ansatz.batch_size == 9 + assert (lc @ lc.dual).representation.batch_size == 9 settings.UNSAFE_ZIP_BATCH = True - assert (lc @ lc.dual).representation.ansatz.batch_size == 3 # not 9 + assert (lc @ lc.dual).representation.batch_size == 3 # not 9 settings.UNSAFE_ZIP_BATCH = False def test_vacuum_shape(self): diff --git a/tests/test_lab_dev/test_transformations/test_cft.py b/tests/test_lab_dev/test_transformations/test_cft.py index 478cd45a9..d89fde6aa 100644 --- a/tests/test_lab_dev/test_transformations/test_cft.py +++ b/tests/test_lab_dev/test_transformations/test_cft.py @@ -44,7 +44,7 @@ def test_wigner_function(self): vec = np.linspace(-5, 5, 100) wigner, _, _ = wigner_discretized(dm, vec, vec) - Wigner = (state >> CFT([0]).inverse() >> BtoPS([0], s=0)).representation.ansatz + Wigner = (state >> CFT([0]).inverse() >> BtoPS([0], s=0)).representation X, Y = np.meshgrid( vec * np.sqrt(2 / settings.HBAR), vec * np.sqrt(2 / settings.HBAR) ) # scaling to take care of HBAR diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 74238eae0..8c2c01249 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -183,7 +183,7 @@ def test_random(self): @pytest.mark.parametrize("modes", [[0], [0, 1], [0, 1, 2]]) def test_is_CP(self, modes): u = Unitary.random(modes).representation - kraus = u @ u.conj() + kraus = u @ u.conj assert Channel.from_bargmann(modes, modes, kraus.triple).is_CP def test_is_TP(self): @@ -195,7 +195,7 @@ def test_is_physical(self): def test_XY(self): U = Unitary.random([0, 1]) u = U.representation - unitary_channel = Channel.from_bargmann([0, 1], [0, 1], (u.conj() @ u).triple) + unitary_channel = Channel.from_bargmann([0, 1], [0, 1], (u.conj @ u).triple) X, Y = unitary_channel.XY assert np.allclose(X, U.symplectic) and np.allclose(Y, np.zeros(4)) diff --git a/tests/test_physics/test_ansatz.py b/tests/test_physics/test_ansatz.py deleted file mode 100644 index 7b8b7ae9d..000000000 --- a/tests/test_physics/test_ansatz.py +++ /dev/null @@ -1,612 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains tests for ``Ansatz`` objects.""" - -# pylint: disable = missing-function-docstring, pointless-statement, comparison-with-itself - -import numpy as np -import pytest - -from mrmustard import math -from mrmustard.physics.ansatze import ( - PolyExpAnsatz, - ArrayAnsatz, - bargmann_Abc_to_phasespace_cov_means, -) -from mrmustard.lab_dev.states.base import DM -from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho -from mrmustard.lab_dev.circuit_components_utils import BtoPS -from ..random import Abc_triple - - -class TestArrayAnsatz: - r"""Tests all algebra related to ArrayAnsatz.""" - - def test_init_(self): - array = np.random.random((2, 4, 5)) - aa = ArrayAnsatz(array=array) - assert isinstance(aa, ArrayAnsatz) - assert np.allclose(aa.array, array) - - def test_neg(self): - array = np.random.random((2, 4, 5)) - aa = ArrayAnsatz(array=array) - minusaa = -aa - assert isinstance(minusaa, ArrayAnsatz) - assert np.allclose(minusaa.array, -array) - - def test_equal(self): - array = np.random.random((2, 4, 5)) - aa1 = ArrayAnsatz(array=array) - aa2 = ArrayAnsatz(array=array) - assert aa1 == aa2 - - def test_add(self): - array = np.arange(8).reshape(2, 2, 2) - array2 = np.arange(8).reshape(2, 2, 2) - aa1 = ArrayAnsatz(array=array) - aa2 = ArrayAnsatz(array=array2) - aa1_add_aa2 = aa1 + aa2 - - assert isinstance(aa1_add_aa2, ArrayAnsatz) - assert aa1_add_aa2.array.shape == (4, 2, 2) - assert np.allclose(aa1_add_aa2.array[0], np.array([[0, 2], [4, 6]])) - assert np.allclose(aa1_add_aa2.array[1], np.array([[4, 6], [8, 10]])) - assert np.allclose(aa1_add_aa2.array[2], np.array([[4, 6], [8, 10]])) - assert np.allclose(aa1_add_aa2.array[3], np.array([[8, 10], [12, 14]])) - - array3 = np.arange(2).reshape(2, 1, 1) - aa3 = ArrayAnsatz(array=array3) - aa2_add_aa3 = aa2 + aa3 - - assert isinstance(aa2_add_aa3, ArrayAnsatz) - assert np.allclose(aa2_add_aa3.array[0], np.array([[0, 1], [2, 3]])) - assert np.allclose(aa2_add_aa3.array[1], np.array([[1, 1], [2, 3]])) - assert np.allclose(aa2_add_aa3.array[2], np.array([[4, 5], [6, 7]])) - assert np.allclose(aa2_add_aa3.array[3], np.array([[5, 5], [6, 7]])) - - array4 = np.array([1]).reshape(1, 1, 1) - aa4 = ArrayAnsatz(array=array4) - aa2_add_aa4 = aa2 + aa4 - - assert isinstance(aa2_add_aa4, ArrayAnsatz) - assert np.allclose(aa2_add_aa4.array[0], np.array([[1, 1], [2, 3]])) - assert np.allclose(aa2_add_aa4.array[1], np.array([[5, 5], [6, 7]])) - - def test_and(self): - array = np.arange(8).reshape(2, 2, 2) - array2 = np.arange(8).reshape(2, 2, 2) - aa1 = ArrayAnsatz(array=array) - aa2 = ArrayAnsatz(array=array2) - aa1_and_aa2 = aa1 & aa2 - assert isinstance(aa1_and_aa2, ArrayAnsatz) - assert aa1_and_aa2.array.shape == (4, 2, 2, 2, 2) - assert np.allclose( - aa1_and_aa2.array[0], - np.array( - [ - [[[0, 0], [0, 0]], [[0, 1], [2, 3]]], - [[[0, 2], [4, 6]], [[0, 3], [6, 9]]], - ] - ), - ) - assert np.allclose( - aa1_and_aa2.array[1], - np.array( - [ - [[[0, 0], [0, 0]], [[4, 5], [6, 7]]], - [[[8, 10], [12, 14]], [[12, 15], [18, 21]]], - ] - ), - ) - assert np.allclose( - aa1_and_aa2.array[2], - np.array( - [ - [[[0, 4], [8, 12]], [[0, 5], [10, 15]]], - [[[0, 6], [12, 18]], [[0, 7], [14, 21]]], - ] - ), - ) - assert np.allclose( - aa1_and_aa2.array[3], - np.array( - [ - [[[16, 20], [24, 28]], [[20, 25], [30, 35]]], - [[[24, 30], [36, 42]], [[28, 35], [42, 49]]], - ] - ), - ) - - def test_mul_a_scalar(self): - array = np.random.random((2, 4, 5)) - aa1 = ArrayAnsatz(array=array) - aa1_scalar = aa1 * 8 - assert isinstance(aa1_scalar, ArrayAnsatz) - assert np.allclose(aa1_scalar.array, array * 8) - - def test_mul(self): - array = np.arange(8).reshape(2, 2, 2) - array2 = np.arange(8).reshape(2, 2, 2) - aa1 = ArrayAnsatz(array=array) - aa2 = ArrayAnsatz(array=array2) - aa1_mul_aa2 = aa1 * aa2 - assert isinstance(aa1_mul_aa2, ArrayAnsatz) - assert aa1_mul_aa2.array.shape == (4, 2, 2) - assert np.allclose(aa1_mul_aa2.array[0], np.array([[0, 1], [4, 9]])) - assert np.allclose(aa1_mul_aa2.array[1], np.array([[0, 5], [12, 21]])) - assert np.allclose(aa1_mul_aa2.array[2], np.array([[0, 5], [12, 21]])) - assert np.allclose(aa1_mul_aa2.array[3], np.array([[16, 25], [36, 49]])) - - array3 = np.arange(2).reshape(2, 1, 1) - aa3 = ArrayAnsatz(array=array3) - aa2_mul_aa3 = aa2 * aa3 - assert isinstance(aa2_mul_aa3, ArrayAnsatz) - assert aa2_mul_aa3.array.shape == (4, 2, 2) - assert np.allclose(aa2_mul_aa3.array[0], np.array([[0, 0], [0, 0]])) - assert np.allclose(aa2_mul_aa3.array[1], np.array([[0, 0], [0, 0]])) - assert np.allclose(aa2_mul_aa3.array[2], np.array([[0, 0], [0, 0]])) - assert np.allclose(aa2_mul_aa3.array[3], np.array([[4, 0], [0, 0]])) - - array4 = np.array([1]).reshape(1, 1, 1) - aa4 = ArrayAnsatz(array=array4) - aa2_mul_aa4 = aa2 * aa4 - - assert isinstance(aa2_mul_aa4, ArrayAnsatz) - assert np.allclose(aa2_mul_aa4.array[0], np.array([[0, 0], [0, 0]])) - assert np.allclose(aa2_mul_aa4.array[1], np.array([[4, 0], [0, 0]])) - - def test_truediv_a_scalar(self): - array = np.random.random((2, 4, 5)) - aa1 = ArrayAnsatz(array=array) - aa1_scalar = aa1 / 6 - assert isinstance(aa1_scalar, ArrayAnsatz) - assert np.allclose(aa1_scalar.array, array / 6) - - def test_div(self): - array = np.arange(9)[1:].reshape(2, 2, 2) - array2 = np.arange(9)[1:].reshape(2, 2, 2) - aa1 = ArrayAnsatz(array=array) - aa2 = ArrayAnsatz(array=array2) - aa1_div_aa2 = aa1 / aa2 - assert isinstance(aa1_div_aa2, ArrayAnsatz) - assert aa1_div_aa2.array.shape == (4, 2, 2) - assert np.allclose(aa1_div_aa2.array[0], np.array([[1.0, 1.0], [1.0, 1.0]])) - assert np.allclose(aa1_div_aa2.array[1], np.array([[0.2, 0.33333], [0.42857143, 0.5]])) - assert np.allclose(aa1_div_aa2.array[2], np.array([[5.0, 3.0], [2.33333333, 2.0]])) - assert np.allclose(aa1_div_aa2.array[3], np.array([[1.0, 1.0], [1.0, 1.0]])) - - array3 = np.arange(3)[1:].reshape(2, 1, 1) - aa3 = ArrayAnsatz(array=array3) - aa3_div_aa2 = aa3 / aa2 - assert isinstance(aa3_div_aa2, ArrayAnsatz) - assert aa3_div_aa2.array.shape == (4, 2, 2) - assert np.allclose(aa3_div_aa2.array[0], np.array([[1, 0], [0, 0]])) - assert np.allclose(aa3_div_aa2.array[1], np.array([[0.2, 0], [0, 0]])) - assert np.allclose(aa3_div_aa2.array[2], np.array([[2, 0], [0, 0]])) - assert np.allclose(aa3_div_aa2.array[3], np.array([[0.4, 0], [0, 0]])) - - array4 = np.array([2]).reshape(1, 1, 1) - aa4 = ArrayAnsatz(array=array4) - aa4_div_aa2 = aa4 / aa2 - - assert isinstance(aa4_div_aa2, ArrayAnsatz) - assert np.allclose(aa4_div_aa2.array[0], np.array([[2, 0], [0, 0]])) - assert np.allclose(aa4_div_aa2.array[1], np.array([[0.4, 0], [0, 0]])) - - def test_algebra_with_different_shape_of_array_raise_errors(self): - array = np.random.random((2, 4, 5)) - array2 = np.random.random((3, 4, 8, 9)) - aa1 = ArrayAnsatz(array=array) - aa2 = ArrayAnsatz(array=array2) - - with pytest.raises(Exception): - aa1 + aa2 - - with pytest.raises(Exception): - aa1 - aa2 - - with pytest.raises(Exception): - aa1 * aa2 - - with pytest.raises(Exception): - aa1 / aa2 - - with pytest.raises(Exception): - aa1 == aa2 - - def test_bargmann_Abc_to_phasespace_cov_means(self): - # The init state cov and means comes from the random state 'state = Gaussian(1) >> Dgate([0.2], [0.3])' - state_cov = np.array([[0.32210229, -0.99732956], [-0.99732956, 6.1926484]]) - state_means = np.array([0.2, 0.3]) - state = DM.from_bargmann([0], wigner_to_bargmann_rho(state_cov, state_means)) - state_after = state >> BtoPS(modes=[0], s=0) # pylint: disable=protected-access - A1, b1, c1 = state_after.bargmann_triple() - ( - new_state_cov, - new_state_means, - new_state_coeff, - ) = bargmann_Abc_to_phasespace_cov_means(A1, b1, c1) - assert np.allclose(state_cov, new_state_cov) - assert np.allclose(state_means, new_state_means) - assert np.allclose(1.0 / (2 * np.pi), new_state_coeff) - - state_cov = np.array( - [ - [1.00918303, -0.33243548, 0.15202393, -0.07540124], - [-0.33243548, 1.2203162, -0.03961978, 0.30853472], - [0.15202393, -0.03961978, 1.11158673, 0.28786279], - [-0.07540124, 0.30853472, 0.28786279, 0.97833402], - ] - ) - state_means = np.array([0.4, 0.6, 0.0, 0.0]) - A, b, c = wigner_to_bargmann_rho(state_cov, state_means) - state = DM.from_bargmann(modes=[0, 1], triple=(A, b, c)) - - state_after = state >> BtoPS(modes=[0, 1], s=0) # pylint: disable=protected-access - A1, b1, c1 = state_after.bargmann_triple() - ( - new_state_cov1, - new_state_means1, - new_state_coeff1, - ) = bargmann_Abc_to_phasespace_cov_means(A1, b1, c1) - - A22, b22, c22 = (state >> BtoPS([0], 0) >> BtoPS([1], 0)).bargmann_triple() - ( - new_state_cov22, - new_state_means22, - new_state_coeff22, - ) = bargmann_Abc_to_phasespace_cov_means(A22, b22, c22) - assert math.allclose(new_state_cov22, state_cov) - assert math.allclose(new_state_cov1, state_cov) - assert math.allclose(new_state_means1, state_means) - assert math.allclose(new_state_means22, state_means) - assert math.allclose(new_state_coeff1, 1 / (2 * np.pi) ** 2) - assert math.allclose(new_state_coeff22, 1 / (2 * np.pi) ** 2) - - -class TestPolyExpAnsatz: - r""" - Tests the ``PolyExpAnsatz`` class. - """ - - Abc_n1 = Abc_triple(1) - Abc_n2 = Abc_triple(2) - Abc_n3 = Abc_triple(3) - - @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) - def test_init(self, triple): - A, b, c = triple - ansatz = PolyExpAnsatz(A, b, c) - - assert np.allclose(ansatz.mat[0], A) - assert np.allclose(ansatz.vec[0], b) - assert np.allclose(ansatz.array[0], c) - - def test_add(self): - A1, b1, _ = Abc_triple(5) - c1 = np.random.random(size=(1, 3, 3)) - A2, b2, _ = Abc_triple(5) - c2 = np.random.random(size=(1, 2, 2)) - - ansatz = PolyExpAnsatz(A1, b1, c1) - ansatz2 = PolyExpAnsatz(A2, b2, c2) - - ansatz3 = ansatz + ansatz2 - - assert np.allclose(ansatz3.mat[0], A1) - assert np.allclose(ansatz3.vec[0], b1) - assert np.allclose(ansatz3.array[0], c1[0]) - assert np.allclose(ansatz3.mat[1], A2) - assert np.allclose(ansatz3.vec[1], b2) - assert np.allclose(ansatz3.array[1][:2, :2], c2[0]) - - def test_mul(self): - A1, b1, _ = Abc_triple(2) - c1 = np.random.random(size=(1, 4)) - A2, b2, _ = Abc_triple(2) - c2 = np.random.random(size=(1, 4)) - - ansatz = PolyExpAnsatz(A1, b1, c1) - ansatz2 = PolyExpAnsatz(A2, b2, c2) - ansatz3 = ansatz * ansatz2 - - A3 = np.block( - [ - [ - A1[:1, :1] + A2[:1, :1], - A1[:1, 1:], - A2[:1, 1:], - ], - [ - A1[1:, :1], - A1[1:, 1:], - math.zeros((1, 1), dtype=np.complex128), - ], - [ - A2[1:, :1], - np.zeros((1, 1), dtype=np.complex128), - A2[1:, 1:], - ], - ] - ) - b3 = np.concatenate((b1[:1] + b2[:1], b1[1:], b2[1:])) - c3 = np.outer(c1, c2).reshape(4, 4) - assert np.allclose(ansatz3.mat[0], A3) - assert np.allclose(ansatz3.vec[0], b3) - assert np.allclose(ansatz3.array[0], c3) - - def test_mul_scalar(self): - A, b, c = Abc_triple(5) - d = 0.1 - - ansatz = PolyExpAnsatz(A, b, c) - - ansatz2 = ansatz * d - - assert np.allclose(ansatz2.mat[0], A) - assert np.allclose(ansatz2.vec[0], b) - assert np.allclose(ansatz2.array[0], d * c) - - def test_truediv(self): - A1, b1, c1 = Abc_triple(5) - A2, b2, c2 = Abc_triple(5) - - ansatz = PolyExpAnsatz(A1, b1, c1) - ansatz2 = PolyExpAnsatz(A2, b2, c2) - - ansatz3 = ansatz / ansatz2 - - assert np.allclose(ansatz3.mat[0], A1 - A2) - assert np.allclose(ansatz3.vec[0], b1 - b2) - assert np.allclose(ansatz3.array[0], c1 / c2) - - def test_truediv_scalar(self): - A, b, c = Abc_triple(5) - d = 0.1 - - ansatz = PolyExpAnsatz(A, b, c) - - ansatz2 = ansatz / d - - assert np.allclose(ansatz2.mat[0], A) - assert np.allclose(ansatz2.vec[0], b) - assert np.allclose(ansatz2.array[0], c / d) - - def test_call(self): - A, b, c = Abc_triple(5) - ansatz = PolyExpAnsatz(A, b, c) - - assert np.allclose(ansatz(z=math.zeros_like(b)), c) - - A, b, _ = Abc_triple(4) - c = np.random.random(size=(1, 3, 3, 3)) - ansatz = PolyExpAnsatz(A, b, c) - z = np.random.uniform(-10, 10, size=(7, 2)) - with pytest.raises( - Exception, match="The sum of the dimension of the argument and polynomial" - ): - ansatz(z) - - A = np.array([[0, 1], [1, 0]]) - b = np.zeros(2) - c = c = np.zeros(10, dtype=complex).reshape(1, -1) - c[0, -1] = 1 - obj1 = PolyExpAnsatz(A, b, c) - - nine_factorial = np.prod(np.arange(1, 9)) - assert np.allclose(obj1(np.array([[0.1]])), 0.1**9 / np.sqrt(nine_factorial)) - - def test_and(self): - A1, b1, _ = Abc_triple(6) - c1 = np.random.random(size=(1, 4, 4)) - A2, b2, _ = Abc_triple(6) - c2 = np.random.random(size=(1, 4, 4)) - - ansatz = PolyExpAnsatz(A1, b1, c1) - ansatz2 = PolyExpAnsatz(A2, b2, c2) - - ansatz3 = ansatz & ansatz2 - - A3 = np.block( - [ - [ - A1[:4, :4], - np.zeros((4, 4), dtype=complex), - A1[:4, 4:], - np.zeros((4, 2), dtype=complex), - ], - [ - np.zeros((4, 4), dtype=complex), - A2[:4:, :4], - np.zeros((4, 2), dtype=complex), - A2[:4, 4:], - ], - [ - A1[4:, :4], - np.zeros((2, 4), dtype=complex), - A1[4:, 4:], - np.zeros((2, 2), dtype=complex), - ], - [ - np.zeros((2, 4), dtype=complex), - A2[4:, :4], - np.zeros((2, 2), dtype=complex), - A2[4:, 4:], - ], - ] - ) - b3 = np.concatenate((b1[:4], b2[:4], b1[4:], b2[4:])) - c3 = np.outer(c1, c2).reshape(4, 4, 4, 4) - assert np.allclose(ansatz3.mat[0], A3) - assert np.allclose(ansatz3.vec[0], b3) - assert np.allclose(ansatz3.array[0], c3) - - def test_eq(self): - A, b, c = Abc_triple(5) - - ansatz = PolyExpAnsatz(A, b, c) - ansatz2 = PolyExpAnsatz(2 * A, 2 * b, 2 * c) - - assert ansatz == ansatz - assert ansatz2 == ansatz2 - assert ansatz != ansatz2 - assert ansatz2 != ansatz - - def test_simplify(self): - A, b, c = Abc_triple(5) - - ansatz = PolyExpAnsatz(A, b, c) - - ansatz = ansatz + ansatz - - assert np.allclose(ansatz.A[0], ansatz.A[1]) - assert np.allclose(ansatz.A[0], A) - assert np.allclose(ansatz.b[0], ansatz.b[1]) - assert np.allclose(ansatz.b[0], b) - - ansatz.simplify() - assert len(ansatz.A) == 1 - assert len(ansatz.b) == 1 - assert ansatz.c == 2 * c - - def test_simplify_v2(self): - A, b, c = Abc_triple(5) - - ansatz = PolyExpAnsatz(A, b, c) - - ansatz = ansatz + ansatz - - assert np.allclose(ansatz.A[0], ansatz.A[1]) - assert np.allclose(ansatz.A[0], A) - assert np.allclose(ansatz.b[0], ansatz.b[1]) - assert np.allclose(ansatz.b[0], b) - - ansatz.simplify_v2() - assert len(ansatz.A) == 1 - assert len(ansatz.b) == 1 - assert np.allclose(ansatz.c, 2 * c) - - def test_order_batch(self): - ansatz = PolyExpAnsatz( - A=[np.array([[0]]), np.array([[1]])], - b=[np.array([1]), np.array([0])], - c=[1, 2], - ) - ansatz._order_batch() # pylint: disable=protected-access - - assert np.allclose(ansatz.A[0], np.array([[1]])) - assert np.allclose(ansatz.b[0], np.array([0])) - assert ansatz.c[0] == 2 - assert np.allclose(ansatz.A[1], np.array([[0]])) - assert np.allclose(ansatz.b[1], np.array([1])) - assert ansatz.c[1] == 1 - - def test_polynomial_shape(self): - A, b, _ = Abc_triple(4) - c = np.array([[1, 2, 3]]) - ansatz = PolyExpAnsatz(A, b, c) - - poly_dim, poly_shape = ansatz.polynomial_shape - assert np.allclose(poly_dim, 1) - assert np.allclose(poly_shape, (3,)) - - A1, b1, _ = Abc_triple(4) - c1 = np.array([[1, 2, 3]]) - ansatz1 = PolyExpAnsatz(A1, b1, c1) - - A2, b2, _ = Abc_triple(4) - c2 = np.array([[1, 2, 3]]) - ansatz2 = PolyExpAnsatz(A2, b2, c2) - - ansatz3 = ansatz1 * ansatz2 - - poly_dim, poly_shape = ansatz3.polynomial_shape - assert np.allclose(poly_dim, 2) - assert np.allclose(poly_shape, (3, 3)) - - def test_decompose_ansatz(self): - A, b, _ = Abc_triple(4) - c = np.random.uniform(-10, 10, size=(1, 3, 3, 3)) - ansatz = PolyExpAnsatz(A, b, c) - - decomp_ansatz = ansatz.decompose_ansatz() - z = np.random.uniform(-10, 10, size=(1, 1)) - assert np.allclose(ansatz(z), decomp_ansatz(z)) - assert np.allclose(decomp_ansatz.A.shape, (1, 2, 2)) - - def test_decompose_ansatz_batch(self): - """ - In this test the batch dimension of both ``z`` and ``Abc`` is tested. - """ - A1, b1, _ = Abc_triple(4) - c1 = np.random.uniform(-10, 10, size=(3, 3, 3)) - A2, b2, _ = Abc_triple(4) - c2 = np.random.uniform(-10, 10, size=(3, 3, 3)) - ansatz = PolyExpAnsatz([A1, A2], [b1, b2], [c1, c2]) - - decomp_ansatz = ansatz.decompose_ansatz() - z = np.random.uniform(-10, 10, size=(3, 1)) - assert np.allclose(ansatz(z), decomp_ansatz(z)) - assert np.allclose(decomp_ansatz.A.shape, (2, 2, 2)) - assert np.allclose(decomp_ansatz.b.shape, (2, 2)) - assert np.allclose(decomp_ansatz.c.shape, (2, 9)) - - A1, b1, _ = Abc_triple(5) - c1 = np.random.uniform(-10, 10, size=(3, 3, 3)) - A2, b2, _ = Abc_triple(5) - c2 = np.random.uniform(-10, 10, size=(3, 3, 3)) - ansatz = PolyExpAnsatz([A1, A2], [b1, b2], [c1, c2]) - - decomp_ansatz = ansatz.decompose_ansatz() - z = np.random.uniform(-10, 10, size=(3, 2)) - assert np.allclose(ansatz(z), decomp_ansatz(z)) - assert np.allclose(decomp_ansatz.A.shape, (2, 4, 4)) - assert np.allclose(decomp_ansatz.b.shape, (2, 4)) - assert np.allclose(decomp_ansatz.c.shape, (2, 9, 9)) - - def test_call_none(self): - A1, b1, _ = Abc_triple(7) - A2, b2, _ = Abc_triple(7) - A3, b3, _ = Abc_triple(7) - - batch = 3 - c = np.random.random(size=(batch, 5, 5, 5)) / 1000 - - obj = PolyExpAnsatz([A1, A2, A3], [b1, b2, b3], c) - z0 = np.array([[None, 2, None, 5]]) - z1 = np.array([[1, 2, 4, 5]]) - z2 = np.array([[1, 4]]) - obj_none = obj(z0) - val1 = obj(z1) - val2 = obj_none(z2) - assert np.allclose(val1, val2) - - obj1 = PolyExpAnsatz(A1, b1, c[0].reshape(1, 5, 5, 5)) - z0 = np.array([[None, 2, None, 5], [None, 1, None, 4]]) - z1 = np.array([[1, 2, 4, 5], [2, 1, 4, 4]]) - z2 = np.array([[1, 4], [2, 4]]) - obj1_none = obj1(z0) - obj1_none0 = PolyExpAnsatz( - obj1_none.A[0], obj1_none.b[0], obj1_none.c[0].reshape(1, 5, 5, 5) - ) - obj1_none1 = PolyExpAnsatz( - obj1_none.A[1], obj1_none.b[1], obj1_none.c[1].reshape(1, 5, 5, 5) - ) - val1 = obj1(z1) - val2 = np.array( - (obj1_none0(z2[0].reshape(1, -1)), obj1_none1(z2[1].reshape(1, -1))) - ).reshape(-1) - assert np.allclose(val1, val2) diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py index 8bfd2fa3f..843cdf992 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -56,7 +56,7 @@ def test_init_non_batched(self, triple): @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) def test_conj(self, triple): A, b, c = triple - bargmann = Bargmann(*triple).conj() + bargmann = Bargmann(*triple).conj assert np.allclose(bargmann.A, math.conj(A)) assert np.allclose(bargmann.b, math.conj(b)) From 3474c851675d6d8223ddebdd59d215fd36975768 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 11:14:30 -0400 Subject: [PATCH 08/87] too-many-instance-attributes --- mrmustard/physics/representations/bargmann.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mrmustard/physics/representations/bargmann.py b/mrmustard/physics/representations/bargmann.py index 8ec6645bc..1f9390511 100644 --- a/mrmustard/physics/representations/bargmann.py +++ b/mrmustard/physics/representations/bargmann.py @@ -55,6 +55,7 @@ __all__ = ["Bargmann"] +# pylint: disable=too-many-instance-attributes class Bargmann(Representation): r""" The Fock-Bargmann representation of a broad class of quantum states, transformations, From 014a07b49993c998d11409f5682780ac59f4b63a Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 11:15:48 -0400 Subject: [PATCH 09/87] protected-access --- mrmustard/physics/representations/fock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mrmustard/physics/representations/fock.py b/mrmustard/physics/representations/fock.py index 344f9e8f6..4b5272c0b 100644 --- a/mrmustard/physics/representations/fock.py +++ b/mrmustard/physics/representations/fock.py @@ -120,7 +120,7 @@ def conj(self): The conjugate of this ansatz. """ ret = Fock(math.conj(self.array), batched=True) - ret._contract_idxs = self._contract_idxs + ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access return ret @property From 807f555ed78f0c4898586a48ea859243c4c89607 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 11:48:05 -0400 Subject: [PATCH 10/87] some cleanup --- mrmustard/physics/representations/bargmann.py | 57 +---------------- mrmustard/physics/representations/base.py | 63 +++++++++++++++++-- mrmustard/physics/representations/fock.py | 45 +------------ 3 files changed, 62 insertions(+), 103 deletions(-) diff --git a/mrmustard/physics/representations/bargmann.py b/mrmustard/physics/representations/bargmann.py index 1f9390511..8e6b18a20 100644 --- a/mrmustard/physics/representations/bargmann.py +++ b/mrmustard/physics/representations/bargmann.py @@ -195,9 +195,6 @@ def b(self, value): @property def batch_size(self): - r""" - The batch size of this representation. - """ return self.c.shape[0] @property @@ -218,9 +215,6 @@ def c(self, value): @property def conj(self): - r""" - The conjugate of this Bargmann object. - """ ret = Bargmann(math.conj(self.A), math.conj(self.b), math.conj(self.c)) ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access return ret @@ -229,16 +223,10 @@ def conj(self): def data( self, ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: - r""" - The data of the representation. - """ return self.triple @property def num_vars(self): - r""" - The number of variables in this ansatz. - """ return self.A.shape[-1] - self.polynomial_shape[0] @property @@ -254,9 +242,6 @@ def polynomial_shape(self) -> tuple[int, tuple]: @property def scalar(self) -> Batch[ComplexTensor]: - r""" - The scalar part of the representation. - """ if self.polynomial_shape[0] > 0: return self([]) else: @@ -266,21 +251,14 @@ def scalar(self) -> Batch[ComplexTensor]: def triple( self, ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: - r""" - The batch of triples :math:`(A_i, b_i, c_i)`. - """ return self.A, self.b, self.c @classmethod def from_dict(cls, data: dict[str, ArrayLike]) -> Bargmann: - """Deserialize a Bargmann instance.""" return cls(**data) @classmethod def from_function(cls, fn: Callable, **kwargs: Any) -> Bargmann: - r""" - Returns a Bargmann object from a generator function. - """ ret = cls(None, None, None) ret._fn = fn ret._kwargs = kwargs @@ -374,26 +352,6 @@ def plot( return fig, ax def reorder(self, order: tuple[int, ...] | list[int]) -> Bargmann: - r""" - Reorders the indices of the ``A`` matrix and ``b`` vector of the ``(A, b, c)`` triple in - this Bargmann object. - - .. code-block:: - - >>> from mrmustard.physics.representations import Bargmann - >>> from mrmustard.physics.triples import displacement_gate_Abc - - >>> rep_dgate1 = Bargmann(*displacement_gate_Abc([0.1, 0.2, 0.3])) - >>> rep_dgate2 = Bargmann(*displacement_gate_Abc([0.2, 0.3, 0.1])) - - >>> assert rep_dgate1.reorder([1, 2, 0, 4, 5, 3]) == rep_dgate2 - - Args: - order: The new order. - - Returns: - The reordered Bargmann object. - """ A, b, c = reorder_abc(self.triple, order) return Bargmann(A, b, c) @@ -444,23 +402,12 @@ def simplify_v2(self) -> None: self._simplified = True def to_dict(self) -> dict[str, ArrayLike]: - """Serialize a Bargmann instance.""" return {"A": self.A, "b": self.b, "c": self.c} - def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> Bargmann: - r""" - The partial trace over the given index pairs. - - Args: - idx_z: The first part of the pairs of indices to trace over. - idx_zconj: The second part. - - Returns: - Bargmann: the ansatz with the given indices traced over - """ + def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Bargmann: A, b, c = [], [], [] for Abc in zip(self.A, self.b, self.c): - Aij, bij, cij = complex_gaussian_integral(Abc, idx_z, idx_zconj, measure=-1.0) + Aij, bij, cij = complex_gaussian_integral(Abc, idxs1, idxs2, measure=-1.0) A.append(Aij) b.append(bij) c.append(cij) diff --git a/mrmustard/physics/representations/base.py b/mrmustard/physics/representations/base.py index 798354715..b8a6d1f66 100644 --- a/mrmustard/physics/representations/base.py +++ b/mrmustard/physics/representations/base.py @@ -14,13 +14,15 @@ """ -This module contains the classes for the available representations. +This module contains the base representation class. """ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, Callable +from numpy.typing import ArrayLike + from mrmustard.utils.typing import ( Batch, ComplexMatrix, @@ -36,10 +38,6 @@ class Representation(ABC): r""" A base class for representations. - - Representations can be initialized using the ``from_ansatz`` method, which automatically equips - them with all the functionality required to perform mathematical operations, such as equality, - multiplication, subtraction, etc. """ def __init__(self) -> None: @@ -47,6 +45,20 @@ def __init__(self) -> None: self._fn = None self._kwargs = {} + @property + @abstractmethod + def batch_size(self) -> int: + r""" + The batch size of the representation. + """ + + @property + @abstractmethod + def conj(self) -> Representation: + r""" + The conjugate of the representation. + """ + @property @abstractmethod def data(self) -> tuple | Tensor: @@ -55,6 +67,13 @@ def data(self) -> tuple | Tensor: For now, it's the triple for Bargmann and the array for Fock. """ + @property + @abstractmethod + def num_vars(self) -> int: + r""" + The number of variables in the representation. + """ + @property @abstractmethod def scalar(self) -> Scalar: @@ -72,6 +91,14 @@ def triple( The batch of triples :math:`(A_i, b_i, c_i)`. """ + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, ArrayLike]) -> Representation: + r""" + Deserialize a Representation. + """ + + @classmethod @abstractmethod def from_function(cls, fn: Callable, **kwargs: Any) -> Representation: r""" @@ -83,3 +110,29 @@ def reorder(self, order: tuple[int, ...] | list[int]) -> Representation: r""" Reorders the representation indices. """ + + @abstractmethod + def to_dict(self) -> dict[str, ArrayLike]: + r""" + Serialize a Representation. + """ + + @abstractmethod + def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Representation: + r""" + Implements the partial trace over the given index pairs. + + Args: + idxs1: The first part of the pairs of indices to trace over. + idxs2: The second part. + + Returns: + The traced-over representation. + """ + + @abstractmethod + def _generate_ansatz(self): + r""" + This method computes and sets data given a function + and some kwargs. + """ diff --git a/mrmustard/physics/representations/fock.py b/mrmustard/physics/representations/fock.py index 4b5272c0b..1e9015d06 100644 --- a/mrmustard/physics/representations/fock.py +++ b/mrmustard/physics/representations/fock.py @@ -14,7 +14,7 @@ """ -This module contains the classes for the available representations. +This module contains the Fock representation. """ from __future__ import annotations @@ -94,7 +94,7 @@ def __init__(self, array: Batch[Tensor], batched=False): @property def array(self) -> Batch[Tensor]: r""" - The array of this ansatz. + The array of this representation. """ self._generate_ansatz() if not self._backend_array: @@ -109,32 +109,20 @@ def array(self, value): @property def batch_size(self): - r""" - The batch size of this ansatz. - """ return self.array.shape[0] @property def conj(self): - r""" - The conjugate of this ansatz. - """ ret = Fock(math.conj(self.array), batched=True) ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access return ret @property def data(self) -> Batch[Tensor]: - r""" - The data of the representation. - """ return self.array @property def num_vars(self) -> int: - r""" - The number of variables in this ansatz. - """ return len(self.array.shape) - 1 @property @@ -159,14 +147,10 @@ def triple(self) -> tuple: @classmethod def from_dict(cls, data: dict[str, ArrayLike]) -> Fock: - """Deserialize a Fock instance.""" return cls(data["array"], batched=True) @classmethod def from_function(cls, fn: Callable, **kwargs: Any) -> Fock: - r""" - Returns a Fock object from a generator function. - """ ret = cls(None, True) ret._fn = fn ret._kwargs = kwargs @@ -221,16 +205,6 @@ def reduce(self, shape: int | Sequence[int]) -> Fock: return Fock(array=ret, batched=True) def reorder(self, order: tuple[int, ...] | list[int]) -> Fock: - r""" - Reorders the indices of the array with the given order. - - Args: - order: The order. Does not need to refer to the batch dimension. - - Returns: - The reordered Fock. - """ - return Fock(math.transpose(self.array, [0] + [i + 1 for i in order]), batched=True) def sum_batch(self) -> Fock: @@ -243,20 +217,9 @@ def sum_batch(self) -> Fock: return Fock(math.expand_dims(math.sum(self.array, axes=[0]), 0), batched=True) def to_dict(self) -> dict[str, ArrayLike]: - """Serialize a Fock instance.""" return {"array": self.data} def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Fock: - r""" - Implements the partial trace over the given index pairs. - - Args: - idxs1: The first part of the pairs of indices to trace over. - idxs2: The second part. - - Returns: - The traced-over Fock object. - """ if len(idxs1) != len(idxs2) or not set(idxs1).isdisjoint(idxs2): raise ValueError("idxs must be of equal length and disjoint") order = ( @@ -272,10 +235,6 @@ def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Fock: return Fock([trace] if trace.shape == () else trace, batched=True) def _generate_ansatz(self): - r""" - This method computes and sets the array given a function - and some kwargs. - """ if self._array is None: self.array = [self._fn(**self._kwargs)] From a2289f62d4d76404bf3ecacf27461eaad699af8a Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 12:13:38 -0400 Subject: [PATCH 11/87] representation done --- mrmustard/physics/representations/bargmann.py | 75 +----------- mrmustard/physics/representations/base.py | 109 ++++++++++++++++++ mrmustard/physics/representations/fock.py | 107 +---------------- 3 files changed, 113 insertions(+), 178 deletions(-) diff --git a/mrmustard/physics/representations/bargmann.py b/mrmustard/physics/representations/bargmann.py index 8e6b18a20..853f1e25b 100644 --- a/mrmustard/physics/representations/bargmann.py +++ b/mrmustard/physics/representations/bargmann.py @@ -150,9 +150,6 @@ def __init__( c: Batch[ComplexTensor] = 1.0, name: str = "", ): - if A is None and b is None and c is not None: - raise ValueError("Please provide either A or b.") - super().__init__() self._A = A self._b = b @@ -801,13 +798,7 @@ def __call__(self, z: Batch[Vector]) -> Scalar | Bargmann: def __eq__(self, other: Bargmann) -> bool: return self._equal_no_array(other) and np.allclose(self.c, other.c, atol=1e-10) - def __neg__(self) -> Bargmann: - return Bargmann(self.A, self.b, -self.c) - def __getitem__(self, idx: int | tuple[int, ...]) -> Bargmann: - r""" - A copy of self with the given indices marked for contraction. - """ idx = (idx,) if isinstance(idx, int) else idx for i in idx: if i >= self.num_vars: @@ -819,29 +810,6 @@ def __getitem__(self, idx: int | tuple[int, ...]) -> Bargmann: return ret def __matmul__(self, other: Bargmann) -> Bargmann: - r""" - Implements the inner product in Bargmann representation. - - ..code-block:: - - >>> from mrmustard.physics.representations import Bargmann - >>> from mrmustard.physics.triples import displacement_gate_Abc, vacuum_state_Abc - >>> rep1 = Bargmann(*vacuum_state_Abc(1)) - >>> rep2 = Bargmann(*displacement_gate_Abc(1)) - >>> rep3 = rep1[0] @ rep2[1] - >>> assert np.allclose(rep3.A, [[0,],]) - >>> assert np.allclose(rep3.b, [1,]) - - Args: - other: Another Bargmann representation. - - Returns: - Bargmann: the resulting Bargmann representation. - - """ - if not isinstance(other, Bargmann): - raise NotImplementedError("Only matmul Bargmann with Bargmann") - idx_s = self._contract_idxs idx_o = other._contract_idxs @@ -864,19 +832,6 @@ def __matmul__(self, other: Bargmann) -> Bargmann: return Bargmann(A, b, c) def __mul__(self, other: Scalar | Bargmann) -> Bargmann: - r"""Multiplies this representation by a scalar or another Bargmann representation. - - Args: - other: A scalar or another Bargmann representation. - - Raises: - TypeError: If other is neither a scalar nor a Bargmann representation. - - Returns: - Bargmann: The product of this representation and other. - - """ - def mul_A(A1, A2, dim_alpha, dim_beta1, dim_beta2): A3 = math.block( [ @@ -940,36 +895,10 @@ def mul_c(c1, c2): except Exception as e: raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e - def __rmul__(self, other: Bargmann | Scalar) -> Bargmann: - r""" - Multiplies this representation by another or by a scalar on the right. - """ - return self.__mul__(other) - - def __sub__(self, other): - r""" - Subtracts other from this representation. - """ - try: - return self.__add__(-other) - except AttributeError as e: - raise TypeError(f"Cannot subtract {self.__class__} and {other.__class__}.") from e + def __neg__(self) -> Bargmann: + return Bargmann(self.A, self.b, -self.c) def __truediv__(self, other: Scalar | Bargmann) -> Bargmann: - r""" - Multiplies this Bargmann by a scalar or another Bargmann. - - Args: - other: A scalar or another Bargmann. - - Raises: - TypeError: If other is neither a scalar nor a Bargmann. - - Returns: - Bargmann: The product of this Bargmann and other. - - """ - def div_A(A1, A2, dim_alpha, dim_beta1, dim_beta2): A3 = math.block( [ diff --git a/mrmustard/physics/representations/base.py b/mrmustard/physics/representations/base.py index b8a6d1f66..74adf755a 100644 --- a/mrmustard/physics/representations/base.py +++ b/mrmustard/physics/representations/base.py @@ -30,6 +30,7 @@ ComplexVector, Scalar, Tensor, + Vector, ) __all__ = ["Representation"] @@ -136,3 +137,111 @@ def _generate_ansatz(self): This method computes and sets data given a function and some kwargs. """ + + @abstractmethod + def __add__(self, other: Representation) -> Representation: + r""" + Adds this representation and another representation. + + Args: + other: Another representation. + + Returns: + The addition of this representation and other. + """ + + @abstractmethod + def __and__(self, other: Representation) -> Representation: + r""" + Tensor product of this representation with another. + + Args: + other: Another representation. + + Returns: + The tensor product of this representation and other. + """ + + @abstractmethod + def __call__(self, z: Batch[Vector]) -> Scalar | Representation: + r""" + Evaluates this representation at a given point in the domain. + + Args: + z: point in C^n where the function is evaluated + + Returns: + The value of the function if ``z`` has no ``None``, else it returns a new ansatz. + """ + + @abstractmethod + def __eq__(self, other: Representation) -> bool: + r""" + Whether this representation is equal to another. + """ + + @abstractmethod + def __getitem__(self, idx: int | tuple[int, ...]) -> Representation: + r""" + Returns a copy of self with the given indices marked for contraction. + """ + + @abstractmethod + def __matmul__(self, other: Representation) -> Representation: + r""" + Implements the inner product of representations over the marked indices. + + Args: + other: Another representation. + + Returns: + The resulting representation. + """ + + @abstractmethod + def __mul__(self, other: Scalar | Representation) -> Representation: + r""" + Multiplies this representation by a scalar or another representation. + + Args: + other: A scalar or another representation. + + Raises: + TypeError: If other is neither a scalar nor a representation. + + Returns: + The product of this representation and other. + """ + + @abstractmethod + def __neg__(self) -> Representation: + r""" + Negates the values in the representation. + """ + + def __rmul__(self, other: Representation | Scalar) -> Representation: + r""" + Multiplies this representation by another or by a scalar on the right. + """ + return self.__mul__(other) + + def __sub__(self, other: Representation) -> Representation: + r""" + Subtracts other from this representation. + """ + try: + return self.__add__(-other) + except AttributeError as e: + raise TypeError(f"Cannot subtract {self.__class__} and {other.__class__}.") from e + + @abstractmethod + def __truediv__(self, other: Scalar | Representation) -> Representation: + r""" + Divides this representation by another representation. + + Args: + other: A scalar or another representation. + + Returns: + The division of this representation and other. + """ diff --git a/mrmustard/physics/representations/fock.py b/mrmustard/physics/representations/fock.py index 1e9015d06..b248bcdb9 100644 --- a/mrmustard/physics/representations/fock.py +++ b/mrmustard/physics/representations/fock.py @@ -28,11 +28,7 @@ from IPython.display import display from mrmustard import math, widgets -from mrmustard.utils.typing import ( - Batch, - Scalar, - Tensor, -) +from mrmustard.utils.typing import Batch, Scalar, Tensor, Vector from .base import Representation @@ -246,18 +242,6 @@ def _ipython_display_(self): display(w) def __add__(self, other: Fock) -> Fock: - r""" - Adds the array of this Fock representation and the array of another Fock representation. - - Args: - other: Another Fock representation. - - Raises: - ValueError: If the arrays don't have the same shape. - - Returns: - ArrayAnsatz: The addition of this representation and other. - """ try: diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) if diff < 0: @@ -273,41 +257,19 @@ def __add__(self, other: Fock) -> Fock: raise TypeError(f"Cannot add {self.__class__} and {other.__class__}.") from e def __and__(self, other: Fock) -> Fock: - r""" - Tensor product of this Fock representation with another Fock representation. - - Args: - other: Another Fock representation. - - Returns: - The tensor product of this representation and other. - Batch size is the product of two batches. - """ new_array = [math.outer(a, b) for a in self.array for b in other.array] return Fock(array=new_array, batched=True) - def __call__(self, point: Any) -> Scalar: - r""" - Evaluates this representation at a given point in the domain. - """ + def __call__(self, z: Batch[Vector]) -> Scalar: raise AttributeError("Cannot call Fock.") def __eq__(self, other: Representation) -> bool: - r""" - Whether this ansatz's array is equal to another ansatz's array. - - Note that the comparison is done by numpy allclose with numpy's default rtol and atol. - - """ slices = (slice(0, None),) + tuple( slice(0, min(si, oi)) for si, oi in zip(self.array.shape[1:], other.array.shape[1:]) ) return np.allclose(self.array[slices], other.array[slices], atol=1e-10) def __getitem__(self, idx: int | tuple[int, ...]) -> Fock: - r""" - Returns a copy of self with the given indices marked for contraction. - """ idx = (idx,) if isinstance(idx, int) else idx for i in idx: if i >= self.num_vars: @@ -319,29 +281,6 @@ def __getitem__(self, idx: int | tuple[int, ...]) -> Fock: return ret def __matmul__(self, other: Fock) -> Fock: - r""" - Implements the inner product of fock arrays over the marked indices. - - .. code-block:: - >>> from mrmustard.physics.representations import Fock - >>> f = Fock(np.random.random((3, 5, 10))) # 10 is reduced to 8 - >>> g = Fock(np.random.random((2, 5, 8))) - >>> h = f[1,2] @ g[1,2] - >>> assert h.array.shape == (1,3,2) # batch size is 1 - >>> f = Fock(np.random.random((3, 5, 10)), batched=True) - >>> g = Fock(np.random.random((2, 5, 8)), batched=True) - >>> h = f[0,1] @ g[0,1] - >>> assert h.array.shape == (6,) # batch size is 3 x 2 = 6 - - Args: - other: Another representation. - - Returns: - A ``Fock``representation. - """ - if not isinstance(other, Fock): - raise NotImplementedError("only matmul Fock with Fock") - idx_s = list(self._contract_idxs) idx_o = list(other._contract_idxs) @@ -370,18 +309,6 @@ def __matmul__(self, other: Fock) -> Fock: return Fock(batched_array, batched=True) def __mul__(self, other: Scalar | Fock) -> Fock: - r""" - Multiplies this Fock representation by another Fock representation. - - Args: - other: A scalar or another Fock representation. - - Raises: - ValueError: If both of array don't have the same shape. - - Returns: - ArrayAnsatz: The product of this representation and other. - """ if isinstance(other, Fock): try: diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) @@ -406,39 +333,9 @@ def __mul__(self, other: Scalar | Fock) -> Fock: return ret def __neg__(self) -> Fock: - r""" - Negates the values in the array. - """ return Fock(array=-self.array, batched=True) - def __rmul__(self, other: Fock | Scalar) -> Fock: - r""" - Multiplies this representation by another or by a scalar on the right. - """ - return self.__mul__(other) - - def __sub__(self, other: Fock) -> Fock: - r""" - Subtracts other from this ansatz. - """ - try: - return self.__add__(-other) - except AttributeError as e: - raise TypeError(f"Cannot subtract {self.__class__} and {other.__class__}.") from e - def __truediv__(self, other: Scalar | Fock) -> Fock: - r""" - Divides this Fock representation by another Fock representation. - - Args: - other: A scalar or another Fock representation. - - Raises: - ValueError: If the arrays don't have the same shape. - - Returns: - ArrayAnsatz: The division of this representation and other. - """ if isinstance(other, Fock): try: diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) From 74d4cef15ddc3f0ea4a1195a9d55f7c41e6517e3 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 12:23:39 -0400 Subject: [PATCH 12/87] merge --- mrmustard/lab_dev/transformations/dgate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index a7640f3ef..64504755f 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -25,7 +25,7 @@ from .base import Unitary from ...physics.representations import Bargmann -from ...physics import triples, fock +from ...physics import triples, fock_utils from ..utils import make_parameter, reshape_params __all__ = ["Dgate"] @@ -113,7 +113,7 @@ def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTenso array: The Fock representation of this component. """ if isinstance(shape, int): - shape = (shape,) * self.representation.ansatz.num_vars + shape = (shape,) * self.representation.num_vars auto_shape = self.auto_shape() shape = shape or auto_shape if len(shape) != len(auto_shape): @@ -129,9 +129,9 @@ def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTenso Ud = None for idx, out_in in enumerate(zip(shape[:N], shape[N:])): if Ud is None: - Ud = fock.displacement(x[idx], y[idx], shape=out_in) + Ud = fock_utils.displacement(x[idx], y[idx], shape=out_in) else: - U_next = fock.displacement(x[idx], y[idx], shape=out_in) + U_next = fock_utils.displacement(x[idx], y[idx], shape=out_in) Ud = math.outer(Ud, U_next) array = math.transpose( @@ -139,6 +139,6 @@ def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTenso list(range(0, 2 * N, 2)) + list(range(1, 2 * N, 2)), ) else: - array = fock.displacement(x[0], y[0], shape=shape) + array = fock_utils.displacement(x[0], y[0], shape=shape) arrays = math.expand_dims(array, 0) if batched else array return arrays From 3bf79a3f37937ec10c146a60a34f8e046649166c Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 13:45:19 -0400 Subject: [PATCH 13/87] codefactor --- mrmustard/physics/fock_utils.py | 6 +++--- tests/test_physics/test_fock_utils.py | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mrmustard/physics/fock_utils.py b/mrmustard/physics/fock_utils.py index 0d4f7c83a..d18f98343 100644 --- a/mrmustard/physics/fock_utils.py +++ b/mrmustard/physics/fock_utils.py @@ -68,14 +68,14 @@ def fock_state(n: Sequence[int], cutoffs: int | Sequence[int] | None = None) -> msg = f"Expected ``len(cutoffs)={len(n)}`` but found ``{len(cutoffs)}``." raise ValueError(msg) - shape = tuple([c + 1 for c in cutoffs]) + shape = tuple(c + 1 for c in cutoffs) array = np.zeros(shape, dtype=np.complex128) try: array[tuple(n)] = 1 - except IndexError: + except IndexError as e: msg = "Photon numbers cannot be larger than the corresponding cutoffs." - raise ValueError(msg) + raise ValueError(msg) from e return math.astensor(array) diff --git a/tests/test_physics/test_fock_utils.py b/tests/test_physics/test_fock_utils.py index 95e32ddc5..d1cf3295b 100644 --- a/tests/test_physics/test_fock_utils.py +++ b/tests/test_physics/test_fock_utils.py @@ -47,6 +47,9 @@ def test_fock_state(): + r""" + Tests that the `fock_state` method gives expected values. + """ n = [4, 5, 6] array1 = fock_utils.fock_state(n) @@ -63,6 +66,9 @@ def test_fock_state(): def test_fock_state_error(): + r""" + Tests that the `fock_state` method handles errors as expected. + """ n = [4, 5] with pytest.raises(ValueError): From 809173890e3f2357426cb5f40fcc8c086398c31d Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 10 Sep 2024 13:48:32 -0400 Subject: [PATCH 14/87] doc --- mrmustard/physics/representations/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mrmustard/physics/representations/__init__.py b/mrmustard/physics/representations/__init__.py index 34c0933bb..a75d9469e 100644 --- a/mrmustard/physics/representations/__init__.py +++ b/mrmustard/physics/representations/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. r""" +The classes for representations in circuit components. """ from .base import * From 31bafedf99694523ba118f959e7a91e8f2a801e7 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 30 Sep 2024 16:21:23 -0400 Subject: [PATCH 15/87] move wires to physics --- mrmustard/lab_dev/__init__.py | 1 - mrmustard/lab_dev/circuit_components.py | 2 +- .../lab_dev/circuit_components_utils/branch_and_bound.py | 2 +- mrmustard/lab_dev/states/base.py | 2 +- mrmustard/{lab_dev => physics}/wires.py | 0 tests/test_lab_dev/test_circuit_components.py | 2 +- tests/test_lab_dev/test_circuit_components_utils.py | 2 +- tests/test_lab_dev/test_states/test_states_base.py | 2 +- .../test_transformations/test_transformations_base.py | 2 +- tests/{test_lab_dev => test_physics}/test_wires.py | 4 ++-- 10 files changed, 9 insertions(+), 10 deletions(-) rename mrmustard/{lab_dev => physics}/wires.py (100%) rename tests/{test_lab_dev => test_physics}/test_wires.py (98%) diff --git a/mrmustard/lab_dev/__init__.py b/mrmustard/lab_dev/__init__.py index 4ef839b6b..36f1248ab 100644 --- a/mrmustard/lab_dev/__init__.py +++ b/mrmustard/lab_dev/__init__.py @@ -21,4 +21,3 @@ from .circuits import * from .states import * from .transformations import * -from .wires import Wires diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 607d85d23..e22c5c6ed 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -43,7 +43,7 @@ from mrmustard.physics.fock import quadrature_basis from mrmustard.math.parameter_set import ParameterSet from mrmustard.math.parameters import Constant, Variable -from mrmustard.lab_dev.wires import Wires +from mrmustard.physics.wires import Wires from mrmustard.physics.triples import identity_Abc __all__ = ["CircuitComponent"] diff --git a/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py b/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py index ca44ff142..2f50f1bde 100644 --- a/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py +++ b/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py @@ -24,7 +24,7 @@ import numpy as np from typing import Generator import networkx as nx -from mrmustard.lab_dev.wires import Wires +from mrmustard.physics.wires import Wires from mrmustard.lab_dev.circuit_components import CircuitComponent Edge = tuple[int, int] diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 49c248a43..6460c28f6 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -60,7 +60,7 @@ ) from mrmustard.lab_dev.circuit_components_utils import BtoPS, BtoQ, TraceOut from mrmustard.lab_dev.circuit_components import CircuitComponent -from mrmustard.lab_dev.wires import Wires +from mrmustard.physics.wires import Wires __all__ = ["State", "DM", "Ket"] diff --git a/mrmustard/lab_dev/wires.py b/mrmustard/physics/wires.py similarity index 100% rename from mrmustard/lab_dev/wires.py rename to mrmustard/physics/wires.py diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 876409d84..30b42b948 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -37,7 +37,7 @@ SqueezedVacuum, ) from mrmustard.lab_dev.transformations import Dgate, Attenuator, Unitary, Sgate, Channel -from mrmustard.lab_dev.wires import Wires +from mrmustard.physics.wires import Wires from ..random import Abc_triple diff --git a/tests/test_lab_dev/test_circuit_components_utils.py b/tests/test_lab_dev/test_circuit_components_utils.py index 354059430..36f493c59 100644 --- a/tests/test_lab_dev/test_circuit_components_utils.py +++ b/tests/test_lab_dev/test_circuit_components_utils.py @@ -33,7 +33,7 @@ from mrmustard.lab_dev.circuit_components_utils import TraceOut, BtoPS, BtoQ from mrmustard.lab_dev.circuit_components import CircuitComponent from mrmustard.lab_dev.states import Coherent, DM -from mrmustard.lab_dev.wires import Wires +from mrmustard.physics.wires import Wires # original settings diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index a34a18438..b4d8d195a 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -36,7 +36,7 @@ Vacuum, ) from mrmustard.lab_dev.transformations import Attenuator, Dgate, Sgate -from mrmustard.lab_dev.wires import Wires +from mrmustard.physics.wires import Wires from mrmustard.widgets import state as state_widget # original settings diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 74238eae0..54d0c14ce 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -30,7 +30,7 @@ Unitary, Operation, ) -from mrmustard.lab_dev.wires import Wires +from mrmustard.physics.wires import Wires from mrmustard.lab_dev.states import Vacuum diff --git a/tests/test_lab_dev/test_wires.py b/tests/test_physics/test_wires.py similarity index 98% rename from tests/test_lab_dev/test_wires.py rename to tests/test_physics/test_wires.py index c1d4af695..4a9ca81a5 100644 --- a/tests/test_lab_dev/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -21,7 +21,7 @@ from ipywidgets import HTML import pytest -from mrmustard.lab_dev.wires import Wires +from mrmustard.physics.wires import Wires class TestWires: @@ -213,7 +213,7 @@ def test_matmul_error(self): with pytest.raises(ValueError): u @ v # pylint: disable=pointless-statement - @patch("mrmustard.lab_dev.wires.display") + @patch("mrmustard.physics.wires.display") def test_ipython_repr(self, mock_display): """Test the IPython repr function.""" wires = Wires({0}, {}, {3}, {3, 4}) From a27838c27cb6a9c4c49e12d4335bf0fe13548d6d Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 30 Sep 2024 16:22:04 -0400 Subject: [PATCH 16/87] move wires to physics --- mrmustard/physics/representations.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 38967647f..938a11ff0 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -48,6 +48,8 @@ ) from mrmustard import widgets +from .wires import Wires + __all__ = ["Representation", "Bargmann", "Fock"] From 0f8bfc90e7541103b9ecb0d01fc2831cda9d71dc Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 3 Oct 2024 11:51:50 -0400 Subject: [PATCH 17/87] init --- mrmustard/lab_dev/circuit_components.py | 203 +++++---------- mrmustard/lab_dev/states/base.py | 10 +- mrmustard/lab_dev/states/coherent.py | 5 +- .../lab_dev/states/displaced_squeezed.py | 16 +- mrmustard/lab_dev/states/number.py | 6 +- .../lab_dev/states/quadrature_eigenstate.py | 7 +- mrmustard/lab_dev/states/sauron.py | 8 +- mrmustard/lab_dev/states/squeezed_vacuum.py | 7 +- mrmustard/lab_dev/states/thermal.py | 6 +- .../states/two_mode_squeezed_vacuum.py | 9 +- .../lab_dev/transformations/amplifier.py | 5 +- .../lab_dev/transformations/attenuator.py | 5 +- mrmustard/lab_dev/transformations/bsgate.py | 9 +- mrmustard/lab_dev/transformations/cft.py | 6 +- mrmustard/lab_dev/transformations/dgate.py | 8 +- .../lab_dev/transformations/fockdamping.py | 5 +- mrmustard/lab_dev/transformations/ggate.py | 11 +- mrmustard/lab_dev/transformations/rgate.py | 6 +- mrmustard/lab_dev/transformations/s2gate.py | 7 +- mrmustard/lab_dev/transformations/sgate.py | 7 +- mrmustard/physics/multi_representations.py | 241 ++++++++++++++++++ tests/test_lab_dev/test_circuit_components.py | 5 +- 22 files changed, 400 insertions(+), 192 deletions(-) create mode 100644 mrmustard/physics/multi_representations.py diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index e22c5c6ed..4768a916a 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -44,7 +44,7 @@ from mrmustard.math.parameter_set import ParameterSet from mrmustard.math.parameters import Constant, Variable from mrmustard.physics.wires import Wires -from mrmustard.physics.triples import identity_Abc +from mrmustard.physics.multi_representations import MultiRepresentation __all__ = ["CircuitComponent"] @@ -75,20 +75,17 @@ def __init__( ) -> None: self._name = name self._parameter_set = ParameterSet() - self._representation = representation - if isinstance(wires, Wires): - self._wires = wires - else: - wires = [tuple(elem) for elem in wires] if wires else [(), (), (), ()] - modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket = wires - self._wires = Wires( + if not isinstance(wires, Wires): + modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket = ( + [tuple(elem) for elem in wires] if wires else [(), (), (), ()] + ) + wires = Wires( set(modes_out_bra), set(modes_in_bra), set(modes_out_ket), set(modes_in_ket), ) - # handle out-of-order modes ob = tuple(sorted(modes_out_bra)) ib = tuple(sorted(modes_in_bra)) @@ -107,8 +104,13 @@ def __init__( + tuple(np.argsort(modes_out_ket) + offsets[1]) + tuple(np.argsort(modes_in_ket) + offsets[2]) ) - if self._representation: - self._representation = self._representation.reorder(tuple(perm)) + if representation is not None: + self._multi_rep = MultiRepresentation( + representation.reorder(tuple(perm)), wires + ) + + if not hasattr(self, "_multi_rep"): + self._multi_rep = MultiRepresentation(representation, wires) def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: """ @@ -243,14 +245,14 @@ def representation(self) -> Representation | None: r""" A representation of this circuit component. """ - return self._representation + return self._multi_rep.representation @property def wires(self) -> Wires: r""" The wires of this component. """ - return self._wires + return self._multi_rep.wires @classmethod def from_bargmann( @@ -419,8 +421,7 @@ def _from_attributes( if tp.__name__ in types: ret = tp() ret._name = name - ret._representation = representation - ret._wires = wires + ret._multi_rep = MultiRepresentation(representation, wires) return ret return CircuitComponent(representation, wires, name) @@ -452,14 +453,7 @@ def bargmann_triple( >>> assert isinstance(coh_cc, CircuitComponent) >>> assert coh == coh_cc # equality looks at representation and wires """ - try: - A, b, c = self.representation.triple - if not batched and self.representation.ansatz.batch_size == 1: - return A[0], b[0], c[0] - else: - return A, b, c - except AttributeError as e: - raise AttributeError("No Bargmann data for this component.") from e + return self._multi_rep.bargmann_triple(batched) def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> ComplexTensor: r""" @@ -475,38 +469,7 @@ def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> Compl Returns: array: The Fock representation of this component. """ - num_vars = self.representation.ansatz.num_vars - if isinstance(shape, int): - shape = (shape,) * num_vars - try: - As, bs, cs = self.bargmann_triple(batched=True) - shape = shape or self.auto_shape() - if len(shape) != num_vars: - raise ValueError( - f"Expected Fock shape of length {num_vars}, got length {len(shape)}" - ) - if self.representation.ansatz.polynomial_shape[0] == 0: - arrays = [math.hermite_renormalized(A, b, c, shape) for A, b, c in zip(As, bs, cs)] - else: - arrays = [ - math.sum( - math.hermite_renormalized(A, b, 1, shape + c.shape) * c, - axes=math.arange( - num_vars, num_vars + len(c.shape), dtype=math.int32 - ).tolist(), - ) - for A, b, c in zip(As, bs, cs) - ] - except AttributeError: - shape = shape or self.auto_shape() - if len(shape) != num_vars: - raise ValueError( - f"Expected Fock shape of length {num_vars}, got length {len(shape)}" - ) - arrays = self.representation.reduce(shape).array - array = math.sum(arrays, axes=[0]) - arrays = math.expand_dims(array, 0) if batched else array - return arrays + return self._multi_rep.fock(shape or self.auto_shape(), batched) def on(self, modes: Sequence[int]) -> CircuitComponent: r""" @@ -539,16 +502,46 @@ def on(self, modes: Sequence[int]) -> CircuitComponent: for subset in subsets: if subset and len(subset) != len(modes): raise ValueError(f"Expected ``{len(modes)}`` modes, found ``{len(subset)}``.") - ret = self._light_copy() - ret._wires = Wires( - modes_out_bra=set(modes) if ob else set(), - modes_in_bra=set(modes) if ib else set(), - modes_out_ket=set(modes) if ok else set(), - modes_in_ket=set(modes) if ik else set(), + ret = self._light_copy( + Wires( + modes_out_bra=set(modes) if ob else set(), + modes_in_bra=set(modes) if ib else set(), + modes_out_ket=set(modes) if ok else set(), + modes_in_ket=set(modes) if ik else set(), + ) ) - return ret + def to_bargmann(self) -> CircuitComponent: + r""" + Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation. + .. code-block:: + + >>> from mrmustard.lab_dev import Dgate + >>> from mrmustard.physics.representations import Bargmann + + >>> d = Dgate([1], x=0.1, y=0.1) + >>> d_fock = d.to_fock(shape=3) + >>> d_bargmann = d_fock.to_bargmann() + + + >>> assert d_bargmann.name == d.name + >>> assert d_bargmann.wires == d.wires + >>> assert isinstance(d_bargmann.representation, Bargmann) + """ + if isinstance(self.representation, Bargmann): + return self + else: + mult_rep = self._multi_rep.to_bargmann() + try: + ret = self._getitem_builtin(self.modes) + ret._multi_rep = mult_rep + except TypeError: + ret = self._from_attributes(mult_rep.representation, mult_rep.wires, self.name) + if "manual_shape" in ret.__dict__: + del ret.manual_shape + return ret + def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: r""" Returns a new circuit component with the same attributes as this and a ``Fock`` representation. @@ -570,56 +563,16 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: an ``int``, it is broadcasted to all the dimensions. If ``None``, it defaults to the value of ``AUTOSHAPE_MAX`` in the settings. """ - fock = Fock(self.fock(shape, batched=True), batched=True) - try: - if self.representation.ansatz.polynomial_shape[0] == 0: - fock.ansatz._original_abc_data = self.representation.triple - except AttributeError: - fock.ansatz._original_abc_data = None + mult_rep = self._multi_rep.to_fock(shape or self.auto_shape()) try: ret = self._getitem_builtin(self.modes) - ret._representation = fock + ret._multi_rep = mult_rep except TypeError: - ret = self._from_attributes(fock, self.wires, self.name) + ret = self._from_attributes(mult_rep.representation, mult_rep.wires, self.name) if "manual_shape" in ret.__dict__: del ret.manual_shape return ret - def to_bargmann(self) -> CircuitComponent: - r""" - Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation. - .. code-block:: - - >>> from mrmustard.lab_dev import Dgate - >>> from mrmustard.physics.representations import Bargmann - - >>> d = Dgate([1], x=0.1, y=0.1) - >>> d_fock = d.to_fock(shape=3) - >>> d_bargmann = d_fock.to_bargmann() - - - >>> assert d_bargmann.name == d.name - >>> assert d_bargmann.wires == d.wires - >>> assert isinstance(d_bargmann.representation, Bargmann) - """ - if isinstance(self.representation, Bargmann): - return self - else: - if self.representation.ansatz._original_abc_data: - A, b, c = self.representation.ansatz._original_abc_data - else: - A, b, _ = identity_Abc(len(self.wires.quantum)) - c = self.representation.data - bargmann = Bargmann(A, b, c) - try: - ret = self._getitem_builtin(self.modes) - ret._representation = bargmann - except TypeError: - ret = self._from_attributes(bargmann, self.wires, self.name) - if "manual_shape" in ret.__dict__: - del ret.manual_shape - return ret - def _add_parameter(self, parameter: Constant | Variable): r""" Adds a parameter to this circuit component and makes it accessible as an attribute. @@ -656,23 +609,11 @@ def _light_copy(self, wires: Wires | None = None) -> CircuitComponent: """ instance = super().__new__(self.__class__) instance.__dict__ = self.__dict__.copy() - instance.__dict__["_wires"] = wires or Wires(*self.wires.args) + instance.__dict__["_multi_rep"] = MultiRepresentation( + self.representation, wires or Wires(*self.wires.args) + ) return instance - def _matmul_indices(self, other: CircuitComponent) -> tuple[tuple[int, ...], tuple[int, ...]]: - r""" - Finds the indices of the wires being contracted when ``self @ other`` is called. - """ - # find the indices of the wires being contracted on the bra side - bra_modes = tuple(self.wires.bra.output.modes & other.wires.bra.input.modes) - idx_z = self.wires.bra.output[bra_modes].indices - idx_zconj = other.wires.bra.input[bra_modes].indices - # find the indices of the wires being contracted on the ket side - ket_modes = tuple(self.wires.ket.output.modes & other.wires.ket.input.modes) - idx_z += self.wires.ket.output[ket_modes].indices - idx_zconj += other.wires.ket.input[ket_modes].indices - return idx_z, idx_zconj - def _rshift_return( self, ret: CircuitComponent | np.ndarray | complex ) -> CircuitComponent | np.ndarray | complex: @@ -696,9 +637,12 @@ def __eq__(self, other) -> bool: r""" Whether this component is equal to another component. - Compares representations and wires, but not the other attributes (e.g. name and parameter set). + Compares multi-representations, but not the other attributes + (e.g. name and parameter set). """ - return self.representation == other.representation and self.wires == other.wires + if isinstance(other, CircuitComponent): + return self._multi_rep == other._multi_rep + return False def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: r""" @@ -719,19 +663,8 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: """ if isinstance(other, (numbers.Number, np.ndarray)): return self * other - - wires_result, perm = self.wires @ other.wires - idx_z, idx_zconj = self._matmul_indices(other) - if type(self.representation) == type(other.representation): - self_rep = self.representation - other_rep = other.representation - else: - self_rep = self.to_bargmann().representation - other_rep = other.to_bargmann().representation - - rep = self_rep[idx_z] @ other_rep[idx_zconj] - rep = rep.reorder(perm) if perm else rep - return CircuitComponent._from_attributes(rep, wires_result, None) + result = self._multi_rep @ other._multi_rep + return CircuitComponent._from_attributes(result.representation, result.wires, None) def __mul__(self, other: Scalar) -> CircuitComponent: r""" diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 6460c28f6..abdd94db9 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -623,11 +623,10 @@ def __init__( f"Expected a representation with {2*len(modes)} variables, found {representation.ansatz.num_vars}." ) super().__init__( + representation=representation, wires=[modes, (), modes, ()], name=name, ) - if representation is not None: - self._representation = representation @property def is_positive(self) -> bool: @@ -754,7 +753,7 @@ def auto_shape( # experimental: if self.representation.ansatz.batch_size == 1: try: # fock - shape = self._representation.array.shape[1:] + shape = self.representation.array.shape[1:] except AttributeError: # bargmann if self.representation.ansatz.polynomial_shape[0] == 0: repr = self.representation @@ -933,11 +932,10 @@ def __init__( f"Expected a representation with {len(modes)} variables, found {representation.ansatz.num_vars}." ) super().__init__( + representation=representation, wires=[(), (), modes, ()], name=name, ) - if representation is not None: - self._representation = representation @property def is_physical(self) -> bool: @@ -1049,7 +1047,7 @@ def auto_shape( # experimental: if self.representation.ansatz.batch_size == 1: try: # fock - shape = self._representation.array.shape[1:] + shape = self.representation.array.shape[1:] except AttributeError: # bargmann if self.representation.ansatz.polynomial_shape[0] == 0: repr = self.representation.conj() & self.representation diff --git a/mrmustard/lab_dev/states/coherent.py b/mrmustard/lab_dev/states/coherent.py index 2b64ee045..d5caa67e6 100644 --- a/mrmustard/lab_dev/states/coherent.py +++ b/mrmustard/lab_dev/states/coherent.py @@ -20,6 +20,7 @@ from typing import Sequence +from mrmustard.physics.multi_representations import MultiRepresentation from mrmustard.physics.representations import Bargmann from mrmustard.physics import triples from .base import Ket @@ -82,6 +83,6 @@ def __init__( self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._representation = Bargmann.from_function( - fn=triples.coherent_state_Abc, x=self.x, y=self.y + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.coherent_state_Abc, x=self.x, y=self.y), self.wires ) diff --git a/mrmustard/lab_dev/states/displaced_squeezed.py b/mrmustard/lab_dev/states/displaced_squeezed.py index 83bb74808..a6ac4aaba 100644 --- a/mrmustard/lab_dev/states/displaced_squeezed.py +++ b/mrmustard/lab_dev/states/displaced_squeezed.py @@ -20,6 +20,7 @@ from typing import Sequence +from mrmustard.physics.multi_representations import MultiRepresentation from mrmustard.physics.representations import Bargmann from mrmustard.physics import triples from .base import Ket @@ -84,10 +85,13 @@ def __init__( self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = Bargmann.from_function( - fn=triples.displaced_squeezed_vacuum_state_Abc, - x=self.x, - y=self.y, - r=self.r, - phi=self.phi, + self._multi_rep = MultiRepresentation( + Bargmann.from_function( + fn=triples.displaced_squeezed_vacuum_state_Abc, + x=self.x, + y=self.y, + r=self.r, + phi=self.phi, + ), + self.wires, ) diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index 07261c8ae..9f6a52f63 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -20,6 +20,7 @@ from typing import Sequence +from mrmustard.physics.multi_representations import MultiRepresentation from mrmustard.physics.representations import Fock from mrmustard.physics.fock import fock_state from .base import Ket @@ -73,6 +74,7 @@ def __init__( self.short_name = [str(int(n)) for n in self.n.value] for i, cutoff in enumerate(self.cutoffs.value): self.manual_shape[i] = int(cutoff) + 1 - self._representation = Fock.from_function( - fock_state, n=self.n.value, cutoffs=self.cutoffs.value + + self._multi_rep = MultiRepresentation( + Fock.from_function(fock_state, n=self.n.value, cutoffs=self.cutoffs.value), self.wires ) diff --git a/mrmustard/lab_dev/states/quadrature_eigenstate.py b/mrmustard/lab_dev/states/quadrature_eigenstate.py index f7f6b89d2..72ce0f174 100644 --- a/mrmustard/lab_dev/states/quadrature_eigenstate.py +++ b/mrmustard/lab_dev/states/quadrature_eigenstate.py @@ -22,6 +22,7 @@ import numpy as np +from mrmustard.physics.multi_representations import MultiRepresentation from mrmustard.physics.representations import Bargmann from mrmustard.physics import triples from .base import Ket @@ -67,10 +68,10 @@ def __init__( xs, phis = list(reshape_params(len(modes), x=x, phi=phi)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = Bargmann.from_function( - fn=triples.quadrature_eigenstates_Abc, x=self.x, phi=self.phi + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.quadrature_eigenstates_Abc, x=self.x, phi=self.phi), + self.wires, ) - self.manual_shape = (50,) @property diff --git a/mrmustard/lab_dev/states/sauron.py b/mrmustard/lab_dev/states/sauron.py index dc0e7f10b..ed15f7ead 100644 --- a/mrmustard/lab_dev/states/sauron.py +++ b/mrmustard/lab_dev/states/sauron.py @@ -16,6 +16,7 @@ from typing import Sequence from mrmustard.lab_dev.states.base import Ket +from mrmustard.physics.multi_representations import MultiRepresentation from mrmustard.physics.representations import Bargmann from mrmustard.physics import triples @@ -42,6 +43,9 @@ def __init__(self, modes: Sequence[int], n: int, epsilon: float = 0.1): super().__init__(name=f"Sauron-{n}", modes=modes) self._add_parameter(make_parameter(False, n, "n", (None, None), dtype="int64")) self._add_parameter(make_parameter(False, epsilon, "epsilon", (None, None))) - self._representation = Bargmann.from_function( - triples.sauron_state_Abc, n=self.n.value, epsilon=self.epsilon.value + self._multi_rep = MultiRepresentation( + Bargmann.from_function( + triples.sauron_state_Abc, n=self.n.value, epsilon=self.epsilon.value + ), + self.wires, ) diff --git a/mrmustard/lab_dev/states/squeezed_vacuum.py b/mrmustard/lab_dev/states/squeezed_vacuum.py index 79ee503a0..5c449bd1e 100644 --- a/mrmustard/lab_dev/states/squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/squeezed_vacuum.py @@ -20,6 +20,7 @@ from typing import Sequence +from mrmustard.physics.multi_representations import MultiRepresentation from mrmustard.physics.representations import Bargmann from mrmustard.physics import triples from .base import Ket @@ -68,7 +69,7 @@ def __init__( rs, phis = list(reshape_params(len(modes), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - - self._representation = Bargmann.from_function( - fn=triples.squeezed_vacuum_state_Abc, r=self.r, phi=self.phi + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.squeezed_vacuum_state_Abc, r=self.r, phi=self.phi), + self.wires, ) diff --git a/mrmustard/lab_dev/states/thermal.py b/mrmustard/lab_dev/states/thermal.py index bce628435..1c9d00e67 100644 --- a/mrmustard/lab_dev/states/thermal.py +++ b/mrmustard/lab_dev/states/thermal.py @@ -20,6 +20,7 @@ from typing import Sequence +from mrmustard.physics.multi_representations import MultiRepresentation from mrmustard.physics.representations import Bargmann from mrmustard.physics import triples from .base import DM @@ -61,5 +62,6 @@ def __init__( super().__init__(modes=modes, name="Thermal") (nbars,) = list(reshape_params(len(modes), nbar=nbar)) self._add_parameter(make_parameter(nbar_trainable, nbars, "nbar", nbar_bounds)) - - self._representation = Bargmann.from_function(fn=triples.thermal_state_Abc, nbar=self.nbar) + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.thermal_state_Abc, nbar=self.nbar), self.wires + ) diff --git a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py index 7cb535ee0..6a6fcbd27 100644 --- a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py @@ -20,6 +20,7 @@ from typing import Sequence +from mrmustard.physics.multi_representations import MultiRepresentation from mrmustard.physics.representations import Bargmann from mrmustard.physics import triples from .base import Ket @@ -66,7 +67,9 @@ def __init__( rs, phis = list(reshape_params(int(len(modes) / 2), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - - self._representation = Bargmann.from_function( - fn=triples.two_mode_squeezed_vacuum_state_Abc, r=self.r, phi=self.phi + self._multi_rep = MultiRepresentation( + Bargmann.from_function( + fn=triples.two_mode_squeezed_vacuum_state_Abc, r=self.r, phi=self.phi + ), + self.wires, ) diff --git a/mrmustard/lab_dev/transformations/amplifier.py b/mrmustard/lab_dev/transformations/amplifier.py index 2457a540f..223a3de34 100644 --- a/mrmustard/lab_dev/transformations/amplifier.py +++ b/mrmustard/lab_dev/transformations/amplifier.py @@ -21,6 +21,7 @@ from typing import Sequence from .base import Channel +from ...physics.multi_representations import MultiRepresentation from ...physics.representations import Bargmann from ...physics import triples from ..utils import make_parameter, reshape_params @@ -95,4 +96,6 @@ def __init__( None, ) ) - self._representation = Bargmann.from_function(fn=triples.amplifier_Abc, g=self.gain) + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.amplifier_Abc, g=self.gain), self.wires + ) diff --git a/mrmustard/lab_dev/transformations/attenuator.py b/mrmustard/lab_dev/transformations/attenuator.py index 72b88864e..c728ec582 100644 --- a/mrmustard/lab_dev/transformations/attenuator.py +++ b/mrmustard/lab_dev/transformations/attenuator.py @@ -21,6 +21,7 @@ from typing import Sequence from .base import Channel +from ...physics.multi_representations import MultiRepresentation from ...physics.representations import Bargmann from ...physics import triples from ..utils import make_parameter, reshape_params @@ -95,6 +96,6 @@ def __init__( None, ) ) - self._representation = Bargmann.from_function( - fn=triples.attenuator_Abc, eta=self.transmissivity + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.attenuator_Abc, eta=self.transmissivity), self.wires ) diff --git a/mrmustard/lab_dev/transformations/bsgate.py b/mrmustard/lab_dev/transformations/bsgate.py index 71df7cd21..f3807953f 100644 --- a/mrmustard/lab_dev/transformations/bsgate.py +++ b/mrmustard/lab_dev/transformations/bsgate.py @@ -21,6 +21,7 @@ from typing import Sequence from .base import Unitary +from ...physics.multi_representations import MultiRepresentation from ...physics.representations import Bargmann from ...physics import triples from ..utils import make_parameter @@ -104,7 +105,9 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="BSgate") self._add_parameter(make_parameter(theta_trainable, theta, "theta", theta_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - - self._representation = Bargmann.from_function( - fn=triples.beamsplitter_gate_Abc, theta=self.theta, phi=self.phi + self._multi_rep = MultiRepresentation( + Bargmann.from_function( + fn=triples.beamsplitter_gate_Abc, theta=self.theta, phi=self.phi + ), + self.wires, ) diff --git a/mrmustard/lab_dev/transformations/cft.py b/mrmustard/lab_dev/transformations/cft.py index c5174a026..6233c10c8 100644 --- a/mrmustard/lab_dev/transformations/cft.py +++ b/mrmustard/lab_dev/transformations/cft.py @@ -18,6 +18,7 @@ from typing import Sequence from mrmustard.lab_dev.transformations.base import Map +from mrmustard.physics.multi_representations import MultiRepresentation from mrmustard.physics.representations import Bargmann from mrmustard.physics import triples @@ -47,6 +48,7 @@ def __init__( modes_in=modes, name="CFT", ) - self._representation = Bargmann.from_function( - fn=triples.complex_fourier_transform_Abc, n_modes=len(modes) + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.complex_fourier_transform_Abc, n_modes=len(modes)), + self.wires, ) diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index a7640f3ef..597de0489 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -24,6 +24,7 @@ from mrmustard import math from .base import Unitary +from ...physics.multi_representations import MultiRepresentation from ...physics.representations import Bargmann from ...physics import triples, fock from ..utils import make_parameter, reshape_params @@ -94,13 +95,12 @@ def __init__( xs, ys = list(reshape_params(len(modes), x=x, y=y)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - - self._representation = Bargmann.from_function( - fn=triples.displacement_gate_Abc, x=self.x, y=self.y + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.displacement_gate_Abc, x=self.x, y=self.y), self.wires ) def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTensor: - r""", shape: Optional[int | Sequence[int]] = None, batched=False) -> CircuitComponent: + r""" Returns the unitary representation of the Displacement gate using the Laguerre polynomials. If the shape is not given, it defaults to the ``auto_shape`` of the component if it is available, otherwise it defaults to the value of ``AUTOSHAPE_MAX`` in the settings. diff --git a/mrmustard/lab_dev/transformations/fockdamping.py b/mrmustard/lab_dev/transformations/fockdamping.py index 8bcf98991..8d5b38e0c 100644 --- a/mrmustard/lab_dev/transformations/fockdamping.py +++ b/mrmustard/lab_dev/transformations/fockdamping.py @@ -21,6 +21,7 @@ from typing import Sequence from .base import Operation +from ...physics.multi_representations import MultiRepresentation from ...physics.representations import Bargmann from ...physics import triples from ..utils import make_parameter, reshape_params @@ -85,6 +86,6 @@ def __init__( None, ) ) - self._representation = Bargmann.from_function( - fn=triples.fock_damping_Abc, beta=self.damping + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.fock_damping_Abc, beta=self.damping), self.wires ) diff --git a/mrmustard/lab_dev/transformations/ggate.py b/mrmustard/lab_dev/transformations/ggate.py index 98fd50abf..ebec739c3 100644 --- a/mrmustard/lab_dev/transformations/ggate.py +++ b/mrmustard/lab_dev/transformations/ggate.py @@ -22,6 +22,7 @@ from mrmustard.utils.typing import RealMatrix from .base import Unitary +from ...physics.multi_representations import MultiRepresentation from ...physics.representations import Bargmann from ..utils import make_parameter @@ -57,8 +58,10 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="Ggate") S = make_parameter(symplectic_trainable, symplectic, "symplectic", (None, None)) self.parameter_set.add_parameter(S) - - self._representation = Bargmann.from_function( - fn=lambda s: Unitary.from_symplectic(modes, s).bargmann_triple(), - s=self.parameter_set.symplectic, + self._multi_rep = MultiRepresentation( + Bargmann.from_function( + fn=lambda s: Unitary.from_symplectic(modes, s).bargmann_triple(), + s=self.parameter_set.symplectic, + ), + self.wires, ) diff --git a/mrmustard/lab_dev/transformations/rgate.py b/mrmustard/lab_dev/transformations/rgate.py index 405f225d4..ff120f1b5 100644 --- a/mrmustard/lab_dev/transformations/rgate.py +++ b/mrmustard/lab_dev/transformations/rgate.py @@ -21,6 +21,7 @@ from typing import Sequence from .base import Unitary +from ...physics.multi_representations import MultiRepresentation from ...physics.representations import Bargmann from ...physics import triples from ..utils import make_parameter, reshape_params @@ -62,5 +63,6 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="Rgate") (phis,) = list(reshape_params(len(modes), phi=phi)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - - self._representation = Bargmann.from_function(fn=triples.rotation_gate_Abc, theta=self.phi) + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.rotation_gate_Abc, theta=self.phi), self.wires + ) diff --git a/mrmustard/lab_dev/transformations/s2gate.py b/mrmustard/lab_dev/transformations/s2gate.py index 32d30a37f..0f4dd9dfd 100644 --- a/mrmustard/lab_dev/transformations/s2gate.py +++ b/mrmustard/lab_dev/transformations/s2gate.py @@ -21,6 +21,7 @@ from typing import Sequence from .base import Unitary +from ...physics.multi_representations import MultiRepresentation from ...physics.representations import Bargmann from ...physics import triples from ..utils import make_parameter @@ -87,7 +88,7 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="S2gate") self._add_parameter(make_parameter(r_trainable, r, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - - self._representation = Bargmann.from_function( - fn=triples.twomode_squeezing_gate_Abc, r=self.r, phi=self.phi + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.twomode_squeezing_gate_Abc, r=self.r, phi=self.phi), + self.wires, ) diff --git a/mrmustard/lab_dev/transformations/sgate.py b/mrmustard/lab_dev/transformations/sgate.py index 6610a6863..227bc78ff 100644 --- a/mrmustard/lab_dev/transformations/sgate.py +++ b/mrmustard/lab_dev/transformations/sgate.py @@ -21,6 +21,7 @@ from typing import Sequence from .base import Unitary +from ...physics.multi_representations import MultiRepresentation from ...physics.representations import Bargmann from ...physics import triples from ..utils import make_parameter, reshape_params @@ -94,7 +95,7 @@ def __init__( rs, phis = list(reshape_params(len(modes), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - - self._representation = Bargmann.from_function( - fn=triples.squeezing_gate_Abc, r=self.r, delta=self.phi + self._multi_rep = MultiRepresentation( + Bargmann.from_function(fn=triples.squeezing_gate_Abc, r=self.r, delta=self.phi), + self.wires, ) diff --git a/mrmustard/physics/multi_representations.py b/mrmustard/physics/multi_representations.py new file mode 100644 index 000000000..e8c9bbe40 --- /dev/null +++ b/mrmustard/physics/multi_representations.py @@ -0,0 +1,241 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This module contains the class for multi-representations. +""" + +from __future__ import annotations +from typing import Any, Sequence +from enum import Enum + +from mrmustard import settings, math, widgets as mmwidgets +from mrmustard.utils.typing import ( + Scalar, + ComplexTensor, + ComplexMatrix, + ComplexVector, + Vector, + Batch, +) + +from .representations import Representation, Bargmann, Fock +from .triples import identity_Abc +from .wires import Wires + +__all__ = ["MultiRepresentation"] + + +class RepEnum(Enum): + r""" + An enum to represent what representation a wire is in. + """ + + NONETYPE = 0 + BARGMANN = 1 + FOCK = 2 + QUADRATURE = 3 + PHASESPACE = 4 + + @classmethod + def from_representation(cls, value: Representation): + r""" """ + return cls[value.__class__.__name__.upper()] + + @classmethod + def _missing_(cls, value): + return cls.NONETYPE + + def __repr__(self) -> str: + return self.name + + +class MultiRepresentation: + r""" + A class for multi-representations. + + A multi-representation handles the underlying representation, the wires of + said representation and keeps track of representation conversions. + + Args: + representation: A representation for this multi-representation. + wires: The wires of this multi-representation. + wire_reps: An optional dictionary for keeping track of each wire's representation. + """ + + def __init__( + self, + representation: Representation | None, + wires: Wires | None, + wire_reps: dict | None = None, + ) -> None: + self._representation = representation + self._wires = wires + rep_enum = ( + RepEnum[representation.__class__.__name__.upper()] if representation else RepEnum(1) + ) + self._wire_reps = wire_reps or dict.fromkeys(wires.modes, rep_enum) + + @property + def representation(self) -> Representation | None: + r""" + The underlying representation of this multi-representation. + """ + return self._representation + + @property + def wires(self) -> Wires | None: + r""" + The wires of this multi-representation. + """ + return self._wires + + def bargmann_triple( + self, batched: bool = False + ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: + r""" + The Bargmann parametrization of this multi-representation, if available. + It returns a triple (A, b, c) such that the Bargmann function of this is + :math:`F(z) = c \exp\left(\frac{1}{2} z^T A z + b^T z\right)` + + If ``batched`` is ``False`` (default), it removes the batch dimension if it is of size 1. + + Args: + batched: Whether to return the triple batched. + """ + try: + A, b, c = self.representation.triple + if not batched and self.representation.ansatz.batch_size == 1: + return A[0], b[0], c[0] + else: + return A, b, c + except AttributeError as e: + raise AttributeError("No Bargmann data for this component.") from e + + def fock(self, shape: int | Sequence[int], batched=False) -> ComplexTensor: + r""" + Returns an array representation of this component in the Fock basis with the given shape. + If the shape is not given, it defaults to the ``auto_shape`` of the component if it is + available, otherwise it defaults to the value of ``AUTOSHAPE_MAX`` in the settings. + + Args: + shape: The shape of the returned representation. If ``shape`` is given as an ``int``, + it is broadcasted to all the dimensions. If not given, it is estimated. + batched: Whether the returned representation is batched or not. If ``False`` (default) + it will squeeze the batch dimension if it is 1. + Returns: + array: The Fock representation of this component. + """ + num_vars = self.representation.ansatz.num_vars + if isinstance(shape, int): + shape = (shape,) * num_vars + try: + As, bs, cs = self.bargmann_triple(batched=True) + if len(shape) != num_vars: + raise ValueError( + f"Expected Fock shape of length {num_vars}, got length {len(shape)}" + ) + if self.representation.ansatz.polynomial_shape[0] == 0: + arrays = [math.hermite_renormalized(A, b, c, shape) for A, b, c in zip(As, bs, cs)] + else: + arrays = [ + math.sum( + math.hermite_renormalized(A, b, 1, shape + c.shape) * c, + axes=math.arange( + num_vars, num_vars + len(c.shape), dtype=math.int32 + ).tolist(), + ) + for A, b, c in zip(As, bs, cs) + ] + except AttributeError: + if len(shape) != num_vars: + raise ValueError( + f"Expected Fock shape of length {num_vars}, got length {len(shape)}" + ) + arrays = self.representation.reduce(shape).array + array = math.sum(arrays, axes=[0]) + arrays = math.expand_dims(array, 0) if batched else array + return arrays + + def to_bargmann(self) -> MultiRepresentation: + r""" + Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation. + """ + if isinstance(self.representation, Bargmann): + return self + else: + if self.representation.ansatz._original_abc_data: + A, b, c = self.representation.ansatz._original_abc_data + else: + A, b, _ = identity_Abc(len(self.wires.quantum)) + c = self.representation.data + bargmann = Bargmann(A, b, c) + return MultiRepresentation(bargmann, self.wires) + + def to_fock(self, shape: int | Sequence[int]) -> MultiRepresentation: + r""" + Returns a new multi-representation with a ``Fock`` representation. + + Args: + shape: The shape of the returned representation. If ``shape``is given as + an ``int``, it is broadcasted to all the dimensions. If ``None``, it + defaults to the value of ``AUTOSHAPE_MAX`` in the settings. + """ + fock = Fock(self.fock(shape, batched=True), batched=True) + try: + if self.representation.ansatz.polynomial_shape[0] == 0: + fock.ansatz._original_abc_data = self.representation.triple + except AttributeError: + fock.ansatz._original_abc_data = None + return MultiRepresentation(fock, self.wires) + + def _matmul_indices( + self, other: MultiRepresentation + ) -> tuple[tuple[int, ...], tuple[int, ...]]: + r""" + Finds the indices of the wires being contracted when ``self @ other`` is called. + """ + # find the indices of the wires being contracted on the bra side + bra_modes = tuple(self.wires.bra.output.modes & other.wires.bra.input.modes) + idx_z = self.wires.bra.output[bra_modes].indices + idx_zconj = other.wires.bra.input[bra_modes].indices + # find the indices of the wires being contracted on the ket side + ket_modes = tuple(self.wires.ket.output.modes & other.wires.ket.input.modes) + idx_z += self.wires.ket.output[ket_modes].indices + idx_zconj += other.wires.ket.input[ket_modes].indices + return idx_z, idx_zconj + + def __eq__(self, other): + if isinstance(other, MultiRepresentation): + return ( + self.representation == other.representation + and self.wires == other.wires + and self._wire_reps == other._wire_reps + ) + return False + + def __matmul__(self, other: MultiRepresentation): + wires_result, perm = self.wires @ other.wires + idx_z, idx_zconj = self._matmul_indices(other) + if type(self.representation) is type(other.representation): + self_rep = self.representation + other_rep = other.representation + else: + self_rep = self.to_bargmann().representation + other_rep = other.to_bargmann().representation + + rep = self_rep[idx_z] @ other_rep[idx_zconj] + rep = rep.reorder(perm) if perm else rep + return MultiRepresentation(rep, wires_result) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 30b42b948..3cff71257 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -38,6 +38,7 @@ ) from mrmustard.lab_dev.transformations import Dgate, Attenuator, Unitary, Sgate, Channel from mrmustard.physics.wires import Wires +from mrmustard.physics.multi_representations import MultiRepresentation from ..random import Abc_triple @@ -169,7 +170,7 @@ def test_on(self): assert isinstance(d67.r, Variable) assert math.allclose(d89.r.value, d67.r.value) assert bool(d67.parameter_set) is True - assert d67._representation is d89._representation + assert d67.representation is d89.representation def test_on_error(self): with pytest.raises(ValueError): @@ -393,7 +394,7 @@ def test_rshift_bargmann_and_fock(self, shape): def test_rshift_error(self): vac012 = Vacuum([0, 1, 2]) d0 = Dgate([0], x=0.1, y=0.1) - d0._wires = Wires() + d0._multi_rep = MultiRepresentation(d0.representation, Wires()) with pytest.raises(ValueError, match="not clear"): vac012 >> d0 From d92fb82c1f7f462661643b3831b46cdb3b7b2389 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 3 Oct 2024 13:40:58 -0400 Subject: [PATCH 18/87] tests passing --- mrmustard/lab_dev/states/base.py | 5 +++-- mrmustard/lab_dev/transformations/dgate.py | 9 ++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index abdd94db9..9b703cebf 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -61,6 +61,7 @@ from mrmustard.lab_dev.circuit_components_utils import BtoPS, BtoQ, TraceOut from mrmustard.lab_dev.circuit_components import CircuitComponent from mrmustard.physics.wires import Wires +from mrmustard.physics.multi_representations import MultiRepresentation __all__ = ["State", "DM", "Ket"] @@ -623,10 +624,10 @@ def __init__( f"Expected a representation with {2*len(modes)} variables, found {representation.ansatz.num_vars}." ) super().__init__( - representation=representation, wires=[modes, (), modes, ()], name=name, ) + self._multi_rep = MultiRepresentation(representation, self.wires) @property def is_positive(self) -> bool: @@ -932,10 +933,10 @@ def __init__( f"Expected a representation with {len(modes)} variables, found {representation.ansatz.num_vars}." ) super().__init__( - representation=representation, wires=[(), (), modes, ()], name=name, ) + self._multi_rep = MultiRepresentation(representation, self.wires) @property def is_physical(self) -> bool: diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index 597de0489..35876c665 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -25,7 +25,7 @@ from .base import Unitary from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann +from ...physics.representations import Bargmann, Fock from ...physics import triples, fock from ..utils import make_parameter, reshape_params @@ -142,3 +142,10 @@ def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTenso array = fock.displacement(x[0], y[0], shape=shape) arrays = math.expand_dims(array, 0) if batched else array return arrays + + def to_fock(self, shape: int | Sequence[int] | None = None) -> Dgate: + fock = Fock(self.fock(shape, batched=True), batched=True) + fock.ansatz._original_abc_data = self.representation.triple + ret = self._getitem_builtin(self.modes) + ret._multi_rep = MultiRepresentation(fock, self.wires) + return ret From 8c45168c0f4f1017d04bee57c560eebf56ee8846 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 3 Oct 2024 13:46:17 -0400 Subject: [PATCH 19/87] unused imports --- mrmustard/physics/multi_representations.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mrmustard/physics/multi_representations.py b/mrmustard/physics/multi_representations.py index e8c9bbe40..4f9313211 100644 --- a/mrmustard/physics/multi_representations.py +++ b/mrmustard/physics/multi_representations.py @@ -18,16 +18,14 @@ """ from __future__ import annotations -from typing import Any, Sequence +from typing import Sequence from enum import Enum -from mrmustard import settings, math, widgets as mmwidgets +from mrmustard import math from mrmustard.utils.typing import ( - Scalar, ComplexTensor, ComplexMatrix, ComplexVector, - Vector, Batch, ) @@ -51,7 +49,9 @@ class RepEnum(Enum): @classmethod def from_representation(cls, value: Representation): - r""" """ + r""" + Returns a ``RepEnum`` from a ``Representation``. + """ return cls[value.__class__.__name__.upper()] @classmethod @@ -83,10 +83,9 @@ def __init__( ) -> None: self._representation = representation self._wires = wires - rep_enum = ( - RepEnum[representation.__class__.__name__.upper()] if representation else RepEnum(1) + self._wire_reps = wire_reps or dict.fromkeys( + wires.modes, RepEnum.from_representation(representation) ) - self._wire_reps = wire_reps or dict.fromkeys(wires.modes, rep_enum) @property def representation(self) -> Representation | None: From be5dfe10b861304691f110fa7a7aa85e1ba360ed Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 3 Oct 2024 14:41:55 -0400 Subject: [PATCH 20/87] some progress for btoq --- mrmustard/lab_dev/circuit_components.py | 1 + .../circuit_components_utils/b_to_q.py | 45 +++++++++++++++++++ mrmustard/physics/multi_representations.py | 2 + 3 files changed, 48 insertions(+) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 4768a916a..4482e0774 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -619,6 +619,7 @@ def _rshift_return( ) -> CircuitComponent | np.ndarray | complex: "internal convenience method for right-shift, to return the right type of object" if len(ret.wires) > 0: + print("ret", ret._multi_rep._wire_reps) return ret scalar = ret.representation.scalar return math.sum(scalar) if not settings.UNSAFE_ZIP_BATCH else scalar diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index ff4ebfee1..33855b323 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -19,11 +19,16 @@ from __future__ import annotations from typing import Sequence +import numpy as np +import numbers + from mrmustard.physics import triples from mrmustard.math.parameters import Constant from ..transformations.base import Operation from ...physics.representations import Bargmann +from ...physics.multi_representations import RepEnum +from ..circuit_components import CircuitComponent __all__ = ["BtoQ"] @@ -54,3 +59,43 @@ def __init__( name="BtoQ", ) self._add_parameter(Constant(phi, "phi")) + + def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitComponent | complex: + if hasattr(other, "__custom_rrshift__"): + return other.__custom_rrshift__(self) + + if isinstance(other, (numbers.Number, np.ndarray)): + return self * other + + s_k = other.wires.ket + s_b = other.wires.bra + o_k = self.wires.ket + o_b = self.wires.bra + + only_ket = (not s_b and s_k) and (not o_b and o_k) + only_bra = (not s_k and s_b) and (not o_k and o_b) + both_sides = s_b and s_k and o_b and o_k + + self_needs_bra = (not s_b and s_k) and (o_b and o_k) + self_needs_ket = (not s_k and s_b) and (o_b and o_k) + + other_needs_bra = (s_b and s_k) and (not o_b and o_k) + other_needs_ket = (s_b and s_k) and (not o_k and o_b) + + if only_ket or only_bra or both_sides: + ret = other @ self + elif self_needs_bra or self_needs_ket: + ret = other.adjoint @ (other @ self) + elif other_needs_bra or other_needs_ket: + ret = (other @ self) @ self.adjoint + else: + msg = f"``>>`` not supported between {other} and {self} because it's not clear " + msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components." + raise ValueError(msg) + + # update ret._multi_rep._wire_reps + temp = dict.fromkeys(self.modes, RepEnum.QUADRATURE) + print("1", ret._multi_rep._wire_reps) + ret._multi_rep._wire_reps.update(temp) + print("2", ret._multi_rep._wire_reps) + return self._rshift_return(ret) diff --git a/mrmustard/physics/multi_representations.py b/mrmustard/physics/multi_representations.py index 4f9313211..b8458e484 100644 --- a/mrmustard/physics/multi_representations.py +++ b/mrmustard/physics/multi_representations.py @@ -63,6 +63,8 @@ def __repr__(self) -> str: class MultiRepresentation: + # TODO: merge current Representation and Anstaz -> Ansatz + # TODO: rename to Representation r""" A class for multi-representations. From 6569ce82b98bb3a87ad4ce0f96ae5df90c46d024 Mon Sep 17 00:00:00 2001 From: Anthony Date: Fri, 4 Oct 2024 11:44:59 -0400 Subject: [PATCH 21/87] merge --- mrmustard/lab_dev/circuits.py | 2 +- tests/test_lab_dev/test_circuit_components.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mrmustard/lab_dev/circuits.py b/mrmustard/lab_dev/circuits.py index 711ea3b73..7e73a1644 100644 --- a/mrmustard/lab_dev/circuits.py +++ b/mrmustard/lab_dev/circuits.py @@ -300,7 +300,7 @@ def deserialize(cls, data: dict) -> Circuit: def __eq__(self, other: Circuit) -> bool: if not isinstance(other, Circuit): - return false + return False return self.components == other.components def __getitem__(self, idx: int) -> CircuitComponent: diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 9943a6e36..aeb4ac73e 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -202,7 +202,7 @@ def test_to_fock_bargmann_Dgate(self): d = Dgate([1], x=0.1, y=0.1) d_fock = d.to_fock(shape=(4, 6)) d_barg = d_fock.to_bargmann() - assert d_fock.representation.ansatz._original_abc_data == d.representation.triple + assert d_fock.representation._original_abc_data == d.representation.triple assert d_barg == d def test_to_fock_poly_exp(self): @@ -211,7 +211,7 @@ def test_to_fock_poly_exp(self): barg = Bargmann(A, b, c) fock_cc = CircuitComponent(barg, wires=[(), (), (0, 1), ()]).to_fock(shape=(10, 10)) poly = math.hermite_renormalized(A, b, 1, (10, 10, 5)) - assert fock_cc.representation.ansatz._original_abc_data is None + assert fock_cc.representation._original_abc_data is None assert np.allclose(fock_cc.representation.data, np.einsum("ijk,k", poly, c[0])) def test_add(self): From e261cbc734f7cd446e42c3b6227a7a2a7066ee04 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 10:06:11 -0400 Subject: [PATCH 22/87] some ansatz tests ported --- .../test_states/test_states_base.py | 55 +++++- .../test_representations/test_fock.py | 163 +++++++++++------- 2 files changed, 156 insertions(+), 62 deletions(-) diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index ca09c37f6..aa1791b94 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -24,10 +24,14 @@ from mrmustard import math, settings from mrmustard.math.parameters import Constant, Variable +from mrmustard.physics.bargmann_utils import ( + bargmann_Abc_to_phasespace_cov_means, + wigner_to_bargmann_rho, +) from mrmustard.physics.gaussian import vacuum_cov, vacuum_means, squeezed_vacuum_cov from mrmustard.physics.triples import coherent_state_Abc from mrmustard.lab_dev.circuit_components import CircuitComponent -from mrmustard.lab_dev.circuit_components_utils import TraceOut +from mrmustard.lab_dev.circuit_components_utils import BtoPS, TraceOut from mrmustard.lab_dev.states import ( Coherent, DisplacedSqueezed, @@ -520,6 +524,55 @@ def test_from_fock_error(self): with pytest.raises(ValueError): DM.from_fock([0], state01.fock(5), "my_dm", True) + def test_bargmann_Abc_to_phasespace_cov_means(self): + # The init state cov and means comes from the random state 'state = Gaussian(1) >> Dgate([0.2], [0.3])' + state_cov = np.array([[0.32210229, -0.99732956], [-0.99732956, 6.1926484]]) + state_means = np.array([0.2, 0.3]) + state = DM.from_bargmann([0], wigner_to_bargmann_rho(state_cov, state_means)) + state_after = state >> BtoPS(modes=[0], s=0) # pylint: disable=protected-access + A1, b1, c1 = state_after.bargmann_triple() + ( + new_state_cov, + new_state_means, + new_state_coeff, + ) = bargmann_Abc_to_phasespace_cov_means(A1, b1, c1) + assert np.allclose(state_cov, new_state_cov) + assert np.allclose(state_means, new_state_means) + assert np.allclose(1.0 / (2 * np.pi), new_state_coeff) + + state_cov = np.array( + [ + [1.00918303, -0.33243548, 0.15202393, -0.07540124], + [-0.33243548, 1.2203162, -0.03961978, 0.30853472], + [0.15202393, -0.03961978, 1.11158673, 0.28786279], + [-0.07540124, 0.30853472, 0.28786279, 0.97833402], + ] + ) + state_means = np.array([0.4, 0.6, 0.0, 0.0]) + A, b, c = wigner_to_bargmann_rho(state_cov, state_means) + state = DM.from_bargmann(modes=[0, 1], triple=(A, b, c)) + + state_after = state >> BtoPS(modes=[0, 1], s=0) # pylint: disable=protected-access + A1, b1, c1 = state_after.bargmann_triple() + ( + new_state_cov1, + new_state_means1, + new_state_coeff1, + ) = bargmann_Abc_to_phasespace_cov_means(A1, b1, c1) + + A22, b22, c22 = (state >> BtoPS([0], 0) >> BtoPS([1], 0)).bargmann_triple() + ( + new_state_cov22, + new_state_means22, + new_state_coeff22, + ) = bargmann_Abc_to_phasespace_cov_means(A22, b22, c22) + assert math.allclose(new_state_cov22, state_cov) + assert math.allclose(new_state_cov1, state_cov) + assert math.allclose(new_state_means1, state_means) + assert math.allclose(new_state_means22, state_means) + assert math.allclose(new_state_coeff1, 1 / (2 * np.pi) ** 2) + assert math.allclose(new_state_coeff22, 1 / (2 * np.pi) ** 2) + def test_bargmann_triple_error(self): fock = Number([0], n=10).dm() with pytest.raises(AttributeError): diff --git a/tests/test_physics/test_representations/test_fock.py b/tests/test_physics/test_representations/test_fock.py index 8e47f9ce3..3d95912c6 100644 --- a/tests/test_physics/test_representations/test_fock.py +++ b/tests/test_physics/test_representations/test_fock.py @@ -49,11 +49,35 @@ def test_init_non_batched(self): assert fock.array.shape == (1, 5, 7, 8) assert np.allclose(fock.array[0, :, :, :], self.array578) - def test_sum_batch(self): - fock = Fock(self.array2578, batched=True) - fock_collapsed = fock.sum_batch()[0] - assert fock_collapsed.array.shape == (1, 5, 7, 8) - assert np.allclose(fock_collapsed.array, np.sum(self.array2578, axis=0)) + def test_add(self): + fock1 = Fock(self.array2578, batched=True) + fock2 = Fock(self.array5578, batched=True) + fock1_add_fock2 = fock1 + fock2 + assert fock1_add_fock2.array.shape == (10, 5, 7, 8) + assert np.allclose(fock1_add_fock2.array[0], self.array2578[0] + self.array5578[0]) + assert np.allclose(fock1_add_fock2.array[4], self.array2578[0] + self.array5578[4]) + assert np.allclose(fock1_add_fock2.array[5], self.array2578[1] + self.array5578[0]) + + def test_algebra_with_different_shape_of_array_raise_errors(self): + array = np.random.random((2, 4, 5)) + array2 = np.random.random((3, 4, 8, 9)) + aa1 = Fock(array=array) + aa2 = Fock(array=array2) + + with pytest.raises(Exception, match="Cannot add"): + aa1 + aa2 + + with pytest.raises(Exception, match="Cannot add"): + aa1 - aa2 + + with pytest.raises(Exception, match="Cannot multiply"): + aa1 * aa2 + + with pytest.raises(Exception, match="Cannot divide"): + aa1 / aa2 + + with pytest.raises(Exception): + aa1 == aa2 def test_and(self): fock1 = Fock(self.array1578, batched=True) @@ -65,40 +89,21 @@ def test_and(self): math.reshape(np.einsum("bcde, pfgh -> bpcdefgh", self.array1578, self.array5578), -1), ) - def test_multiply_a_scalar(self): - fock1 = Fock(self.array1578, batched=True) - fock_test = 1.3 * fock1 - assert np.allclose(fock_test.array, 1.3 * self.array1578) - - def test_mul(self): - fock1 = Fock(self.array1578, batched=True) - fock2 = Fock(self.array5578, batched=True) - fock1_mul_fock2 = fock1 * fock2 - assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) - assert np.allclose( - math.reshape(fock1_mul_fock2.array, -1), - math.reshape(np.einsum("bcde, pcde -> bpcde", self.array1578, self.array5578), -1), - ) + def test_conj(self): + fock = Fock(self.array1578, batched=True) + fock_conj = fock.conj + assert np.allclose(fock_conj.array, np.conj(self.array1578)) def test_divide_on_a_scalar(self): fock1 = Fock(self.array1578, batched=True) fock_test = fock1 / 1.5 assert np.allclose(fock_test.array, self.array1578 / 1.5) - def test_truediv(self): - fock1 = Fock(self.array1578, batched=True) - fock2 = Fock(self.array5578, batched=True) - fock1_mul_fock2 = fock1 / fock2 - assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) - assert np.allclose( - math.reshape(fock1_mul_fock2.array, -1), - math.reshape(np.einsum("bcde, pcde -> bpcde", self.array1578, 1 / self.array5578), -1), - ) - - def test_conj(self): - fock = Fock(self.array1578, batched=True) - fock_conj = fock.conj - assert np.allclose(fock_conj.array, np.conj(self.array1578)) + def test_equal(self): + array = np.random.random((2, 4, 5)) + aa1 = Fock(array=array) + aa2 = Fock(array=array) + assert aa1 == aa2 def test_matmul_fock_fock(self): array2 = math.astensor(np.random.random((5, 6, 7, 8, 10))) @@ -111,37 +116,27 @@ def test_matmul_fock_fock(self): math.reshape(np.einsum("bcde, pfgeh -> bpcdfgh", self.array2578, array2), -1), ) - def test_add(self): - fock1 = Fock(self.array2578, batched=True) - fock2 = Fock(self.array5578, batched=True) - fock1_add_fock2 = fock1 + fock2 - assert fock1_add_fock2.array.shape == (10, 5, 7, 8) - assert np.allclose(fock1_add_fock2.array[0], self.array2578[0] + self.array5578[0]) - assert np.allclose(fock1_add_fock2.array[4], self.array2578[0] + self.array5578[4]) - assert np.allclose(fock1_add_fock2.array[5], self.array2578[1] + self.array5578[0]) - - def test_sub(self): - fock1 = Fock(self.array2578, batched=True) + def test_mul(self): + fock1 = Fock(self.array1578, batched=True) fock2 = Fock(self.array5578, batched=True) - fock1_sub_fock2 = fock1 - fock2 - assert fock1_sub_fock2.array.shape == (10, 5, 7, 8) - assert np.allclose(fock1_sub_fock2.array[0], self.array2578[0] - self.array5578[0]) - assert np.allclose(fock1_sub_fock2.array[4], self.array2578[0] - self.array5578[4]) - assert np.allclose(fock1_sub_fock2.array[9], self.array2578[1] - self.array5578[4]) + fock1_mul_fock2 = fock1 * fock2 + assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) + assert np.allclose( + math.reshape(fock1_mul_fock2.array, -1), + math.reshape(np.einsum("bcde, pcde -> bpcde", self.array1578, self.array5578), -1), + ) - def test_trace(self): - array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) - fock1 = Fock(array1, batched=True) - fock2 = fock1.trace(idxs1=[0, 3], idxs2=[1, 6]) - assert fock2.array.shape == (2, 1, 4, 1, 3) - assert np.allclose(fock2.array, np.einsum("bccefghfj -> beghj", array1)) + def test_multiply_a_scalar(self): + fock1 = Fock(self.array1578, batched=True) + fock_test = 1.3 * fock1 + assert np.allclose(fock_test.array, 1.3 * self.array1578) - def test_reorder(self): - array1 = math.astensor(np.arange(8).reshape((1, 2, 2, 2))) - fock1 = Fock(array1, batched=True) - fock2 = fock1.reorder(order=(2, 1, 0)) - assert np.allclose(fock2.array, np.array([[[[0, 4], [2, 6]], [[1, 5], [3, 7]]]])) - assert np.allclose(fock2.array, np.arange(8).reshape((1, 2, 2, 2), order="F")) + def test_neg(self): + array = np.random.random((2, 4, 5)) + aa = Fock(array=array) + minusaa = -aa + assert isinstance(minusaa, Fock) + assert np.allclose(minusaa.array, -array) @pytest.mark.parametrize("batched", [True, False]) def test_reduce(self, batched): @@ -176,6 +171,52 @@ def test_reduce_padded(self): fock1 = fock.reduce((8, 8, 8)) assert fock1.array.shape == (1, 8, 8, 8) + def test_reorder(self): + array1 = math.astensor(np.arange(8).reshape((1, 2, 2, 2))) + fock1 = Fock(array1, batched=True) + fock2 = fock1.reorder(order=(2, 1, 0)) + assert np.allclose(fock2.array, np.array([[[[0, 4], [2, 6]], [[1, 5], [3, 7]]]])) + assert np.allclose(fock2.array, np.arange(8).reshape((1, 2, 2, 2), order="F")) + + def test_sub(self): + fock1 = Fock(self.array2578, batched=True) + fock2 = Fock(self.array5578, batched=True) + fock1_sub_fock2 = fock1 - fock2 + assert fock1_sub_fock2.array.shape == (10, 5, 7, 8) + assert np.allclose(fock1_sub_fock2.array[0], self.array2578[0] - self.array5578[0]) + assert np.allclose(fock1_sub_fock2.array[4], self.array2578[0] - self.array5578[4]) + assert np.allclose(fock1_sub_fock2.array[9], self.array2578[1] - self.array5578[4]) + + def test_sum_batch(self): + fock = Fock(self.array2578, batched=True) + fock_collapsed = fock.sum_batch()[0] + assert fock_collapsed.array.shape == (1, 5, 7, 8) + assert np.allclose(fock_collapsed.array, np.sum(self.array2578, axis=0)) + + def test_trace(self): + array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) + fock1 = Fock(array1, batched=True) + fock2 = fock1.trace(idxs1=[0, 3], idxs2=[1, 6]) + assert fock2.array.shape == (2, 1, 4, 1, 3) + assert np.allclose(fock2.array, np.einsum("bccefghfj -> beghj", array1)) + + def test_truediv(self): + fock1 = Fock(self.array1578, batched=True) + fock2 = Fock(self.array5578, batched=True) + fock1_mul_fock2 = fock1 / fock2 + assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) + assert np.allclose( + math.reshape(fock1_mul_fock2.array, -1), + math.reshape(np.einsum("bcde, pcde -> bpcde", self.array1578, 1 / self.array5578), -1), + ) + + def test_truediv_a_scalar(self): + array = np.random.random((2, 4, 5)) + aa1 = Fock(array=array) + aa1_scalar = aa1 / 6 + assert isinstance(aa1_scalar, Fock) + assert np.allclose(aa1_scalar.array, array / 6) + # @pytest.mark.parametrize("shape", [(1, 8), (1, 8, 8)]) # @patch("mrmustard.physics.representations.display") # def test_ipython_repr(self, mock_display, shape): From 40ef6f12ce79e0991562c64e882dd64774e260da Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 10:25:21 -0400 Subject: [PATCH 23/87] ansatz tests --- mrmustard/physics/representations/bargmann.py | 67 ++-- .../test_representations/test_bargmann.py | 294 ++++++++++++++---- 2 files changed, 270 insertions(+), 91 deletions(-) diff --git a/mrmustard/physics/representations/bargmann.py b/mrmustard/physics/representations/bargmann.py index 853f1e25b..bd44861d3 100644 --- a/mrmustard/physics/representations/bargmann.py +++ b/mrmustard/physics/representations/bargmann.py @@ -657,38 +657,41 @@ def __add__(self, other: Bargmann) -> Bargmann: the shapes fit. Example: If the shape of c1 is (1,3,4,5) and the shape of c2 is (1,5,4,3) then the shape of the combined object will be (2,5,4,5). """ - combined_matrices = math.concat([self.A, other.A], axis=0) - combined_vectors = math.concat([self.b, other.b], axis=0) - - a0s = self.c.shape[1:] - a1s = other.c.shape[1:] - if a0s == a1s: - combined_arrays = math.concat([self.c, other.c], axis=0) - else: - s_max = np.maximum(np.array(a0s), np.array(a1s)) - - padding_array0 = np.array( - ( - np.zeros(len(s_max) + 1), - np.concatenate((np.array([0]), np.array((s_max - a0s)))), - ), - dtype=int, - ).T - padding_tuple0 = tuple(tuple(padding_array0[i]) for i in range(len(s_max) + 1)) - - padding_array1 = np.array( - ( - np.zeros(len(s_max) + 1), - np.concatenate((np.array([0]), np.array((s_max - a1s)))), - ), - dtype=int, - ).T - padding_tuple1 = tuple(tuple(padding_array1[i]) for i in range(len(s_max) + 1)) - a0_new = np.pad(self.c, padding_tuple0, "constant") - a1_new = np.pad(other.c, padding_tuple1, "constant") - combined_arrays = math.concat([a0_new, a1_new], axis=0) - # note output is not simplified - return Bargmann(combined_matrices, combined_vectors, combined_arrays) + try: + combined_matrices = math.concat([self.A, other.A], axis=0) + combined_vectors = math.concat([self.b, other.b], axis=0) + + a0s = self.c.shape[1:] + a1s = other.c.shape[1:] + if a0s == a1s: + combined_arrays = math.concat([self.c, other.c], axis=0) + else: + s_max = np.maximum(np.array(a0s), np.array(a1s)) + + padding_array0 = np.array( + ( + np.zeros(len(s_max) + 1), + np.concatenate((np.array([0]), np.array((s_max - a0s)))), + ), + dtype=int, + ).T + padding_tuple0 = tuple(tuple(padding_array0[i]) for i in range(len(s_max) + 1)) + + padding_array1 = np.array( + ( + np.zeros(len(s_max) + 1), + np.concatenate((np.array([0]), np.array((s_max - a1s)))), + ), + dtype=int, + ).T + padding_tuple1 = tuple(tuple(padding_array1[i]) for i in range(len(s_max) + 1)) + a0_new = np.pad(self.c, padding_tuple0, "constant") + a1_new = np.pad(other.c, padding_tuple1, "constant") + combined_arrays = math.concat([a0_new, a1_new], axis=0) + # note output is not simplified + return Bargmann(combined_matrices, combined_vectors, combined_arrays) + except Exception as e: + raise TypeError(f"Cannot add {self.__class__} and {other.__class__}.") from e def __and__(self, other: Bargmann) -> Bargmann: r""" diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py index 843cdf992..f65c1b896 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -27,6 +27,8 @@ complex_gaussian_integral, ) from mrmustard.physics.representations.bargmann import Bargmann +from mrmustard.physics.representations.fock import Fock + from ...random import Abc_triple # original settings @@ -53,51 +55,152 @@ def test_init_non_batched(self, triple): assert np.allclose(bargmann.b, b) assert np.allclose(bargmann.c, c) - @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) - def test_conj(self, triple): - A, b, c = triple - bargmann = Bargmann(*triple).conj + @pytest.mark.parametrize("n", [1, 2, 3]) + def test_add(self, n): + triple1 = Abc_triple(n) + triple2 = Abc_triple(n) - assert np.allclose(bargmann.A, math.conj(A)) - assert np.allclose(bargmann.b, math.conj(b)) - assert np.allclose(bargmann.c, math.conj(c)) + bargmann1 = Bargmann(*triple1) + bargmann2 = Bargmann(*triple2) + bargmann_add = bargmann1 + bargmann2 + + assert np.allclose(bargmann_add.A, math.concat([bargmann1.A, bargmann2.A], axis=0)) + assert np.allclose(bargmann_add.b, math.concat([bargmann1.b, bargmann2.b], axis=0)) + assert np.allclose(bargmann_add.c, math.concat([bargmann1.c, bargmann2.c], axis=0)) + + def test_add_error(self): + bargmann = Bargmann(*Abc_triple(3)) + fock = Fock(np.random.random((1, 4, 4, 4)), batched=True) + + with pytest.raises(TypeError, match="Cannot add"): + bargmann + fock # pylint: disable=pointless-statement @pytest.mark.parametrize("n", [1, 2, 3]) def test_and(self, n): triple1 = Abc_triple(n) triple2 = Abc_triple(n) - temp1 = Bargmann(*triple1) - print(temp1.A.shape) - bargmann = Bargmann(*triple1) & Bargmann(*triple2) assert bargmann.A.shape == (1, 2 * n, 2 * n) assert bargmann.b.shape == (1, 2 * n) assert bargmann.c.shape == (1,) - @pytest.mark.parametrize("scalar", [0.5, 1.2]) + def test_call(self): + A, b, c = Abc_triple(5) + ansatz = Bargmann(A, b, c) + + assert np.allclose(ansatz(z=math.zeros_like(b)), c) + + A, b, _ = Abc_triple(4) + c = np.random.random(size=(1, 3, 3, 3)) + ansatz = Bargmann(A, b, c) + z = np.random.uniform(-10, 10, size=(7, 2)) + with pytest.raises( + Exception, match="The sum of the dimension of the argument and polynomial" + ): + ansatz(z) + + A = np.array([[0, 1], [1, 0]]) + b = np.zeros(2) + c = c = np.zeros(10, dtype=complex).reshape(1, -1) + c[0, -1] = 1 + obj1 = Bargmann(A, b, c) + + nine_factorial = np.prod(np.arange(1, 9)) + assert np.allclose(obj1(np.array([[0.1]])), 0.1**9 / np.sqrt(nine_factorial)) + + def test_call_none(self): + A1, b1, _ = Abc_triple(7) + A2, b2, _ = Abc_triple(7) + A3, b3, _ = Abc_triple(7) + + batch = 3 + c = np.random.random(size=(batch, 5, 5, 5)) / 1000 + + obj = Bargmann([A1, A2, A3], [b1, b2, b3], c) + z0 = np.array([[None, 2, None, 5]]) + z1 = np.array([[1, 2, 4, 5]]) + z2 = np.array([[1, 4]]) + obj_none = obj(z0) + val1 = obj(z1) + val2 = obj_none(z2) + assert np.allclose(val1, val2) + + obj1 = Bargmann(A1, b1, c[0].reshape(1, 5, 5, 5)) + z0 = np.array([[None, 2, None, 5], [None, 1, None, 4]]) + z1 = np.array([[1, 2, 4, 5], [2, 1, 4, 4]]) + z2 = np.array([[1, 4], [2, 4]]) + obj1_none = obj1(z0) + obj1_none0 = Bargmann(obj1_none.A[0], obj1_none.b[0], obj1_none.c[0].reshape(1, 5, 5, 5)) + obj1_none1 = Bargmann(obj1_none.A[1], obj1_none.b[1], obj1_none.c[1].reshape(1, 5, 5, 5)) + val1 = obj1(z1) + val2 = np.array( + (obj1_none0(z2[0].reshape(1, -1)), obj1_none1(z2[1].reshape(1, -1))) + ).reshape(-1) + assert np.allclose(val1, val2) + @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) - def test_mul_with_scalar(self, scalar, triple): - bargmann1 = Bargmann(*triple) - bargmann_mul = bargmann1 * scalar + def test_conj(self, triple): + A, b, c = triple + bargmann = Bargmann(*triple).conj - assert np.allclose(bargmann1.A, bargmann_mul.A) - assert np.allclose(bargmann1.b, bargmann_mul.b) - assert np.allclose(bargmann1.c * scalar, bargmann_mul.c) + assert np.allclose(bargmann.A, math.conj(A)) + assert np.allclose(bargmann.b, math.conj(b)) + assert np.allclose(bargmann.c, math.conj(c)) + + def test_decompose_ansatz(self): + A, b, _ = Abc_triple(4) + c = np.random.uniform(-10, 10, size=(1, 3, 3, 3)) + ansatz = Bargmann(A, b, c) + + decomp_ansatz = ansatz.decompose_ansatz() + z = np.random.uniform(-10, 10, size=(1, 1)) + assert np.allclose(ansatz(z), decomp_ansatz(z)) + assert np.allclose(decomp_ansatz.A.shape, (1, 2, 2)) + + def test_decompose_ansatz_batch(self): + """ + In this test the batch dimension of both ``z`` and ``Abc`` is tested. + """ + A1, b1, _ = Abc_triple(4) + c1 = np.random.uniform(-10, 10, size=(3, 3, 3)) + A2, b2, _ = Abc_triple(4) + c2 = np.random.uniform(-10, 10, size=(3, 3, 3)) + ansatz = Bargmann([A1, A2], [b1, b2], [c1, c2]) + + decomp_ansatz = ansatz.decompose_ansatz() + z = np.random.uniform(-10, 10, size=(3, 1)) + assert np.allclose(ansatz(z), decomp_ansatz(z)) + assert np.allclose(decomp_ansatz.A.shape, (2, 2, 2)) + assert np.allclose(decomp_ansatz.b.shape, (2, 2)) + assert np.allclose(decomp_ansatz.c.shape, (2, 9)) + + A1, b1, _ = Abc_triple(5) + c1 = np.random.uniform(-10, 10, size=(3, 3, 3)) + A2, b2, _ = Abc_triple(5) + c2 = np.random.uniform(-10, 10, size=(3, 3, 3)) + ansatz = Bargmann([A1, A2], [b1, b2], [c1, c2]) + + decomp_ansatz = ansatz.decompose_ansatz() + z = np.random.uniform(-10, 10, size=(3, 2)) + assert np.allclose(ansatz(z), decomp_ansatz(z)) + assert np.allclose(decomp_ansatz.A.shape, (2, 4, 4)) + assert np.allclose(decomp_ansatz.b.shape, (2, 4)) + assert np.allclose(decomp_ansatz.c.shape, (2, 9, 9)) @pytest.mark.parametrize("n", [1, 2, 3]) - def test_mul(self, n): + def test_div(self, n): triple1 = Abc_triple(n) triple2 = Abc_triple(n) bargmann1 = Bargmann(*triple1) bargmann2 = Bargmann(*triple2) - bargmann_mul = bargmann1 * bargmann2 + bargmann_div = bargmann1 / bargmann2 - assert np.allclose(bargmann_mul.A, bargmann1.A + bargmann2.A) - assert np.allclose(bargmann_mul.b, bargmann1.b + bargmann2.b) - assert np.allclose(bargmann_mul.c, bargmann1.c * bargmann2.c) + assert np.allclose(bargmann_div.A, bargmann1.A - bargmann2.A) + assert np.allclose(bargmann_div.b, bargmann1.b - bargmann2.b) + assert np.allclose(bargmann_div.c, bargmann1.c / bargmann2.c) @pytest.mark.parametrize("scalar", [0.5, 1.2]) @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) @@ -109,38 +212,128 @@ def test_div_with_scalar(self, scalar, triple): assert np.allclose(bargmann1.b, bargmann_div.b) assert np.allclose(bargmann1.c / scalar, bargmann_div.c) - @pytest.mark.parametrize("n", [1, 2, 3]) - def test_div(self, n): - triple1 = Abc_triple(n) - triple2 = Abc_triple(n) + def test_eq(self): + A, b, c = Abc_triple(5) - bargmann1 = Bargmann(*triple1) - bargmann2 = Bargmann(*triple2) - bargmann_div = bargmann1 / bargmann2 + ansatz = Bargmann(A, b, c) + ansatz2 = Bargmann(2 * A, 2 * b, 2 * c) - assert np.allclose(bargmann_div.A, bargmann1.A - bargmann2.A) - assert np.allclose(bargmann_div.b, bargmann1.b - bargmann2.b) - assert np.allclose(bargmann_div.c, bargmann1.c / bargmann2.c) + assert ansatz == ansatz + assert ansatz2 == ansatz2 + assert ansatz != ansatz2 + assert ansatz2 != ansatz + + def test_matmul_barg_barg(self): + triple1 = Abc_triple(3) + triple2 = Abc_triple(3) + + res1 = Bargmann(*triple1) @ Bargmann(*triple2) + exp1 = contract_two_Abc(triple1, triple2, [], []) + assert np.allclose(res1.A, exp1[0]) + assert np.allclose(res1.b, exp1[1]) + assert np.allclose(res1.c, exp1[2]) @pytest.mark.parametrize("n", [1, 2, 3]) - def test_add(self, n): + def test_mul(self, n): triple1 = Abc_triple(n) triple2 = Abc_triple(n) bargmann1 = Bargmann(*triple1) bargmann2 = Bargmann(*triple2) - bargmann_add = bargmann1 + bargmann2 + bargmann_mul = bargmann1 * bargmann2 - assert np.allclose(bargmann_add.A, math.concat([bargmann1.A, bargmann2.A], axis=0)) - assert np.allclose(bargmann_add.b, math.concat([bargmann1.b, bargmann2.b], axis=0)) - assert np.allclose(bargmann_add.c, math.concat([bargmann1.c, bargmann2.c], axis=0)) + assert np.allclose(bargmann_mul.A, bargmann1.A + bargmann2.A) + assert np.allclose(bargmann_mul.b, bargmann1.b + bargmann2.b) + assert np.allclose(bargmann_mul.c, bargmann1.c * bargmann2.c) + + @pytest.mark.parametrize("scalar", [0.5, 1.2]) + @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) + def test_mul_with_scalar(self, scalar, triple): + bargmann1 = Bargmann(*triple) + bargmann_mul = bargmann1 * scalar + + assert np.allclose(bargmann1.A, bargmann_mul.A) + assert np.allclose(bargmann1.b, bargmann_mul.b) + assert np.allclose(bargmann1.c * scalar, bargmann_mul.c) + + def test_order_batch(self): + ansatz = Bargmann( + A=[np.array([[0]]), np.array([[1]])], + b=[np.array([1]), np.array([0])], + c=[1, 2], + ) + ansatz._order_batch() # pylint: disable=protected-access - # def test_add_error(self): - # bargmann = Bargmann(*Abc_triple(3)) - # fock = Fock(np.random.random((1, 4, 4, 4)), batched=True) + assert np.allclose(ansatz.A[0], np.array([[1]])) + assert np.allclose(ansatz.b[0], np.array([0])) + assert ansatz.c[0] == 2 + assert np.allclose(ansatz.A[1], np.array([[0]])) + assert np.allclose(ansatz.b[1], np.array([1])) + assert ansatz.c[1] == 1 - # with pytest.raises(ValueError): - # bargmann + fock # pylint: disable=pointless-statement + def test_polynomial_shape(self): + A, b, _ = Abc_triple(4) + c = np.array([[1, 2, 3]]) + ansatz = Bargmann(A, b, c) + + poly_dim, poly_shape = ansatz.polynomial_shape + assert np.allclose(poly_dim, 1) + assert np.allclose(poly_shape, (3,)) + + A1, b1, _ = Abc_triple(4) + c1 = np.array([[1, 2, 3]]) + ansatz1 = Bargmann(A1, b1, c1) + + A2, b2, _ = Abc_triple(4) + c2 = np.array([[1, 2, 3]]) + ansatz2 = Bargmann(A2, b2, c2) + + ansatz3 = ansatz1 * ansatz2 + + poly_dim, poly_shape = ansatz3.polynomial_shape + assert np.allclose(poly_dim, 2) + assert np.allclose(poly_shape, (3, 3)) + + def test_reorder(self): + triple = Abc_triple(3) + bargmann = Bargmann(*triple).reorder((0, 2, 1)) + + assert np.allclose(bargmann.A[0], triple[0][[0, 2, 1], :][:, [0, 2, 1]]) + assert np.allclose(bargmann.b[0], triple[1][[0, 2, 1]]) + + def test_simplify(self): + A, b, c = Abc_triple(5) + + ansatz = Bargmann(A, b, c) + + ansatz = ansatz + ansatz + + assert np.allclose(ansatz.A[0], ansatz.A[1]) + assert np.allclose(ansatz.A[0], A) + assert np.allclose(ansatz.b[0], ansatz.b[1]) + assert np.allclose(ansatz.b[0], b) + + ansatz.simplify() + assert len(ansatz.A) == 1 + assert len(ansatz.b) == 1 + assert ansatz.c == 2 * c + + def test_simplify_v2(self): + A, b, c = Abc_triple(5) + + ansatz = Bargmann(A, b, c) + + ansatz = ansatz + ansatz + + assert np.allclose(ansatz.A[0], ansatz.A[1]) + assert np.allclose(ansatz.A[0], A) + assert np.allclose(ansatz.b[0], ansatz.b[1]) + assert np.allclose(ansatz.b[0], b) + + ansatz.simplify_v2() + assert len(ansatz.A) == 1 + assert len(ansatz.b) == 1 + assert np.allclose(ansatz.c, 2 * c) @pytest.mark.parametrize("n", [1, 2, 3]) def test_sub(self, n): @@ -164,23 +357,6 @@ def test_trace(self): assert np.allclose(bargmann.b, b) assert np.allclose(bargmann.c, c) - def test_reorder(self): - triple = Abc_triple(3) - bargmann = Bargmann(*triple).reorder((0, 2, 1)) - - assert np.allclose(bargmann.A[0], triple[0][[0, 2, 1], :][:, [0, 2, 1]]) - assert np.allclose(bargmann.b[0], triple[1][[0, 2, 1]]) - - def test_matmul_barg_barg(self): - triple1 = Abc_triple(3) - triple2 = Abc_triple(3) - - res1 = Bargmann(*triple1) @ Bargmann(*triple2) - exp1 = contract_two_Abc(triple1, triple2, [], []) - assert np.allclose(res1.A, exp1[0]) - assert np.allclose(res1.b, exp1[1]) - assert np.allclose(res1.c, exp1[2]) - # @patch("mrmustard.physics.representations.bargmann.display") # def test_ipython_repr(self, mock_display): # """Test the IPython repr function.""" From fabf1d898e457607e213c20c7de1c3d0127eaaf0 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 10:38:08 -0400 Subject: [PATCH 24/87] some codefactor --- tests/test_physics/test_bargmann_utils.py | 16 ++++++++++++++++ tests/test_physics/test_fock_utils.py | 6 +++++- .../test_representations/test_bargmann.py | 4 ++-- .../test_representations/test_fock.py | 10 +++++----- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/tests/test_physics/test_bargmann_utils.py b/tests/test_physics/test_bargmann_utils.py index e0d128ae3..7545b15be 100644 --- a/tests/test_physics/test_bargmann_utils.py +++ b/tests/test_physics/test_bargmann_utils.py @@ -1,3 +1,19 @@ +# Copyright 2021 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the bargmann_utils.py file.""" + import numpy as np from mrmustard import math diff --git a/tests/test_physics/test_fock_utils.py b/tests/test_physics/test_fock_utils.py index d1cf3295b..2465d7464 100644 --- a/tests/test_physics/test_fock_utils.py +++ b/tests/test_physics/test_fock_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the fock.py file.""" +"""Tests for the fock_utils.py file.""" # pylint: disable=pointless-statement @@ -472,12 +472,14 @@ def test_displacement_values(): @given(x=st.floats(-1, 1), y=st.floats(-1, 1)) def test_number_means(x, y): + """Tests the mean photon number.""" assert np.allclose(State(ket=Coherent(x, y).ket([80])).number_means, x * x + y * y) assert np.allclose(State(dm=Coherent(x, y).dm([80])).number_means, x * x + y * y) @given(x=st.floats(-1, 1), y=st.floats(-1, 1)) def test_number_variances_coh(x, y): + """Tests the variance of the number operator.""" assert np.allclose( fock_utils.number_variances(Coherent(x, y).ket([80]), False)[0], x * x + y * y ) @@ -485,10 +487,12 @@ def test_number_variances_coh(x, y): def test_number_variances_fock(): + """Tests the variance of the number operator in Fock.""" assert np.allclose(fock_utils.number_variances(Fock(n=1).ket(), False), 0) assert np.allclose(fock_utils.number_variances(Fock(n=1).dm(), True), 0) def test_normalize_dm(): + """Tests normalizing a DM.""" dm = np.array([[0.2, 0], [0, 0.2]]) assert np.allclose(fock_utils.normalize(dm, True), np.array([[0.5, 0], [0, 0.5]])) diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py index f65c1b896..4dc0040f1 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -218,8 +218,8 @@ def test_eq(self): ansatz = Bargmann(A, b, c) ansatz2 = Bargmann(2 * A, 2 * b, 2 * c) - assert ansatz == ansatz - assert ansatz2 == ansatz2 + assert ansatz == ansatz # pylint: disable= comparison-with-itself + assert ansatz2 == ansatz2 # pylint: disable= comparison-with-itself assert ansatz != ansatz2 assert ansatz2 != ansatz diff --git a/tests/test_physics/test_representations/test_fock.py b/tests/test_physics/test_representations/test_fock.py index 3d95912c6..a021fdb20 100644 --- a/tests/test_physics/test_representations/test_fock.py +++ b/tests/test_physics/test_representations/test_fock.py @@ -65,19 +65,19 @@ def test_algebra_with_different_shape_of_array_raise_errors(self): aa2 = Fock(array=array2) with pytest.raises(Exception, match="Cannot add"): - aa1 + aa2 + aa1 + aa2 # pylint: disable=pointless-statement with pytest.raises(Exception, match="Cannot add"): - aa1 - aa2 + aa1 - aa2 # pylint: disable=pointless-statement with pytest.raises(Exception, match="Cannot multiply"): - aa1 * aa2 + aa1 * aa2 # pylint: disable=pointless-statement with pytest.raises(Exception, match="Cannot divide"): - aa1 / aa2 + aa1 / aa2 # pylint: disable=pointless-statement with pytest.raises(Exception): - aa1 == aa2 + aa1 == aa2 # pylint: disable=pointless-statement def test_and(self): fock1 = Fock(self.array1578, batched=True) From e2012c08b98a04c7da8424cb3ad4208ed7a83bd7 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 10:39:43 -0400 Subject: [PATCH 25/87] some codefactor --- tests/test_physics/test_representations/test_bargmann.py | 2 ++ tests/test_physics/test_representations/test_fock.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py index 4dc0040f1..022927d09 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -14,6 +14,8 @@ """This module contains tests for ``Representation`` objects.""" +# pylint: disable = too-many-public-methods + from unittest.mock import patch import numpy as np diff --git a/tests/test_physics/test_representations/test_fock.py b/tests/test_physics/test_representations/test_fock.py index a021fdb20..614dc8fa8 100644 --- a/tests/test_physics/test_representations/test_fock.py +++ b/tests/test_physics/test_representations/test_fock.py @@ -14,6 +14,8 @@ """This module contains tests for ``Representation`` objects.""" +# pylint: disable = missing-function-docstring + from unittest.mock import patch import numpy as np @@ -27,8 +29,6 @@ # original settings autocutoff_max0 = settings.AUTOCUTOFF_MAX_CUTOFF -# pylint: disable = missing-function-docstring - class TestFockRepresentation: # pylint:disable=too-many-public-methods r"""Tests the Fock Representation.""" From 3080a6c70bd22fecaf185581245c1199589b6e64 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 10:41:47 -0400 Subject: [PATCH 26/87] docs --- doc/code/physics.rst | 1 - doc/code/physics/ansatze.rst | 8 -------- 2 files changed, 9 deletions(-) delete mode 100644 doc/code/physics/ansatze.rst diff --git a/doc/code/physics.rst b/doc/code/physics.rst index 5429fa8c4..1efa7ec53 100644 --- a/doc/code/physics.rst +++ b/doc/code/physics.rst @@ -4,7 +4,6 @@ mrmustard.physics .. toctree:: :maxdepth: 1 - physics/ansatze physics/representations .. toctree:: diff --git a/doc/code/physics/ansatze.rst b/doc/code/physics/ansatze.rst deleted file mode 100644 index 85718bc67..000000000 --- a/doc/code/physics/ansatze.rst +++ /dev/null @@ -1,8 +0,0 @@ -The representations' Ansatze -============================ - -.. currentmodule:: mrmustard.physics.ansatze - -.. automodapi:: mrmustard.physics.ansatze - :no-heading: - :include-all-objects: From 640259be6214e4d6ed7e6e1dc7f060019a1b50d8 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 10:47:27 -0400 Subject: [PATCH 27/87] docs --- doc/code/physics/utils/bargmann_calculations.rst | 4 ++-- doc/code/physics/utils/fock_calculations.rst | 4 ++-- mrmustard/math/backend_numpy.py | 2 +- mrmustard/math/backend_tensorflow.py | 2 +- .../math/lattice/strategies/compactFock/diagonal_amps.py | 6 +++--- .../math/lattice/strategies/compactFock/diagonal_grad.py | 8 ++++---- .../strategies/compactFock/singleLeftoverMode_amps.py | 6 +++--- .../strategies/compactFock/singleLeftoverMode_grad.py | 6 +++--- .../lattice/strategies/julia/compactFock/diagonal_grad.jl | 2 +- 9 files changed, 20 insertions(+), 20 deletions(-) diff --git a/doc/code/physics/utils/bargmann_calculations.rst b/doc/code/physics/utils/bargmann_calculations.rst index 22362d468..08e1e953e 100644 --- a/doc/code/physics/utils/bargmann_calculations.rst +++ b/doc/code/physics/utils/bargmann_calculations.rst @@ -1,8 +1,8 @@ Calculations on Bargmann objects ================================ -.. currentmodule:: mrmustard.physics.bargmann +.. currentmodule:: mrmustard.physics.bargmann_utils -.. automodapi:: mrmustard.physics.bargmann +.. automodapi:: mrmustard.physics.bargmann_utils :no-heading: :include-all-objects: diff --git a/doc/code/physics/utils/fock_calculations.rst b/doc/code/physics/utils/fock_calculations.rst index a26b9f512..8a3f3d936 100644 --- a/doc/code/physics/utils/fock_calculations.rst +++ b/doc/code/physics/utils/fock_calculations.rst @@ -1,8 +1,8 @@ Calculations on Fock objects ============================ -.. currentmodule:: mrmustard.physics.fock +.. currentmodule:: mrmustard.physics.fock_utils -.. automodapi:: mrmustard.physics.fock +.. automodapi:: mrmustard.physics.fock_utils :no-heading: :include-all-objects: diff --git a/mrmustard/math/backend_numpy.py b/mrmustard/math/backend_numpy.py index 2ebde8f9e..18bffa73b 100644 --- a/mrmustard/math/backend_numpy.py +++ b/mrmustard/math/backend_numpy.py @@ -535,7 +535,7 @@ def hermite_renormalized_binomial( def reorder_AB_bargmann(self, A: np.ndarray, B: np.ndarray) -> tuple[np.ndarray, np.ndarray]: r"""In mrmustard.math.numba.compactFock~ dimensions of the Fock representation are ordered like [mode0,mode0,mode1,mode1,...] - while in mrmustard.physics.bargmann the ordering is [mode0,mode1,...,mode0,mode1,...]. Here we reorder A and B. + while in mrmustard.physics.bargmann_utils the ordering is [mode0,mode1,...,mode0,mode1,...]. Here we reorder A and B. """ ordering = np.arange(2 * A.shape[0] // 2).reshape(2, -1).T.flatten() A = self.gather(A, ordering, axis=1) diff --git a/mrmustard/math/backend_tensorflow.py b/mrmustard/math/backend_tensorflow.py index 593083e5c..ede56c91c 100644 --- a/mrmustard/math/backend_tensorflow.py +++ b/mrmustard/math/backend_tensorflow.py @@ -554,7 +554,7 @@ def grad(dLdGconj): def reorder_AB_bargmann(self, A: tf.Tensor, B: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]: r"""In mrmustard.math.compactFock.compactFock~ dimensions of the Fock representation are ordered like [mode0,mode0,mode1,mode1,...] - while in mrmustard.physics.bargmann the ordering is [mode0,mode1,...,mode0,mode1,...]. Here we reorder A and B. + while in mrmustard.physics.bargmann_utils the ordering is [mode0,mode1,...,mode0,mode1,...]. Here we reorder A and B. """ ordering = ( np.arange(2 * A.shape[0] // 2).reshape(2, -1).T.flatten() diff --git a/mrmustard/math/lattice/strategies/compactFock/diagonal_amps.py b/mrmustard/math/lattice/strategies/compactFock/diagonal_amps.py index 4a765fe52..1b6ac8800 100644 --- a/mrmustard/math/lattice/strategies/compactFock/diagonal_amps.py +++ b/mrmustard/math/lattice/strategies/compactFock/diagonal_amps.py @@ -21,7 +21,7 @@ def use_offDiag_pivot( """ Apply recurrence relation for pivot of type [a+1,a,b,b,c,c,...] / [a,a,b+1,b,c,c,...] / [a,a,b,b,c+1,c,...] Args: - A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of modes cutoffs (tuple): upper bounds for the number of photons in each mode params (tuple): (a,b,c,...) @@ -88,7 +88,7 @@ def use_diag_pivot(A, B, M, cutoffs, params, arr0, arr1): # pragma: no cover """ Apply recurrence relation for pivot of type [a,a,b,b,c,c...] Args: - A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of modes cutoffs (tuple): upper bounds for the number of photons in each mode params (tuple): (a,b,c,...) @@ -139,7 +139,7 @@ def fock_representation_diagonal_amps_NUMBA( Returns the PNR probabilities of a mixed state according to algorithm 1 of: https://doi.org/10.22331/q-2023-08-29-1097 Args: - A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of modes cutoffs (tuple): upper bounds for the number of photons in each mode arr0 (array): submatrix of the fock representation that contains Fock amplitudes of the type [a,a,b,b,c,c...] diff --git a/mrmustard/math/lattice/strategies/compactFock/diagonal_grad.py b/mrmustard/math/lattice/strategies/compactFock/diagonal_grad.py index 3279c991d..eba86a9ec 100644 --- a/mrmustard/math/lattice/strategies/compactFock/diagonal_grad.py +++ b/mrmustard/math/lattice/strategies/compactFock/diagonal_grad.py @@ -21,7 +21,7 @@ def calc_dA_dB(i, G_in_dA, G_in_dB, G_in, A, B, K_l, K_i, M, pivot_val, pivot_va Args: i (int): the element of the multidim index that is increased G_in, G_in_dA, G_in_dB (array, array, array): all Fock amplitudes from the 'read' group in the recurrence relation and their derivatives w.r.t. A and B - A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) K_l, K_i (vector, vector): SQRT[pivot], SQRT[pivot + 1] M (int): number of modes pivot_val, pivot_val_dA, pivot_val_dB (array, array, array): Fock amplitude at the position of the pivot and its derivatives w.r.t. A and B @@ -63,7 +63,7 @@ def use_offDiag_pivot_grad( """ Apply recurrence relation for pivot of type [a+1,a,b,b,c,c,...] / [a,a,b+1,b,c,c,...] / [a,a,b,b,c+1,c,...] Args: - A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of modes cutoffs (tuple): upper bounds for the number of photons in each mode params (tuple): (a,b,c,...) @@ -201,7 +201,7 @@ def use_diag_pivot_grad(A, B, M, cutoffs, params, arr0, arr1, arr0_dA, arr1_dA, """ Apply recurrence relation for pivot of type [a,a,b,b,c,c...] Args: - A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of modes cutoffs (tuple): upper bounds for the number of photons in each mode params (tuple): (a,b,c,...) @@ -265,7 +265,7 @@ def fock_representation_diagonal_grad_NUMBA( Returns the gradients of the PNR probabilities of a mixed state according to algorithm 1 of https://doi.org/10.22331/q-2023-08-29-1097 Args: - A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of modes cutoffs (tuple): upper bounds for the number of photons in each mode arr0 (array): submatrix of the fock representation that contains Fock amplitudes of the type [a,a,b,b,c,c...] diff --git a/mrmustard/math/lattice/strategies/compactFock/singleLeftoverMode_amps.py b/mrmustard/math/lattice/strategies/compactFock/singleLeftoverMode_amps.py index 67f74da4f..6f6adbfa9 100644 --- a/mrmustard/math/lattice/strategies/compactFock/singleLeftoverMode_amps.py +++ b/mrmustard/math/lattice/strategies/compactFock/singleLeftoverMode_amps.py @@ -100,7 +100,7 @@ def use_offDiag_pivot( """ Apply recurrence relation for pivot of type [a+1,a,b,b,c,c,...] / [a,a,b+1,b,c,c,...] / [a,a,b,b,c+1,c,...] Args: - A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of detected modes cutoffs (tuple): upper bounds for the number of photons in each mode params (tuple): (a,b,c,...) @@ -202,7 +202,7 @@ def use_diag_pivot(A, B, M, cutoff_leftoverMode, cutoffs_tail, params, arr0, arr """ Apply recurrence relation for pivot of type [a,a,b,b,c,c...] Args: - A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of detected modes cutoffs (tuple): upper bounds for the number of photons in each mode params (tuple): (a,b,c,...) @@ -281,7 +281,7 @@ def fock_representation_1leftoverMode_amps_NUMBA( Returns the density matrices in the upper, undetected mode of a circuit when all other modes are PNR detected according to algorithm 2 of https://doi.org/10.22331/q-2023-08-29-1097 Args: - A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of modes cutoffs (tuple): upper bounds for the number of photons in each mode arr0 (array): submatrix of the fock representation that contains Fock amplitudes of the type [a,a,b,b,c,c...] diff --git a/mrmustard/math/lattice/strategies/compactFock/singleLeftoverMode_grad.py b/mrmustard/math/lattice/strategies/compactFock/singleLeftoverMode_grad.py index 6d1df9f73..6c6c63c04 100644 --- a/mrmustard/math/lattice/strategies/compactFock/singleLeftoverMode_grad.py +++ b/mrmustard/math/lattice/strategies/compactFock/singleLeftoverMode_grad.py @@ -298,7 +298,7 @@ def use_offDiag_pivot_grad( """ Apply recurrence relation for pivot of type [a+1,a,b,b,c,c,...] / [a,a,b+1,b,c,c,...] / [a,a,b,b,c+1,c,...] Args: - A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of detected modes cutoffs (tuple): upper bounds for the number of photons in each mode params (tuple): (a,b,c,...) @@ -499,7 +499,7 @@ def use_diag_pivot_grad( """ Apply recurrence relation for pivot of type [a,a,b,b,c,c...] Args: - A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of detected modes cutoffs (tuple): upper bounds for the number of photons in each mode params (tuple): (a,b,c,...) @@ -594,7 +594,7 @@ def fock_representation_1leftoverMode_grad_NUMBA( Returns the gradients of the density matrices in the upper, undetected mode of a circuit when all other modes are PNR detected (according to algorithm 2 of https://doi.org/10.22331/q-2023-08-29-1097) Args: - A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, Vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) M (int): number of modes cutoffs (tuple): upper bounds for the number of photons in each mode arr0 (array): submatrix of the fock representation that contains Fock amplitudes of the type [a,a,b,b,c,c...] diff --git a/mrmustard/math/lattice/strategies/julia/compactFock/diagonal_grad.jl b/mrmustard/math/lattice/strategies/julia/compactFock/diagonal_grad.jl index 15324c59c..9dfce8a0d 100644 --- a/mrmustard/math/lattice/strategies/julia/compactFock/diagonal_grad.jl +++ b/mrmustard/math/lattice/strategies/julia/compactFock/diagonal_grad.jl @@ -8,7 +8,7 @@ function calc_dA_dB(i, G_in_dA, G_in_dB, G_in, A, B, K_l, K_i, M, pivot_val, piv Args: i (int): the element of the multidim index that is increased G_in, G_in_dA, G_in_dB (array, array, array): all Fock amplitudes from the 'read' group in the recurrence relation and their derivatives w.r.t. A and B - A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock.ABC) + A, B (array, vector): required input for recurrence relation (given by mrmustard.physics.fock_utils.ABC) K_l, K_i (vector, vector): SQRT[pivot], SQRT[pivot + 1] M (int): number of modes pivot_val, pivot_val_dA, pivot_val_dB (array, array, array): Fock amplitude at the position of the pivot and its derivatives w.r.t. A and B From dfe55bcf67b05b0391466441499dc47d3d59c57c Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 10:50:10 -0400 Subject: [PATCH 28/87] workflow --- .github/workflows/tests_docs.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/tests_docs.yml b/.github/workflows/tests_docs.yml index 913c1165b..f46240532 100644 --- a/.github/workflows/tests_docs.yml +++ b/.github/workflows/tests_docs.yml @@ -39,6 +39,5 @@ jobs: - name: Run tests run: | python -m pytest --doctest-modules mrmustard/math/parameter_set.py - python -m pytest --doctest-modules mrmustard/physics/ansatze.py - python -m pytest --doctest-modules mrmustard/physics/representations.py + python -m pytest --doctest-modules mrmustard/physics/representations python -m pytest --doctest-modules mrmustard/lab_dev From ba536ad83eb89fe3b55f9895b942e6038abedca7 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 11:44:49 -0400 Subject: [PATCH 29/87] some coverage --- .../test_representations/test_bargmann.py | 17 +++++++++++++++++ .../test_representations/test_fock.py | 5 +++++ 2 files changed, 22 insertions(+) diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py index 022927d09..9c3a6c59e 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -70,6 +70,23 @@ def test_add(self, n): assert np.allclose(bargmann_add.b, math.concat([bargmann1.b, bargmann2.b], axis=0)) assert np.allclose(bargmann_add.c, math.concat([bargmann1.c, bargmann2.c], axis=0)) + A1, b1, _ = Abc_triple(5) + c1 = np.random.random(size=(1, 3, 3)) + A2, b2, _ = Abc_triple(5) + c2 = np.random.random(size=(1, 2, 2)) + + bargmann3 = Bargmann(A1, b1, c1) + bargmann4 = Bargmann(A2, b2, c2) + + bargmann_add2 = bargmann3 + bargmann4 + + assert np.allclose(bargmann_add2.A[0], A1) + assert np.allclose(bargmann_add2.b[0], b1) + assert np.allclose(bargmann_add2.c[0], c1[0]) + assert np.allclose(bargmann_add2.A[1], A2) + assert np.allclose(bargmann_add2.b[1], b2) + assert np.allclose(bargmann_add2.c[1][:2, :2], c2[0]) + def test_add_error(self): bargmann = Bargmann(*Abc_triple(3)) fock = Fock(np.random.random((1, 4, 4, 4)), batched=True) diff --git a/tests/test_physics/test_representations/test_fock.py b/tests/test_physics/test_representations/test_fock.py index 614dc8fa8..41ffa69b1 100644 --- a/tests/test_physics/test_representations/test_fock.py +++ b/tests/test_physics/test_representations/test_fock.py @@ -89,6 +89,11 @@ def test_and(self): math.reshape(np.einsum("bcde, pfgh -> bpcdefgh", self.array1578, self.array5578), -1), ) + def test_call(self): + fock = Fock(self.array1578, batched=True) + with pytest.raises(AttributeError, match="Cannot call"): + fock(0) + def test_conj(self): fock = Fock(self.array1578, batched=True) fock_conj = fock.conj From d5b347dcbec41d262ae955807426749a9cb664d0 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 13:57:03 -0400 Subject: [PATCH 30/87] coverage --- .../test_representations/test_bargmann.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py index 9c3a6c59e..c8bbd6bc9 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -178,6 +178,11 @@ def test_decompose_ansatz(self): assert np.allclose(ansatz(z), decomp_ansatz(z)) assert np.allclose(decomp_ansatz.A.shape, (1, 2, 2)) + c2 = np.random.uniform(-10, 10, size=(1, 4)) + ansatz2 = Bargmann(A, b, c2) + decomp_ansatz2 = ansatz2.decompose_ansatz() + assert np.allclose(decomp_ansatz2.A, ansatz2.A) + def test_decompose_ansatz_batch(self): """ In this test the batch dimension of both ``z`` and ``Abc`` is tested. @@ -344,15 +349,22 @@ def test_simplify_v2(self): ansatz = ansatz + ansatz - assert np.allclose(ansatz.A[0], ansatz.A[1]) - assert np.allclose(ansatz.A[0], A) - assert np.allclose(ansatz.b[0], ansatz.b[1]) - assert np.allclose(ansatz.b[0], b) + assert math.allclose(ansatz.A[0], ansatz.A[1]) + assert math.allclose(ansatz.A[0], A) + assert math.allclose(ansatz.b[0], ansatz.b[1]) + assert math.allclose(ansatz.b[0], b) ansatz.simplify_v2() assert len(ansatz.A) == 1 assert len(ansatz.b) == 1 - assert np.allclose(ansatz.c, 2 * c) + assert math.allclose(ansatz.c, 2 * c) + + A, b, c = ansatz.triple + + ansatz.simplify_v2() + assert math.allclose(ansatz.A, A) + assert math.allclose(ansatz.b, b) + assert math.allclose(ansatz.c, c) @pytest.mark.parametrize("n", [1, 2, 3]) def test_sub(self, n): From ed1dbb66219d91693afa311dd20e97dd04ff5d31 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 14:25:48 -0400 Subject: [PATCH 31/87] widgets --- mrmustard/widgets/__init__.py | 16 +--- .../test_representations/test_bargmann.py | 91 +++++++++---------- .../test_representations/test_fock.py | 74 +++++++-------- 3 files changed, 82 insertions(+), 99 deletions(-) diff --git a/mrmustard/widgets/__init__.py b/mrmustard/widgets/__init__.py index edb902a77..0d560e3bf 100644 --- a/mrmustard/widgets/__init__.py +++ b/mrmustard/widgets/__init__.py @@ -70,10 +70,7 @@ def fock(rep): header_widget = widgets.HTML("

Fock Representation

") table_widget = widgets.HTML( - TABLE + "" - f"" - f"" - "
Ansatz{rep.ansatz.__class__.__qualname__}
Shape{shape}
" + TABLE + "" f"" "
Shape{shape}
" ) left_widget = widgets.VBox(children=[header_widget, table_widget]) plot_widget.layout.padding = "10px" @@ -120,15 +117,6 @@ def get_abc_str(A, b, c, round_val): round_w = widgets.IntText(value=round_default, description="Rounding (negative -> none):") round_w.style.description_width = "230px" header_w = widgets.HTML("

Bargmann Representation

") - sub_w = widgets.HBox( - [ - widgets.HTML( - '
Ansatz:
' - f"{rep.ansatz.__class__.__qualname__}
" - ), - round_w, - ] - ) triple_w = widgets.HTML(TABLE + triple_fstr.format(*get_abc_str(A, b, c, round_default))) eigs_header_w = widgets.HTML("

Eigenvalues of A

") eigvals_w = go.FigureWidget( @@ -175,7 +163,7 @@ def on_value_change(change): eigs_vbox = widgets.VBox([eigs_header_w, eigvals_w]) return widgets.Box( - [widgets.VBox([header_w, sub_w, triple_w]), eigs_vbox], + [widgets.VBox([header_w, round_w, triple_w]), eigs_vbox], layout=widgets.Layout(max_width="50%", flex_flow="row wrap"), ) diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py index c8bbd6bc9..b903a37d4 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -388,51 +388,46 @@ def test_trace(self): assert np.allclose(bargmann.b, b) assert np.allclose(bargmann.c, c) - # @patch("mrmustard.physics.representations.bargmann.display") - # def test_ipython_repr(self, mock_display): - # """Test the IPython repr function.""" - # rep = Bargmann(*Abc_triple(2)) - # rep._ipython_display_() # pylint:disable=protected-access - # [box] = mock_display.call_args.args - # assert isinstance(box, Box) - # assert box.layout.max_width == "50%" - - # # data on left, eigvals on right - # [data_vbox, eigs_vbox] = box.children - # assert isinstance(data_vbox, VBox) - # assert isinstance(eigs_vbox, VBox) - - # # data forms a stack: header, ansatz, triple data - # [header, sub, table] = data_vbox.children - # assert isinstance(header, HTML) - # assert isinstance(sub, HBox) - # assert isinstance(table, HTML) - - # # ansatz goes beside button to modify rounding - # [ansatz, round_w] = sub.children - # assert isinstance(ansatz, HTML) - # assert isinstance(round_w, IntText) - - # # eigvals have a header and a unit circle plot - # [eig_header, unit_circle] = eigs_vbox.children - # assert isinstance(eig_header, HTML) - # assert isinstance(unit_circle, FigureWidget) - - # @patch("mrmustard.physics.representations.bargmann.display") - # def test_ipython_repr_batched(self, mock_display): - # """Test the IPython repr function for a batched repr.""" - # A1, b1, c1 = Abc_triple(2) - # A2, b2, c2 = Abc_triple(2) - # rep = Bargmann(np.array([A1, A2]), np.array([b1, b2]), np.array([c1, c2])) - # rep._ipython_display_() # pylint:disable=protected-access - # [vbox] = mock_display.call_args.args - # assert isinstance(vbox, VBox) - - # [slider, stack] = vbox.children - # assert isinstance(slider, IntSlider) - # assert slider.max == 1 # the batch size - 1 - # assert isinstance(stack, Stack) - - # # max_width is spot-check that this is bargmann widget - # assert len(stack.children) == 2 - # assert all(box.layout.max_width == "50%" for box in stack.children) + @patch("mrmustard.physics.representations.bargmann.display") + def test_ipython_repr(self, mock_display): + """Test the IPython repr function.""" + rep = Bargmann(*Abc_triple(2)) + rep._ipython_display_() # pylint:disable=protected-access + [box] = mock_display.call_args.args + assert isinstance(box, Box) + assert box.layout.max_width == "50%" + + # data on left, eigvals on right + [data_vbox, eigs_vbox] = box.children + assert isinstance(data_vbox, VBox) + assert isinstance(eigs_vbox, VBox) + + # data forms a stack: header, ansatz, triple data + [header, sub, table] = data_vbox.children + assert isinstance(header, HTML) + assert isinstance(sub, IntText) + assert isinstance(table, HTML) + + # eigvals have a header and a unit circle plot + [eig_header, unit_circle] = eigs_vbox.children + assert isinstance(eig_header, HTML) + assert isinstance(unit_circle, FigureWidget) + + @patch("mrmustard.physics.representations.bargmann.display") + def test_ipython_repr_batched(self, mock_display): + """Test the IPython repr function for a batched repr.""" + A1, b1, c1 = Abc_triple(2) + A2, b2, c2 = Abc_triple(2) + rep = Bargmann(np.array([A1, A2]), np.array([b1, b2]), np.array([c1, c2])) + rep._ipython_display_() # pylint:disable=protected-access + [vbox] = mock_display.call_args.args + assert isinstance(vbox, VBox) + + [slider, stack] = vbox.children + assert isinstance(slider, IntSlider) + assert slider.max == 1 # the batch size - 1 + assert isinstance(stack, Stack) + + # max_width is spot-check that this is bargmann widget + assert len(stack.children) == 2 + assert all(box.layout.max_width == "50%" for box in stack.children) diff --git a/tests/test_physics/test_representations/test_fock.py b/tests/test_physics/test_representations/test_fock.py index 41ffa69b1..ba674aeed 100644 --- a/tests/test_physics/test_representations/test_fock.py +++ b/tests/test_physics/test_representations/test_fock.py @@ -222,40 +222,40 @@ def test_truediv_a_scalar(self): assert isinstance(aa1_scalar, Fock) assert np.allclose(aa1_scalar.array, array / 6) - # @pytest.mark.parametrize("shape", [(1, 8), (1, 8, 8)]) - # @patch("mrmustard.physics.representations.display") - # def test_ipython_repr(self, mock_display, shape): - # """Test the IPython repr function.""" - # rep = Fock(np.random.random(shape), batched=True) - # rep._ipython_display_() # pylint:disable=protected-access - # [hbox] = mock_display.call_args.args - # assert isinstance(hbox, HBox) - - # # the CSS, the header+ansatz, and the tabs of plots - # [css, left, plots] = hbox.children - # assert isinstance(css, HTML) - # assert isinstance(left, VBox) - # assert isinstance(plots, Tab) - - # # left contains header and ansatz - # left = left.children - # assert len(left) == 2 and all(isinstance(w, HTML) for w in left) - - # # one plot for magnitude, another for phase - # assert plots.titles == ("Magnitude", "Phase") - # plots = plots.children - # assert len(plots) == 2 and all(isinstance(p, FigureWidget) for p in plots) - - # @patch("mrmustard.physics.representations.display") - # def test_ipython_repr_expects_batch_1(self, mock_display): - # """Test the IPython repr function does nothing with real batch.""" - # rep = Fock(np.random.random((2, 8)), batched=True) - # rep._ipython_display_() # pylint:disable=protected-access - # mock_display.assert_not_called() - - # @patch("mrmustard.physics.representations.display") - # def test_ipython_repr_expects_3_dims_or_less(self, mock_display): - # """Test the IPython repr function does nothing with 4+ dims.""" - # rep = Fock(np.random.random((1, 4, 4, 4)), batched=True) - # rep._ipython_display_() # pylint:disable=protected-access - # mock_display.assert_not_called() + @pytest.mark.parametrize("shape", [(1, 8), (1, 8, 8)]) + @patch("mrmustard.physics.representations.fock.display") + def test_ipython_repr(self, mock_display, shape): + """Test the IPython repr function.""" + rep = Fock(np.random.random(shape), batched=True) + rep._ipython_display_() # pylint:disable=protected-access + [hbox] = mock_display.call_args.args + assert isinstance(hbox, HBox) + + # the CSS, the header+ansatz, and the tabs of plots + [css, left, plots] = hbox.children + assert isinstance(css, HTML) + assert isinstance(left, VBox) + assert isinstance(plots, Tab) + + # left contains header and ansatz + left = left.children + assert len(left) == 2 and all(isinstance(w, HTML) for w in left) + + # one plot for magnitude, another for phase + assert plots.titles == ("Magnitude", "Phase") + plots = plots.children + assert len(plots) == 2 and all(isinstance(p, FigureWidget) for p in plots) + + @patch("mrmustard.physics.representations.fock.display") + def test_ipython_repr_expects_batch_1(self, mock_display): + """Test the IPython repr function does nothing with real batch.""" + rep = Fock(np.random.random((2, 8)), batched=True) + rep._ipython_display_() # pylint:disable=protected-access + mock_display.assert_not_called() + + @patch("mrmustard.physics.representations.fock.display") + def test_ipython_repr_expects_3_dims_or_less(self, mock_display): + """Test the IPython repr function does nothing with 4+ dims.""" + rep = Fock(np.random.random((1, 4, 4, 4)), batched=True) + rep._ipython_display_() # pylint:disable=protected-access + mock_display.assert_not_called() From b4c4c07696259a98cc27f8dddd6fabd4a14b1758 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 14:29:18 -0400 Subject: [PATCH 32/87] widgets --- tests/test_lab_dev/test_circuit_components.py | 26 +++++++++---------- .../test_representations/test_bargmann.py | 2 +- .../test_representations/test_fock.py | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index aeb4ac73e..7b119d663 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -478,19 +478,19 @@ def test_quadrature_channel(self): back = Channel.from_quadrature([0], [0], C.quadrature_triple()) assert C == back - # @pytest.mark.parametrize("is_fock,widget_cls", [(False, Box), (True, HBox)]) - # @patch("mrmustard.lab_dev.circuit_components.display") - # def test_ipython_repr(self, mock_display, is_fock, widget_cls): - # """Test the IPython repr function.""" - # dgate = Dgate([1], x=0.1, y=0.1) - # if is_fock: - # dgate = dgate.to_fock() - # dgate._ipython_display_() # pylint:disable=protected-access - # [box] = mock_display.call_args.args - # assert isinstance(box, Box) - # [wires_widget, rep_widget] = box.children - # assert isinstance(wires_widget, HTML) - # assert type(rep_widget) is widget_cls + @pytest.mark.parametrize("is_fock,widget_cls", [(False, Box), (True, HBox)]) + @patch("mrmustard.lab_dev.circuit_components.display") + def test_ipython_repr(self, mock_display, is_fock, widget_cls): + """Test the IPython repr function.""" + dgate = Dgate([1], x=0.1, y=0.1) + if is_fock: + dgate = dgate.to_fock() + dgate._ipython_display_() # pylint:disable=protected-access + [box] = mock_display.call_args.args + assert isinstance(box, Box) + [wires_widget, rep_widget] = box.children + assert isinstance(wires_widget, HTML) + assert isinstance(rep_widget, widget_cls) @patch("mrmustard.lab_dev.circuit_components.display") def test_ipython_repr_invalid_obj(self, mock_display): diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py index b903a37d4..d3732f083 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -19,7 +19,7 @@ from unittest.mock import patch import numpy as np -from ipywidgets import Box, HBox, VBox, HTML, IntText, Stack, IntSlider +from ipywidgets import Box, VBox, HTML, IntText, Stack, IntSlider from plotly.graph_objs import FigureWidget import pytest diff --git a/tests/test_physics/test_representations/test_fock.py b/tests/test_physics/test_representations/test_fock.py index ba674aeed..066152d79 100644 --- a/tests/test_physics/test_representations/test_fock.py +++ b/tests/test_physics/test_representations/test_fock.py @@ -19,7 +19,7 @@ from unittest.mock import patch import numpy as np -from ipywidgets import Box, HBox, VBox, HTML, Tab +from ipywidgets import HBox, VBox, HTML, Tab from plotly.graph_objs import FigureWidget import pytest From 5bfb17af0d36f91b62980adf47d3c3a3c831236a Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 14:43:15 -0400 Subject: [PATCH 33/87] cleanup --- .../test_representations/test_bargmann.py | 11 +++-------- tests/test_physics/test_representations/test_fock.py | 11 ++++------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_representations/test_bargmann.py index d3732f083..7a3444a36 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_representations/test_bargmann.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module contains tests for ``Representation`` objects.""" +"""This module contains tests for ``Bargmann`` objects.""" -# pylint: disable = too-many-public-methods +# pylint: disable = too-many-public-methods, missing-function-docstring from unittest.mock import patch @@ -23,7 +23,7 @@ from plotly.graph_objs import FigureWidget import pytest -from mrmustard import math, settings +from mrmustard import math from mrmustard.physics.gaussian_integrals import ( contract_two_Abc, complex_gaussian_integral, @@ -33,11 +33,6 @@ from ...random import Abc_triple -# original settings -autocutoff_max0 = settings.AUTOCUTOFF_MAX_CUTOFF - -# pylint: disable = missing-function-docstring - class TestBargmannRepresentation: r""" diff --git a/tests/test_physics/test_representations/test_fock.py b/tests/test_physics/test_representations/test_fock.py index 066152d79..c9abe2e3e 100644 --- a/tests/test_physics/test_representations/test_fock.py +++ b/tests/test_physics/test_representations/test_fock.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module contains tests for ``Representation`` objects.""" +"""This module contains tests for ``Fock`` objects.""" -# pylint: disable = missing-function-docstring +# pylint: disable = missing-function-docstring, disable=too-many-public-methods from unittest.mock import patch @@ -23,14 +23,11 @@ from plotly.graph_objs import FigureWidget import pytest -from mrmustard import math, settings +from mrmustard import math from mrmustard.physics.representations.fock import Fock -# original settings -autocutoff_max0 = settings.AUTOCUTOFF_MAX_CUTOFF - -class TestFockRepresentation: # pylint:disable=too-many-public-methods +class TestFockRepresentation: r"""Tests the Fock Representation.""" array578 = np.random.random((5, 7, 8)) From b7930f31af1bfa09526b5772eef0fe8ed8ad3f70 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 7 Oct 2024 16:27:15 -0400 Subject: [PATCH 34/87] unused imports --- mrmustard/physics/representations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 938a11ff0..38967647f 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -48,8 +48,6 @@ ) from mrmustard import widgets -from .wires import Wires - __all__ = ["Representation", "Bargmann", "Fock"] From d8c271314ae6cc5cd759ccb35644ea0c43ae41b0 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 10:10:05 -0400 Subject: [PATCH 35/87] ansatz representation --- mrmustard/lab_dev/circuit_components.py | 42 +++---- .../circuit_components_utils/b_to_ps.py | 4 +- .../circuit_components_utils/b_to_q.py | 6 +- .../circuit_components_utils/trace_out.py | 4 +- mrmustard/lab_dev/states/base.py | 32 ++--- mrmustard/lab_dev/states/coherent.py | 9 +- .../lab_dev/states/displaced_squeezed.py | 8 +- mrmustard/lab_dev/states/number.py | 9 +- .../lab_dev/states/quadrature_eigenstate.py | 10 +- mrmustard/lab_dev/states/sauron.py | 8 +- mrmustard/lab_dev/states/squeezed_vacuum.py | 10 +- mrmustard/lab_dev/states/thermal.py | 8 +- .../states/two_mode_squeezed_vacuum.py | 8 +- mrmustard/lab_dev/states/vacuum.py | 4 +- .../lab_dev/transformations/amplifier.py | 8 +- .../lab_dev/transformations/attenuator.py | 9 +- mrmustard/lab_dev/transformations/base.py | 28 ++--- mrmustard/lab_dev/transformations/bsgate.py | 8 +- mrmustard/lab_dev/transformations/cft.py | 10 +- mrmustard/lab_dev/transformations/dgate.py | 13 ++- .../lab_dev/transformations/fockdamping.py | 8 +- mrmustard/lab_dev/transformations/ggate.py | 8 +- mrmustard/lab_dev/transformations/identity.py | 4 +- mrmustard/lab_dev/transformations/rgate.py | 8 +- mrmustard/lab_dev/transformations/s2gate.py | 10 +- mrmustard/lab_dev/transformations/sgate.py | 8 +- .../{representations => ansatz}/__init__.py | 4 +- .../fock.py => ansatz/array_ansatz.py} | 70 +++++------ .../{representations => ansatz}/base.py | 38 +++--- .../bargmann.py => ansatz/polyexp_ansatz.py} | 74 ++++++------ ..._representations.py => representations.py} | 109 +++++++++--------- tests/test_lab_dev/test_circuit_components.py | 32 ++--- .../test_circuit_components_utils.py | 4 +- .../test_lab_dev/test_states/test_thermal.py | 4 +- .../test_transformations/test_dgate.py | 4 +- .../test_transformations_base.py | 13 ++- .../__init__.py | 0 .../test_array_ansatz.py} | 92 +++++++-------- .../test_polyexp_ansatz.py} | 100 ++++++++-------- tests/test_physics/test_triples.py | 6 +- 40 files changed, 425 insertions(+), 409 deletions(-) rename mrmustard/physics/{representations => ansatz}/__init__.py (92%) rename mrmustard/physics/{representations/fock.py => ansatz/array_ansatz.py} (84%) rename mrmustard/physics/{representations => ansatz}/base.py (83%) rename mrmustard/physics/{representations/bargmann.py => ansatz/polyexp_ansatz.py} (94%) rename mrmustard/physics/{multi_representations.py => representations.py} (67%) rename tests/test_physics/{test_representations => test_ansatz}/__init__.py (100%) rename tests/test_physics/{test_representations/test_fock.py => test_ansatz/test_array_ansatz.py} (77%) rename tests/test_physics/{test_representations/test_bargmann.py => test_ansatz/test_polyexp_ansatz.py} (83%) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 8132d167a..9475afbed 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -39,12 +39,12 @@ Vector, Batch, ) -from mrmustard.physics.representations import Representation, Bargmann, Fock +from mrmustard.physics.ansatz import Ansatz, PolyExpAnsatz, ArrayAnsatz from mrmustard.physics.fock_utils import quadrature_basis from mrmustard.math.parameter_set import ParameterSet from mrmustard.math.parameters import Constant, Variable from mrmustard.physics.wires import Wires -from mrmustard.physics.multi_representations import MultiRepresentation +from mrmustard.physics.representations import Representation __all__ = ["CircuitComponent"] @@ -69,7 +69,7 @@ class CircuitComponent: def __init__( self, - representation: Bargmann | Fock | None = None, + representation: PolyExpAnsatz | ArrayAnsatz | None = None, wires: Wires | Sequence[tuple[int]] | None = None, name: str | None = None, ) -> None: @@ -105,12 +105,10 @@ def __init__( + tuple(np.argsort(modes_in_ket) + offsets[2]) ) if representation is not None: - self._multi_rep = MultiRepresentation( - representation.reorder(tuple(perm)), wires - ) + self._multi_rep = Representation(representation.reorder(tuple(perm)), wires) if not hasattr(self, "_multi_rep"): - self._multi_rep = MultiRepresentation(representation, wires) + self._multi_rep = Representation(representation, wires) def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: """ @@ -244,11 +242,11 @@ def parameter_set(self) -> ParameterSet: return self._parameter_set @property - def representation(self) -> Representation | None: + def representation(self) -> Ansatz | None: r""" A representation of this circuit component. """ - return self._multi_rep.representation + return self._multi_rep.ansatz @property def wires(self) -> Wires: @@ -281,7 +279,7 @@ def from_bargmann( Returns: A circuit component with the given Bargmann representation. """ - repr = Bargmann(*triple) + repr = PolyExpAnsatz(*triple) wires = Wires(set(modes_out_bra), set(modes_in_bra), set(modes_out_ket), set(modes_in_ket)) return cls._from_attributes(repr, wires, name) @@ -320,7 +318,7 @@ def from_quadrature( QtoB_ok = BtoQ(modes_out_ket, phi).inverse() # output ket QtoB_ik = BtoQ(modes_in_ket, phi).inverse().dual # input ket # NOTE: the representation is Bargmann here because we use the inverse of BtoQ on the B side - QQQQ = CircuitComponent._from_attributes(Bargmann(*triple), wires) + QQQQ = CircuitComponent._from_attributes(PolyExpAnsatz(*triple), wires) BBBB = QtoB_ib @ (QtoB_ik @ QQQQ @ QtoB_ok) @ QtoB_ob return cls._from_attributes(BBBB.representation, wires, name) @@ -343,7 +341,7 @@ def to_quadrature(self, phi: float = 0.0) -> CircuitComponent: BtoQ_ik = BtoQ(self.wires.input.ket.modes, phi).dual object_to_convert = self - if isinstance(self.representation, Fock): + if isinstance(self.representation, ArrayAnsatz): object_to_convert = self.to_bargmann() QQQQ = BtoQ_ib @ (BtoQ_ik @ object_to_convert @ BtoQ_ok) @ BtoQ_ob @@ -375,7 +373,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor: A circuit component with the given quadrature representation. """ - if isinstance(self.representation, Fock): + if isinstance(self.representation, ArrayAnsatz): fock_arrays = self.representation.array # Find where all the bras and kets are so they can be conjugated appropriately conjugates = [i not in self.wires.ket.indices for i in range(len(self.wires.indices))] @@ -390,7 +388,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor: @classmethod def _from_attributes( cls, - representation: Representation, + representation: Ansatz, wires: Wires, name: str | None = None, ) -> CircuitComponent: @@ -424,7 +422,7 @@ def _from_attributes( if tp.__name__ in types: ret = tp() ret._name = name - ret._multi_rep = MultiRepresentation(representation, wires) + ret._multi_rep = Representation(representation, wires) return ret return CircuitComponent(representation, wires, name) @@ -532,7 +530,7 @@ def to_bargmann(self) -> CircuitComponent: >>> assert d_bargmann.wires == d.wires >>> assert isinstance(d_bargmann.representation, Bargmann) """ - if isinstance(self.representation, Bargmann): + if isinstance(self.representation, PolyExpAnsatz): return self else: mult_rep = self._multi_rep.to_bargmann() @@ -540,7 +538,7 @@ def to_bargmann(self) -> CircuitComponent: ret = self._getitem_builtin(self.modes) ret._multi_rep = mult_rep except TypeError: - ret = self._from_attributes(mult_rep.representation, mult_rep.wires, self.name) + ret = self._from_attributes(mult_rep.ansatz, mult_rep.wires, self.name) if "manual_shape" in ret.__dict__: del ret.manual_shape return ret @@ -571,7 +569,7 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: ret = self._getitem_builtin(self.modes) ret._multi_rep = mult_rep except TypeError: - ret = self._from_attributes(mult_rep.representation, mult_rep.wires, self.name) + ret = self._from_attributes(mult_rep.ansatz, mult_rep.wires, self.name) if "manual_shape" in ret.__dict__: del ret.manual_shape return ret @@ -612,7 +610,7 @@ def _light_copy(self, wires: Wires | None = None) -> CircuitComponent: """ instance = super().__new__(self.__class__) instance.__dict__ = self.__dict__.copy() - instance.__dict__["_multi_rep"] = MultiRepresentation( + instance.__dict__["_multi_rep"] = Representation( self.representation, wires or Wires(*self.wires.args) ) return instance @@ -668,7 +666,7 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: if isinstance(other, (numbers.Number, np.ndarray)): return self * other result = self._multi_rep @ other._multi_rep - return CircuitComponent._from_attributes(result.representation, result.wires, None) + return CircuitComponent._from_attributes(result.ansatz, result.wires, None) def __mul__(self, other: Scalar) -> CircuitComponent: r""" @@ -783,7 +781,9 @@ def __truediv__(self, other: Scalar) -> CircuitComponent: def _ipython_display_(self): # both reps might return None - rep_fn = mmwidgets.fock if isinstance(self.representation, Fock) else mmwidgets.bargmann + rep_fn = ( + mmwidgets.fock if isinstance(self.representation, ArrayAnsatz) else mmwidgets.bargmann + ) rep_widget = rep_fn(self.representation) wires_widget = mmwidgets.wires(self.wires) if not rep_widget: diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index ad9cba741..f1c8e65f7 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -23,7 +23,7 @@ from mrmustard.math.parameters import Constant from ..transformations.base import Map -from ...physics.representations import Bargmann +from ...physics.ansatz import PolyExpAnsatz __all__ = ["BtoPS"] @@ -46,7 +46,7 @@ def __init__( super().__init__( modes_out=modes, modes_in=modes, - representation=Bargmann.from_function( + representation=PolyExpAnsatz.from_function( fn=triples.displacement_map_s_parametrized_Abc, s=s, n_modes=len(modes) ), name="BtoPS", diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index 33855b323..7ead7bea8 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -26,8 +26,8 @@ from mrmustard.math.parameters import Constant from ..transformations.base import Operation -from ...physics.representations import Bargmann -from ...physics.multi_representations import RepEnum +from ...physics.ansatz import PolyExpAnsatz +from ...physics.representations import RepEnum from ..circuit_components import CircuitComponent __all__ = ["BtoQ"] @@ -49,7 +49,7 @@ def __init__( modes: Sequence[int], phi: float = 0.0, ): - repr = Bargmann.from_function( + repr = PolyExpAnsatz.from_function( fn=triples.bargmann_to_quadrature_Abc, n_modes=len(modes), phi=phi ) super().__init__( diff --git a/mrmustard/lab_dev/circuit_components_utils/trace_out.py b/mrmustard/lab_dev/circuit_components_utils/trace_out.py index 39beecaed..665fdaa21 100644 --- a/mrmustard/lab_dev/circuit_components_utils/trace_out.py +++ b/mrmustard/lab_dev/circuit_components_utils/trace_out.py @@ -23,7 +23,7 @@ from mrmustard.physics import triples from ..circuit_components import CircuitComponent -from ...physics.representations import Bargmann +from ...physics.ansatz import PolyExpAnsatz __all__ = ["TraceOut"] @@ -63,7 +63,7 @@ def __init__( ): super().__init__( wires=[(), modes, (), modes], - representation=Bargmann.from_function(fn=triples.identity_Abc, n_modes=len(modes)), + representation=PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)), name="Tr", ) diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index af630afae..3c4505bac 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -55,7 +55,7 @@ ) from mrmustard.math.lattice.strategies.vanilla import autoshape_numba from mrmustard.physics.gaussian import purity -from mrmustard.physics.representations import Bargmann, Fock +from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz from mrmustard.lab_dev.utils import shape_check from mrmustard.physics.bargmann_utils import ( bargmann_Abc_to_phasespace_cov_means, @@ -63,7 +63,7 @@ from mrmustard.lab_dev.circuit_components_utils import BtoPS, BtoQ, TraceOut from mrmustard.lab_dev.circuit_components import CircuitComponent from mrmustard.physics.wires import Wires -from mrmustard.physics.multi_representations import MultiRepresentation +from mrmustard.physics.representations import Representation __all__ = ["State", "DM", "Ket"] @@ -350,7 +350,7 @@ def phase_space(self, s: float) -> tuple: Returns: The covariance matrix, the mean vector and the coefficient of the state in s-parametrized phase space. """ - if not isinstance(self.representation, Bargmann): + if not isinstance(self.representation, PolyExpAnsatz): raise ValueError("Can calculate phase space only for Bargmann states.") new_state = self >> BtoPS(self.modes, s=s) @@ -635,7 +635,7 @@ def visualize_dm( def _ipython_display_(self): # pragma: no cover is_ket = isinstance(self, Ket) - is_fock = isinstance(self.representation, Fock) + is_fock = isinstance(self.representation, ArrayAnsatz) display(widgets.state(self, is_ket=is_ket, is_fock=is_fock)) @@ -654,7 +654,7 @@ class DM(State): def __init__( self, modes: Sequence[int] = (), - representation: Bargmann | Fock | None = None, + representation: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): if representation and representation.num_vars != 2 * len(modes): @@ -665,7 +665,7 @@ def __init__( wires=[modes, (), modes, ()], name=name, ) - self._multi_rep = MultiRepresentation(representation, self.wires) + self._multi_rep = Representation(representation, self.wires) @property def is_positive(self) -> bool: @@ -729,7 +729,7 @@ def from_bargmann( triple: tuple[ComplexMatrix, ComplexVector, complex], name: str | None = None, ) -> State: - return DM(modes, Bargmann(*triple), name) + return DM(modes, PolyExpAnsatz(*triple), name) @classmethod def from_fock( @@ -739,7 +739,7 @@ def from_fock( name: str | None = None, batched: bool = False, ) -> State: - return DM(modes, Fock(array, batched), name) + return DM(modes, ArrayAnsatz(array, batched), name) @classmethod def from_phase_space( @@ -767,7 +767,7 @@ def from_phase_space( shape_check(cov, means, 2 * len(modes), "Phase space") return coeff * DM( modes, - Bargmann.from_function(fn=wigner_to_bargmann_rho, cov=cov, means=means), + PolyExpAnsatz.from_function(fn=wigner_to_bargmann_rho, cov=cov, means=means), name, ) @@ -797,7 +797,7 @@ def from_quadrature( with the number of modes. """ QtoB = BtoQ(modes, phi).inverse() - Q = DM(modes, Bargmann(*triple)) + Q = DM(modes, PolyExpAnsatz(*triple)) return DM(modes, (Q >> QtoB).representation, name) @classmethod @@ -983,7 +983,7 @@ class Ket(State): def __init__( self, modes: Sequence[int] = (), - representation: Bargmann | Fock | None = None, + representation: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): if representation and representation.num_vars != len(modes): @@ -994,7 +994,7 @@ def __init__( wires=[(), (), modes, ()], name=name, ) - self._multi_rep = MultiRepresentation(representation, self.wires) + self._multi_rep = Representation(representation, self.wires) @property def is_physical(self) -> bool: @@ -1034,7 +1034,7 @@ def from_bargmann( triple: tuple[ComplexMatrix, ComplexVector, complex], name: str | None = None, ) -> State: - return Ket(modes, Bargmann(*triple), name) + return Ket(modes, PolyExpAnsatz(*triple), name) @classmethod def from_fock( @@ -1044,7 +1044,7 @@ def from_fock( name: str | None = None, batched: bool = False, ) -> State: - return Ket(modes, Fock(array, batched), name) + return Ket(modes, ArrayAnsatz(array, batched), name) @classmethod def from_phase_space( @@ -1065,7 +1065,7 @@ def from_phase_space( raise ValueError(msg) return Ket( modes, - coeff * Bargmann.from_function(fn=wigner_to_bargmann_psi, cov=cov, means=means), + coeff * PolyExpAnsatz.from_function(fn=wigner_to_bargmann_psi, cov=cov, means=means), name, ) @@ -1078,7 +1078,7 @@ def from_quadrature( name: str | None = None, ) -> State: QtoB = BtoQ(modes, phi).inverse() - Q = Ket(modes, Bargmann(*triple)) + Q = Ket(modes, PolyExpAnsatz(*triple)) return Ket(modes, (Q >> QtoB).representation, name) @classmethod diff --git a/mrmustard/lab_dev/states/coherent.py b/mrmustard/lab_dev/states/coherent.py index d5caa67e6..8243ddb34 100644 --- a/mrmustard/lab_dev/states/coherent.py +++ b/mrmustard/lab_dev/states/coherent.py @@ -20,8 +20,8 @@ from typing import Sequence -from mrmustard.physics.multi_representations import MultiRepresentation -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.representations import Representation +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket from ..utils import make_parameter, reshape_params @@ -83,6 +83,7 @@ def __init__( self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.coherent_state_Abc, x=self.x, y=self.y), self.wires + self._multi_rep = Representation( + PolyExpAnsatz.from_function(fn=triples.coherent_state_Abc, x=self.x, y=self.y), + self.wires, ) diff --git a/mrmustard/lab_dev/states/displaced_squeezed.py b/mrmustard/lab_dev/states/displaced_squeezed.py index a6ac4aaba..56f1770fd 100644 --- a/mrmustard/lab_dev/states/displaced_squeezed.py +++ b/mrmustard/lab_dev/states/displaced_squeezed.py @@ -20,8 +20,8 @@ from typing import Sequence -from mrmustard.physics.multi_representations import MultiRepresentation -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.representations import Representation +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket from ..utils import make_parameter, reshape_params @@ -85,8 +85,8 @@ def __init__( self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function( + self._multi_rep = Representation( + PolyExpAnsatz.from_function( fn=triples.displaced_squeezed_vacuum_state_Abc, x=self.x, y=self.y, diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index 878323f91..30ee38b0b 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -20,8 +20,8 @@ from typing import Sequence -from mrmustard.physics.multi_representations import MultiRepresentation -from mrmustard.physics.representations import Fock +from mrmustard.physics.representations import Representation +from mrmustard.physics.ansatz import ArrayAnsatz from mrmustard.physics.fock_utils import fock_state from .base import Ket from ..utils import make_parameter, reshape_params @@ -75,6 +75,7 @@ def __init__( for i, cutoff in enumerate(self.cutoffs.value): self.manual_shape[i] = int(cutoff) + 1 - self._multi_rep = MultiRepresentation( - Fock.from_function(fock_state, n=self.n.value, cutoffs=self.cutoffs.value), self.wires + self._multi_rep = Representation( + ArrayAnsatz.from_function(fock_state, n=self.n.value, cutoffs=self.cutoffs.value), + self.wires, ) diff --git a/mrmustard/lab_dev/states/quadrature_eigenstate.py b/mrmustard/lab_dev/states/quadrature_eigenstate.py index 72ce0f174..4eb045002 100644 --- a/mrmustard/lab_dev/states/quadrature_eigenstate.py +++ b/mrmustard/lab_dev/states/quadrature_eigenstate.py @@ -22,8 +22,8 @@ import numpy as np -from mrmustard.physics.multi_representations import MultiRepresentation -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.representations import Representation +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket from ..utils import make_parameter, reshape_params @@ -68,8 +68,10 @@ def __init__( xs, phis = list(reshape_params(len(modes), x=x, phi=phi)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.quadrature_eigenstates_Abc, x=self.x, phi=self.phi), + self._multi_rep = Representation( + PolyExpAnsatz.from_function( + fn=triples.quadrature_eigenstates_Abc, x=self.x, phi=self.phi + ), self.wires, ) self.manual_shape = (50,) diff --git a/mrmustard/lab_dev/states/sauron.py b/mrmustard/lab_dev/states/sauron.py index ed15f7ead..827ace06c 100644 --- a/mrmustard/lab_dev/states/sauron.py +++ b/mrmustard/lab_dev/states/sauron.py @@ -16,8 +16,8 @@ from typing import Sequence from mrmustard.lab_dev.states.base import Ket -from mrmustard.physics.multi_representations import MultiRepresentation -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.representations import Representation +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from ..utils import make_parameter @@ -43,8 +43,8 @@ def __init__(self, modes: Sequence[int], n: int, epsilon: float = 0.1): super().__init__(name=f"Sauron-{n}", modes=modes) self._add_parameter(make_parameter(False, n, "n", (None, None), dtype="int64")) self._add_parameter(make_parameter(False, epsilon, "epsilon", (None, None))) - self._multi_rep = MultiRepresentation( - Bargmann.from_function( + self._multi_rep = Representation( + PolyExpAnsatz.from_function( triples.sauron_state_Abc, n=self.n.value, epsilon=self.epsilon.value ), self.wires, diff --git a/mrmustard/lab_dev/states/squeezed_vacuum.py b/mrmustard/lab_dev/states/squeezed_vacuum.py index 5c449bd1e..4c3431018 100644 --- a/mrmustard/lab_dev/states/squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/squeezed_vacuum.py @@ -20,8 +20,8 @@ from typing import Sequence -from mrmustard.physics.multi_representations import MultiRepresentation -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.representations import Representation +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket from ..utils import make_parameter, reshape_params @@ -69,7 +69,9 @@ def __init__( rs, phis = list(reshape_params(len(modes), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.squeezed_vacuum_state_Abc, r=self.r, phi=self.phi), + self._multi_rep = Representation( + PolyExpAnsatz.from_function( + fn=triples.squeezed_vacuum_state_Abc, r=self.r, phi=self.phi + ), self.wires, ) diff --git a/mrmustard/lab_dev/states/thermal.py b/mrmustard/lab_dev/states/thermal.py index 1c9d00e67..8aa37d97f 100644 --- a/mrmustard/lab_dev/states/thermal.py +++ b/mrmustard/lab_dev/states/thermal.py @@ -20,8 +20,8 @@ from typing import Sequence -from mrmustard.physics.multi_representations import MultiRepresentation -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.representations import Representation +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import DM from ..utils import make_parameter, reshape_params @@ -62,6 +62,6 @@ def __init__( super().__init__(modes=modes, name="Thermal") (nbars,) = list(reshape_params(len(modes), nbar=nbar)) self._add_parameter(make_parameter(nbar_trainable, nbars, "nbar", nbar_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.thermal_state_Abc, nbar=self.nbar), self.wires + self._multi_rep = Representation( + PolyExpAnsatz.from_function(fn=triples.thermal_state_Abc, nbar=self.nbar), self.wires ) diff --git a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py index 6a6fcbd27..33a6ce4c4 100644 --- a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py @@ -20,8 +20,8 @@ from typing import Sequence -from mrmustard.physics.multi_representations import MultiRepresentation -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.representations import Representation +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket from ..utils import make_parameter, reshape_params @@ -67,8 +67,8 @@ def __init__( rs, phis = list(reshape_params(int(len(modes) / 2), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function( + self._multi_rep = Representation( + PolyExpAnsatz.from_function( fn=triples.two_mode_squeezed_vacuum_state_Abc, r=self.r, phi=self.phi ), self.wires, diff --git a/mrmustard/lab_dev/states/vacuum.py b/mrmustard/lab_dev/states/vacuum.py index 54f897990..00b3619de 100644 --- a/mrmustard/lab_dev/states/vacuum.py +++ b/mrmustard/lab_dev/states/vacuum.py @@ -20,7 +20,7 @@ from typing import Sequence -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket @@ -60,7 +60,7 @@ def __init__( self, modes: Sequence[int], ) -> None: - rep = Bargmann.from_function(fn=triples.vacuum_state_Abc, n_modes=len(modes)) + rep = PolyExpAnsatz.from_function(fn=triples.vacuum_state_Abc, n_modes=len(modes)) super().__init__(modes=modes, representation=rep, name="Vac") for i in range(len(modes)): diff --git a/mrmustard/lab_dev/transformations/amplifier.py b/mrmustard/lab_dev/transformations/amplifier.py index 223a3de34..1bf3e12cf 100644 --- a/mrmustard/lab_dev/transformations/amplifier.py +++ b/mrmustard/lab_dev/transformations/amplifier.py @@ -21,8 +21,8 @@ from typing import Sequence from .base import Channel -from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann +from ...physics.representations import Representation +from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -96,6 +96,6 @@ def __init__( None, ) ) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.amplifier_Abc, g=self.gain), self.wires + self._multi_rep = Representation( + PolyExpAnsatz.from_function(fn=triples.amplifier_Abc, g=self.gain), self.wires ) diff --git a/mrmustard/lab_dev/transformations/attenuator.py b/mrmustard/lab_dev/transformations/attenuator.py index c728ec582..32b484c2b 100644 --- a/mrmustard/lab_dev/transformations/attenuator.py +++ b/mrmustard/lab_dev/transformations/attenuator.py @@ -21,8 +21,8 @@ from typing import Sequence from .base import Channel -from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann +from ...physics.representations import Representation +from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -96,6 +96,7 @@ def __init__( None, ) ) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.attenuator_Abc, eta=self.transmissivity), self.wires + self._multi_rep = Representation( + PolyExpAnsatz.from_function(fn=triples.attenuator_Abc, eta=self.transmissivity), + self.wires, ) diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index a1b1b9e1c..e1ab29d16 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -29,7 +29,7 @@ from typing import Sequence from mrmustard import math, settings -from mrmustard.physics.representations import Bargmann, Fock +from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz from mrmustard.utils.typing import ComplexMatrix from mrmustard.physics.bargmann_utils import au2Symplectic, symplectic2Au, XY_of_channel from ..circuit_components import CircuitComponent @@ -89,7 +89,7 @@ def inverse(self) -> Transformation: raise NotImplementedError( "Only Transformations with the same number of input and output wires are supported." ) - if not isinstance(self.representation, Bargmann): + if not isinstance(self.representation, PolyExpAnsatz): raise NotImplementedError("Only Bargmann representation is supported.") if self.representation.batch_size > 1: raise NotImplementedError("Batched transformations are not supported.") @@ -97,12 +97,12 @@ def inverse(self) -> Transformation: # compute the inverse A, b, _ = self.dual.representation.conj.triple # apply X(.)X almost_inverse = self._from_attributes( - Bargmann(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 + 0j), self.wires + PolyExpAnsatz(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 + 0j), self.wires ) almost_identity = self @ almost_inverse invert_this_c = almost_identity.representation.c actual_inverse = self._from_attributes( - Bargmann(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 / invert_this_c), + PolyExpAnsatz(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 / invert_this_c), self.wires, self.name + "_inv", ) @@ -121,7 +121,7 @@ def __init__( self, modes_out: tuple[int, ...] = (), modes_in: tuple[int, ...] = (), - representation: Bargmann | Fock | None = None, + representation: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): super().__init__( @@ -138,7 +138,7 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Operation(modes_out, modes_in, Bargmann(*triple), name) + return Operation(modes_out, modes_in, PolyExpAnsatz(*triple), name) @classmethod def from_quadrature( @@ -153,7 +153,7 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Operation(modes_out, modes_in, Bargmann(*triple)) + QQ = Operation(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out return Operation(modes_out, modes_in, BB.representation, name) @@ -188,7 +188,7 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Unitary(modes_out, modes_in, Bargmann(*triple), name) + return Unitary(modes_out, modes_in, PolyExpAnsatz(*triple), name) @classmethod def from_quadrature( @@ -203,7 +203,7 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Unitary(modes_out, modes_in, Bargmann(*triple)) + QQ = Unitary(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out return Unitary(modes_out, modes_in, BB.representation, name) @@ -277,7 +277,7 @@ def __init__( self, modes_out: tuple[int, ...] = (), modes_in: tuple[int, ...] = (), - representation: Bargmann | Fock | None = None, + representation: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): super().__init__( @@ -294,7 +294,7 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Map(modes_out, modes_in, Bargmann(*triple), name) + return Map(modes_out, modes_in, PolyExpAnsatz(*triple), name) @classmethod def from_quadrature( @@ -309,7 +309,7 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Map(modes_out, modes_in, Bargmann(*triple)) + QQ = Map(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out return Map(modes_out, modes_in, BB.representation, name) @@ -382,7 +382,7 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Channel(modes_out, modes_in, Bargmann(*triple), name) + return Channel(modes_out, modes_in, PolyExpAnsatz(*triple), name) @classmethod def from_quadrature( @@ -397,7 +397,7 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Channel(modes_out, modes_in, Bargmann(*triple)) + QQ = Channel(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out return Channel(modes_out, modes_in, BB.representation, name) diff --git a/mrmustard/lab_dev/transformations/bsgate.py b/mrmustard/lab_dev/transformations/bsgate.py index f3807953f..15a89a133 100644 --- a/mrmustard/lab_dev/transformations/bsgate.py +++ b/mrmustard/lab_dev/transformations/bsgate.py @@ -21,8 +21,8 @@ from typing import Sequence from .base import Unitary -from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann +from ...physics.representations import Representation +from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter @@ -105,8 +105,8 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="BSgate") self._add_parameter(make_parameter(theta_trainable, theta, "theta", theta_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function( + self._multi_rep = Representation( + PolyExpAnsatz.from_function( fn=triples.beamsplitter_gate_Abc, theta=self.theta, phi=self.phi ), self.wires, diff --git a/mrmustard/lab_dev/transformations/cft.py b/mrmustard/lab_dev/transformations/cft.py index 6233c10c8..d0ccd7f05 100644 --- a/mrmustard/lab_dev/transformations/cft.py +++ b/mrmustard/lab_dev/transformations/cft.py @@ -18,8 +18,8 @@ from typing import Sequence from mrmustard.lab_dev.transformations.base import Map -from mrmustard.physics.multi_representations import MultiRepresentation -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.representations import Representation +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples __all__ = ["CFT"] @@ -48,7 +48,9 @@ def __init__( modes_in=modes, name="CFT", ) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.complex_fourier_transform_Abc, n_modes=len(modes)), + self._multi_rep = Representation( + PolyExpAnsatz.from_function( + fn=triples.complex_fourier_transform_Abc, n_modes=len(modes) + ), self.wires, ) diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index f730f10d3..e7b756356 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -24,8 +24,8 @@ from mrmustard import math from .base import Unitary -from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann, Fock +from ...physics.representations import Representation +from ...physics.ansatz import PolyExpAnsatz, ArrayAnsatz from ...physics import triples, fock_utils from ..utils import make_parameter, reshape_params @@ -95,8 +95,9 @@ def __init__( xs, ys = list(reshape_params(len(modes), x=x, y=y)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.displacement_gate_Abc, x=self.x, y=self.y), self.wires + self._multi_rep = Representation( + PolyExpAnsatz.from_function(fn=triples.displacement_gate_Abc, x=self.x, y=self.y), + self.wires, ) def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTensor: @@ -144,8 +145,8 @@ def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTenso return arrays def to_fock(self, shape: int | Sequence[int] | None = None) -> Dgate: - fock = Fock(self.fock(shape, batched=True), batched=True) + fock = ArrayAnsatz(self.fock(shape, batched=True), batched=True) fock._original_abc_data = self.representation.triple ret = self._getitem_builtin(self.modes) - ret._multi_rep = MultiRepresentation(fock, self.wires) + ret._multi_rep = Representation(fock, self.wires) return ret diff --git a/mrmustard/lab_dev/transformations/fockdamping.py b/mrmustard/lab_dev/transformations/fockdamping.py index 8d5b38e0c..9d4426d8b 100644 --- a/mrmustard/lab_dev/transformations/fockdamping.py +++ b/mrmustard/lab_dev/transformations/fockdamping.py @@ -21,8 +21,8 @@ from typing import Sequence from .base import Operation -from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann +from ...physics.representations import Representation +from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -86,6 +86,6 @@ def __init__( None, ) ) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.fock_damping_Abc, beta=self.damping), self.wires + self._multi_rep = Representation( + PolyExpAnsatz.from_function(fn=triples.fock_damping_Abc, beta=self.damping), self.wires ) diff --git a/mrmustard/lab_dev/transformations/ggate.py b/mrmustard/lab_dev/transformations/ggate.py index c4572c430..1b953cab0 100644 --- a/mrmustard/lab_dev/transformations/ggate.py +++ b/mrmustard/lab_dev/transformations/ggate.py @@ -22,8 +22,8 @@ from mrmustard.utils.typing import RealMatrix from .base import Unitary -from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann +from ...physics.representations import Representation +from ...physics.ansatz import PolyExpAnsatz from ..utils import make_parameter __all__ = ["Ggate"] @@ -58,8 +58,8 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="Ggate") S = make_parameter(symplectic_trainable, symplectic, "symplectic", (None, None)) self.parameter_set.add_parameter(S) - self._multi_rep = MultiRepresentation( - Bargmann.from_function( + self._multi_rep = Representation( + PolyExpAnsatz.from_function( fn=lambda s: Unitary.from_symplectic(modes, s).bargmann_triple(), s=self.parameter_set.symplectic, ), diff --git a/mrmustard/lab_dev/transformations/identity.py b/mrmustard/lab_dev/transformations/identity.py index 3487a9dbd..ff33738cc 100644 --- a/mrmustard/lab_dev/transformations/identity.py +++ b/mrmustard/lab_dev/transformations/identity.py @@ -21,7 +21,7 @@ from typing import Sequence from .base import Unitary -from ...physics.representations import Bargmann +from ...physics.ansatz import PolyExpAnsatz from ...physics import triples __all__ = ["Identity"] @@ -51,5 +51,5 @@ def __init__( self, modes: Sequence[int], ): - rep = Bargmann.from_function(fn=triples.identity_Abc, n_modes=len(modes)) + rep = PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)) super().__init__(modes_out=modes, modes_in=modes, representation=rep, name="Identity") diff --git a/mrmustard/lab_dev/transformations/rgate.py b/mrmustard/lab_dev/transformations/rgate.py index ff120f1b5..9e58106d3 100644 --- a/mrmustard/lab_dev/transformations/rgate.py +++ b/mrmustard/lab_dev/transformations/rgate.py @@ -21,8 +21,8 @@ from typing import Sequence from .base import Unitary -from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann +from ...physics.representations import Representation +from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -63,6 +63,6 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="Rgate") (phis,) = list(reshape_params(len(modes), phi=phi)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.rotation_gate_Abc, theta=self.phi), self.wires + self._multi_rep = Representation( + PolyExpAnsatz.from_function(fn=triples.rotation_gate_Abc, theta=self.phi), self.wires ) diff --git a/mrmustard/lab_dev/transformations/s2gate.py b/mrmustard/lab_dev/transformations/s2gate.py index 0f4dd9dfd..d5c292744 100644 --- a/mrmustard/lab_dev/transformations/s2gate.py +++ b/mrmustard/lab_dev/transformations/s2gate.py @@ -21,8 +21,8 @@ from typing import Sequence from .base import Unitary -from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann +from ...physics.representations import Representation +from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter @@ -88,7 +88,9 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="S2gate") self._add_parameter(make_parameter(r_trainable, r, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.twomode_squeezing_gate_Abc, r=self.r, phi=self.phi), + self._multi_rep = Representation( + PolyExpAnsatz.from_function( + fn=triples.twomode_squeezing_gate_Abc, r=self.r, phi=self.phi + ), self.wires, ) diff --git a/mrmustard/lab_dev/transformations/sgate.py b/mrmustard/lab_dev/transformations/sgate.py index 227bc78ff..c4d0a1233 100644 --- a/mrmustard/lab_dev/transformations/sgate.py +++ b/mrmustard/lab_dev/transformations/sgate.py @@ -21,8 +21,8 @@ from typing import Sequence from .base import Unitary -from ...physics.multi_representations import MultiRepresentation -from ...physics.representations import Bargmann +from ...physics.representations import Representation +from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -95,7 +95,7 @@ def __init__( rs, phis = list(reshape_params(len(modes), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = MultiRepresentation( - Bargmann.from_function(fn=triples.squeezing_gate_Abc, r=self.r, delta=self.phi), + self._multi_rep = Representation( + PolyExpAnsatz.from_function(fn=triples.squeezing_gate_Abc, r=self.r, delta=self.phi), self.wires, ) diff --git a/mrmustard/physics/representations/__init__.py b/mrmustard/physics/ansatz/__init__.py similarity index 92% rename from mrmustard/physics/representations/__init__.py rename to mrmustard/physics/ansatz/__init__.py index a75d9469e..dbb2cee1b 100644 --- a/mrmustard/physics/representations/__init__.py +++ b/mrmustard/physics/ansatz/__init__.py @@ -17,5 +17,5 @@ """ from .base import * -from .bargmann import * -from .fock import * +from .polyexp_ansatz import * +from .array_ansatz import * diff --git a/mrmustard/physics/representations/fock.py b/mrmustard/physics/ansatz/array_ansatz.py similarity index 84% rename from mrmustard/physics/representations/fock.py rename to mrmustard/physics/ansatz/array_ansatz.py index b248bcdb9..9091bc8d7 100644 --- a/mrmustard/physics/representations/fock.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -14,7 +14,7 @@ """ -This module contains the Fock representation. +This module contains the array ansatz. """ from __future__ import annotations @@ -30,12 +30,12 @@ from mrmustard import math, widgets from mrmustard.utils.typing import Batch, Scalar, Tensor, Vector -from .base import Representation +from .base import Ansatz -__all__ = ["Fock"] +__all__ = ["ArrayAnsatz"] -class Fock(Representation): +class ArrayAnsatz(Ansatz): r""" The Fock representation of a broad class of quantum states, transformations, measurements, channels, etc. @@ -109,7 +109,7 @@ def batch_size(self): @property def conj(self): - ret = Fock(math.conj(self.array), batched=True) + ret = ArrayAnsatz(math.conj(self.array), batched=True) ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access return ret @@ -142,17 +142,17 @@ def triple(self) -> tuple: return self._original_abc_data @classmethod - def from_dict(cls, data: dict[str, ArrayLike]) -> Fock: + def from_dict(cls, data: dict[str, ArrayLike]) -> ArrayAnsatz: return cls(data["array"], batched=True) @classmethod - def from_function(cls, fn: Callable, **kwargs: Any) -> Fock: + def from_function(cls, fn: Callable, **kwargs: Any) -> ArrayAnsatz: ret = cls(None, True) ret._fn = fn ret._kwargs = kwargs return ret - def reduce(self, shape: int | Sequence[int]) -> Fock: + def reduce(self, shape: int | Sequence[int]) -> ArrayAnsatz: r""" Returns a new ``Fock`` with a sliced array. @@ -195,27 +195,27 @@ def reduce(self, shape: int | Sequence[int]) -> Fock: self.array, [(0, 0)] + [(0, s - t) for s, t in zip(shape, self.array.shape[1:])], ) - return Fock(padded, batched=True) + return ArrayAnsatz(padded, batched=True) ret = self.array[(slice(0, None),) + tuple(slice(0, s) for s in shape)] - return Fock(array=ret, batched=True) + return ArrayAnsatz(array=ret, batched=True) - def reorder(self, order: tuple[int, ...] | list[int]) -> Fock: - return Fock(math.transpose(self.array, [0] + [i + 1 for i in order]), batched=True) + def reorder(self, order: tuple[int, ...] | list[int]) -> ArrayAnsatz: + return ArrayAnsatz(math.transpose(self.array, [0] + [i + 1 for i in order]), batched=True) - def sum_batch(self) -> Fock: + def sum_batch(self) -> ArrayAnsatz: r""" Sums over the batch dimension of the array. Turns an object with any batch size to a batch size of 1. Returns: The collapsed Fock object. """ - return Fock(math.expand_dims(math.sum(self.array, axes=[0]), 0), batched=True) + return ArrayAnsatz(math.expand_dims(math.sum(self.array, axes=[0]), 0), batched=True) def to_dict(self) -> dict[str, ArrayLike]: return {"array": self.data} - def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Fock: + def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> ArrayAnsatz: if len(idxs1) != len(idxs2) or not set(idxs1).isdisjoint(idxs2): raise ValueError("idxs must be of equal length and disjoint") order = ( @@ -228,7 +228,7 @@ def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Fock: n = np.prod(new_array.shape[-len(idxs2) :]) new_array = math.reshape(new_array, new_array.shape[: -2 * len(idxs1)] + (n, n)) trace = math.trace(new_array) - return Fock([trace] if trace.shape == () else trace, batched=True) + return ArrayAnsatz([trace] if trace.shape == () else trace, batched=True) def _generate_ansatz(self): if self._array is None: @@ -241,7 +241,7 @@ def _ipython_display_(self): return display(w) - def __add__(self, other: Fock) -> Fock: + def __add__(self, other: ArrayAnsatz) -> ArrayAnsatz: try: diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) if diff < 0: @@ -252,35 +252,35 @@ def __add__(self, other: Fock) -> Fock: new_array = [ a + b for a in self.array for b in other.reduce(self.array.shape[1:]).array ] - return Fock(array=new_array, batched=True) + return ArrayAnsatz(array=new_array, batched=True) except Exception as e: raise TypeError(f"Cannot add {self.__class__} and {other.__class__}.") from e - def __and__(self, other: Fock) -> Fock: + def __and__(self, other: ArrayAnsatz) -> ArrayAnsatz: new_array = [math.outer(a, b) for a in self.array for b in other.array] - return Fock(array=new_array, batched=True) + return ArrayAnsatz(array=new_array, batched=True) def __call__(self, z: Batch[Vector]) -> Scalar: raise AttributeError("Cannot call Fock.") - def __eq__(self, other: Representation) -> bool: + def __eq__(self, other: Ansatz) -> bool: slices = (slice(0, None),) + tuple( slice(0, min(si, oi)) for si, oi in zip(self.array.shape[1:], other.array.shape[1:]) ) return np.allclose(self.array[slices], other.array[slices], atol=1e-10) - def __getitem__(self, idx: int | tuple[int, ...]) -> Fock: + def __getitem__(self, idx: int | tuple[int, ...]) -> ArrayAnsatz: idx = (idx,) if isinstance(idx, int) else idx for i in idx: if i >= self.num_vars: raise IndexError( f"Index {i} out of bounds for representation with {self.num_vars} variables." ) - ret = Fock(self.array, batched=True) + ret = ArrayAnsatz(self.array, batched=True) ret._contract_idxs = idx return ret - def __matmul__(self, other: Fock) -> Fock: + def __matmul__(self, other: ArrayAnsatz) -> ArrayAnsatz: idx_s = list(self._contract_idxs) idx_o = list(other._contract_idxs) @@ -306,10 +306,10 @@ def __matmul__(self, other: Fock) -> Fock: for i in range(n_batches_s): for j in range(n_batches_o): batched_array.append(math.tensordot(reduced_s.array[i], reduced_o.array[j], axes)) - return Fock(batched_array, batched=True) + return ArrayAnsatz(batched_array, batched=True) - def __mul__(self, other: Scalar | Fock) -> Fock: - if isinstance(other, Fock): + def __mul__(self, other: Scalar | ArrayAnsatz) -> ArrayAnsatz: + if isinstance(other, ArrayAnsatz): try: diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) if diff < 0: @@ -320,11 +320,11 @@ def __mul__(self, other: Scalar | Fock) -> Fock: new_array = [ a * b for a in self.array for b in other.reduce(self.array.shape[1:]).array ] - return Fock(array=new_array, batched=True) + return ArrayAnsatz(array=new_array, batched=True) except Exception as e: raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e else: - ret = Fock(array=self.array * other, batched=True) + ret = ArrayAnsatz(array=self.array * other, batched=True) ret._original_abc_data = ( tuple(i * j for i, j in zip(self._original_abc_data, (1, 1, other))) if self._original_abc_data is not None @@ -332,11 +332,11 @@ def __mul__(self, other: Scalar | Fock) -> Fock: ) return ret - def __neg__(self) -> Fock: - return Fock(array=-self.array, batched=True) + def __neg__(self) -> ArrayAnsatz: + return ArrayAnsatz(array=-self.array, batched=True) - def __truediv__(self, other: Scalar | Fock) -> Fock: - if isinstance(other, Fock): + def __truediv__(self, other: Scalar | ArrayAnsatz) -> ArrayAnsatz: + if isinstance(other, ArrayAnsatz): try: diff = sum(self.array.shape[1:]) - sum(other.array.shape[1:]) if diff < 0: @@ -347,11 +347,11 @@ def __truediv__(self, other: Scalar | Fock) -> Fock: new_array = [ a / b for a in self.array for b in other.reduce(self.array.shape[1:]).array ] - return Fock(array=new_array, batched=True) + return ArrayAnsatz(array=new_array, batched=True) except Exception as e: raise TypeError(f"Cannot divide {self.__class__} and {other.__class__}.") from e else: - ret = Fock(array=self.array / other, batched=True) + ret = ArrayAnsatz(array=self.array / other, batched=True) ret._original_abc_data = ( tuple(i / j for i, j in zip(self._original_abc_data, (1, 1, other))) if self._original_abc_data is not None diff --git a/mrmustard/physics/representations/base.py b/mrmustard/physics/ansatz/base.py similarity index 83% rename from mrmustard/physics/representations/base.py rename to mrmustard/physics/ansatz/base.py index 74adf755a..593be5efd 100644 --- a/mrmustard/physics/representations/base.py +++ b/mrmustard/physics/ansatz/base.py @@ -14,7 +14,7 @@ """ -This module contains the base representation class. +This module contains the base ansatz class. """ from __future__ import annotations @@ -33,10 +33,10 @@ Vector, ) -__all__ = ["Representation"] +__all__ = ["Ansatz"] -class Representation(ABC): +class Ansatz(ABC): r""" A base class for representations. """ @@ -55,7 +55,7 @@ def batch_size(self) -> int: @property @abstractmethod - def conj(self) -> Representation: + def conj(self) -> Ansatz: r""" The conjugate of the representation. """ @@ -94,20 +94,20 @@ def triple( @classmethod @abstractmethod - def from_dict(cls, data: dict[str, ArrayLike]) -> Representation: + def from_dict(cls, data: dict[str, ArrayLike]) -> Ansatz: r""" Deserialize a Representation. """ @classmethod @abstractmethod - def from_function(cls, fn: Callable, **kwargs: Any) -> Representation: + def from_function(cls, fn: Callable, **kwargs: Any) -> Ansatz: r""" Returns a representation from a function and kwargs. """ @abstractmethod - def reorder(self, order: tuple[int, ...] | list[int]) -> Representation: + def reorder(self, order: tuple[int, ...] | list[int]) -> Ansatz: r""" Reorders the representation indices. """ @@ -119,7 +119,7 @@ def to_dict(self) -> dict[str, ArrayLike]: """ @abstractmethod - def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Representation: + def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Ansatz: r""" Implements the partial trace over the given index pairs. @@ -139,7 +139,7 @@ def _generate_ansatz(self): """ @abstractmethod - def __add__(self, other: Representation) -> Representation: + def __add__(self, other: Ansatz) -> Ansatz: r""" Adds this representation and another representation. @@ -151,7 +151,7 @@ def __add__(self, other: Representation) -> Representation: """ @abstractmethod - def __and__(self, other: Representation) -> Representation: + def __and__(self, other: Ansatz) -> Ansatz: r""" Tensor product of this representation with another. @@ -163,7 +163,7 @@ def __and__(self, other: Representation) -> Representation: """ @abstractmethod - def __call__(self, z: Batch[Vector]) -> Scalar | Representation: + def __call__(self, z: Batch[Vector]) -> Scalar | Ansatz: r""" Evaluates this representation at a given point in the domain. @@ -175,19 +175,19 @@ def __call__(self, z: Batch[Vector]) -> Scalar | Representation: """ @abstractmethod - def __eq__(self, other: Representation) -> bool: + def __eq__(self, other: Ansatz) -> bool: r""" Whether this representation is equal to another. """ @abstractmethod - def __getitem__(self, idx: int | tuple[int, ...]) -> Representation: + def __getitem__(self, idx: int | tuple[int, ...]) -> Ansatz: r""" Returns a copy of self with the given indices marked for contraction. """ @abstractmethod - def __matmul__(self, other: Representation) -> Representation: + def __matmul__(self, other: Ansatz) -> Ansatz: r""" Implements the inner product of representations over the marked indices. @@ -199,7 +199,7 @@ def __matmul__(self, other: Representation) -> Representation: """ @abstractmethod - def __mul__(self, other: Scalar | Representation) -> Representation: + def __mul__(self, other: Scalar | Ansatz) -> Ansatz: r""" Multiplies this representation by a scalar or another representation. @@ -214,18 +214,18 @@ def __mul__(self, other: Scalar | Representation) -> Representation: """ @abstractmethod - def __neg__(self) -> Representation: + def __neg__(self) -> Ansatz: r""" Negates the values in the representation. """ - def __rmul__(self, other: Representation | Scalar) -> Representation: + def __rmul__(self, other: Ansatz | Scalar) -> Ansatz: r""" Multiplies this representation by another or by a scalar on the right. """ return self.__mul__(other) - def __sub__(self, other: Representation) -> Representation: + def __sub__(self, other: Ansatz) -> Ansatz: r""" Subtracts other from this representation. """ @@ -235,7 +235,7 @@ def __sub__(self, other: Representation) -> Representation: raise TypeError(f"Cannot subtract {self.__class__} and {other.__class__}.") from e @abstractmethod - def __truediv__(self, other: Scalar | Representation) -> Representation: + def __truediv__(self, other: Scalar | Ansatz) -> Ansatz: r""" Divides this representation by another representation. diff --git a/mrmustard/physics/representations/bargmann.py b/mrmustard/physics/ansatz/polyexp_ansatz.py similarity index 94% rename from mrmustard/physics/representations/bargmann.py rename to mrmustard/physics/ansatz/polyexp_ansatz.py index bd44861d3..d5d09f98d 100644 --- a/mrmustard/physics/representations/bargmann.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -50,13 +50,13 @@ from mrmustard.utils.argsort import argsort_gen -from .base import Representation +from .base import Ansatz -__all__ = ["Bargmann"] +__all__ = ["PolyExpAnsatz"] # pylint: disable=too-many-instance-attributes -class Bargmann(Representation): +class PolyExpAnsatz(Ansatz): r""" The Fock-Bargmann representation of a broad class of quantum states, transformations, measurements, channels, etc. @@ -212,7 +212,7 @@ def c(self, value): @property def conj(self): - ret = Bargmann(math.conj(self.A), math.conj(self.b), math.conj(self.c)) + ret = PolyExpAnsatz(math.conj(self.A), math.conj(self.b), math.conj(self.c)) ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access return ret @@ -251,17 +251,17 @@ def triple( return self.A, self.b, self.c @classmethod - def from_dict(cls, data: dict[str, ArrayLike]) -> Bargmann: + def from_dict(cls, data: dict[str, ArrayLike]) -> PolyExpAnsatz: return cls(**data) @classmethod - def from_function(cls, fn: Callable, **kwargs: Any) -> Bargmann: + def from_function(cls, fn: Callable, **kwargs: Any) -> PolyExpAnsatz: ret = cls(None, None, None) ret._fn = fn ret._kwargs = kwargs return ret - def decompose_ansatz(self) -> Bargmann: + def decompose_ansatz(self) -> PolyExpAnsatz: r""" This method decomposes a Bargmann representation. Given a representation of dimensions: A=(batch,n+m,n+m), b=(batch,n+m), c = (batch,k_1,k_2,...,k_m), @@ -285,9 +285,9 @@ def decompose_ansatz(self) -> Bargmann: b_decomp.append(b_decomp_i) c_decomp.append(c_decomp_i) - return Bargmann(A_decomp, b_decomp, c_decomp) + return PolyExpAnsatz(A_decomp, b_decomp, c_decomp) else: - return Bargmann(self.A, self.b, self.c) + return PolyExpAnsatz(self.A, self.b, self.c) def plot( self, @@ -348,9 +348,9 @@ def plot( plt.show(block=False) return fig, ax - def reorder(self, order: tuple[int, ...] | list[int]) -> Bargmann: + def reorder(self, order: tuple[int, ...] | list[int]) -> PolyExpAnsatz: A, b, c = reorder_abc(self.triple, order) - return Bargmann(A, b, c) + return PolyExpAnsatz(A, b, c) def simplify(self) -> None: r""" @@ -401,16 +401,16 @@ def simplify_v2(self) -> None: def to_dict(self) -> dict[str, ArrayLike]: return {"A": self.A, "b": self.b, "c": self.c} - def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Bargmann: + def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> PolyExpAnsatz: A, b, c = [], [], [] for Abc in zip(self.A, self.b, self.c): Aij, bij, cij = complex_gaussian_integral(Abc, idxs1, idxs2, measure=-1.0) A.append(Aij) b.append(bij) c.append(cij) - return Bargmann(A, b, c) + return PolyExpAnsatz(A, b, c) - def _call_all(self, z: Batch[Vector]) -> Bargmann: + def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz: r""" Value of this representation at ``z``. If ``z`` is batched a value of the function at each of the batches are returned. If ``Abc`` is batched it is thought of as a linear combination, and thus the results are added linearly together. @@ -475,7 +475,7 @@ def _call_all(self, z: Batch[Vector]) -> Bargmann: ) # (b_arg) return val - def _call_none(self, z: Batch[Vector]) -> Bargmann: + def _call_none(self, z: Batch[Vector]) -> PolyExpAnsatz: r""" Returns a new ansatz that corresponds to currying (partially evaluate) the current one. For example, if ``self`` represents the function ``F(z1,z2)``, the call ``self._call_none([np.array([1.0, None]])`` @@ -506,7 +506,7 @@ def _call_none(self, z: Batch[Vector]) -> Bargmann: "Batch size of the ansatz and argument must match or one of the batch sizes must be 1." ) A, b, c = zip(*Abc) - return Bargmann(A=A, b=b, c=c) + return PolyExpAnsatz(A=A, b=b, c=c) def _call_none_single(self, Ai, bi, ci, zi): r""" @@ -599,7 +599,7 @@ def _decompose_ansatz_single(self, Ai, bi, ci): b_decomp = math.concat((bi[:dim_alpha], math.zeros((dim_alpha), dtype=bi.dtype)), axis=0) return A_decomp, b_decomp, c_decomp - def _equal_no_array(self, other: Bargmann) -> bool: + def _equal_no_array(self, other: PolyExpAnsatz) -> bool: self.simplify() other.simplify() return np.allclose(self.b, other.b, atol=1e-10) and np.allclose(self.A, other.A, atol=1e-10) @@ -650,7 +650,7 @@ def _order_batch(self): self.b = math.gather(self.b, sorted_indices, axis=0) self.c = math.gather(self.c, sorted_indices, axis=0) - def __add__(self, other: Bargmann) -> Bargmann: + def __add__(self, other: PolyExpAnsatz) -> PolyExpAnsatz: r""" Adds two Bargmann representations together. This means concatenating them in the batch dimension. In the case where c is a polynomial of different shapes it will add padding zeros to make @@ -689,11 +689,11 @@ def __add__(self, other: Bargmann) -> Bargmann: a1_new = np.pad(other.c, padding_tuple1, "constant") combined_arrays = math.concat([a0_new, a1_new], axis=0) # note output is not simplified - return Bargmann(combined_matrices, combined_vectors, combined_arrays) + return PolyExpAnsatz(combined_matrices, combined_vectors, combined_arrays) except Exception as e: raise TypeError(f"Cannot add {self.__class__} and {other.__class__}.") from e - def __and__(self, other: Bargmann) -> Bargmann: + def __and__(self, other: PolyExpAnsatz) -> PolyExpAnsatz: r""" Tensor product of this Bargmann with another Bargmann. Equivalent to :math:`F(a) * G(b)` (with different arguments, that is). @@ -778,9 +778,9 @@ def andc(c1, c2): ] bs = [andb(b1, b2, dim_alpha1, dim_alpha2) for b1, b2 in itertools.product(self.b, other.b)] cs = [andc(c1, c2) for c1, c2 in itertools.product(self.c, other.c)] - return Bargmann(As, bs, cs) + return PolyExpAnsatz(As, bs, cs) - def __call__(self, z: Batch[Vector]) -> Scalar | Bargmann: + def __call__(self, z: Batch[Vector]) -> Scalar | PolyExpAnsatz: r""" Returns either the value of the representation or a new representation depending on the argument. If the argument contains None, returns a new representation. @@ -798,21 +798,21 @@ def __call__(self, z: Batch[Vector]) -> Scalar | Bargmann: else: return self._call_all(z) - def __eq__(self, other: Bargmann) -> bool: + def __eq__(self, other: PolyExpAnsatz) -> bool: return self._equal_no_array(other) and np.allclose(self.c, other.c, atol=1e-10) - def __getitem__(self, idx: int | tuple[int, ...]) -> Bargmann: + def __getitem__(self, idx: int | tuple[int, ...]) -> PolyExpAnsatz: idx = (idx,) if isinstance(idx, int) else idx for i in idx: if i >= self.num_vars: raise IndexError( f"Index {i} out of bounds for representation of dimension {self.num_vars}." ) - ret = Bargmann(self.A, self.b, self.c) + ret = PolyExpAnsatz(self.A, self.b, self.c) ret._contract_idxs = idx return ret - def __matmul__(self, other: Bargmann) -> Bargmann: + def __matmul__(self, other: PolyExpAnsatz) -> PolyExpAnsatz: idx_s = self._contract_idxs idx_o = other._contract_idxs @@ -832,9 +832,9 @@ def __matmul__(self, other: Bargmann) -> Bargmann: Abc.append(contract_two_Abc_poly((A1, b1, c1), (A2, b2, c2), idx_s, idx_o)) A, b, c = zip(*Abc) - return Bargmann(A, b, c) + return PolyExpAnsatz(A, b, c) - def __mul__(self, other: Scalar | Bargmann) -> Bargmann: + def __mul__(self, other: Scalar | PolyExpAnsatz) -> PolyExpAnsatz: def mul_A(A1, A2, dim_alpha, dim_beta1, dim_beta2): A3 = math.block( [ @@ -868,7 +868,7 @@ def mul_c(c1, c2): c3 = math.reshape(math.outer(c1, c2), (c1.shape + c2.shape)) return c3 - if isinstance(other, Bargmann): + if isinstance(other, PolyExpAnsatz): dim_beta1, _ = self.polynomial_shape dim_beta2, _ = other.polynomial_shape @@ -891,17 +891,17 @@ def mul_c(c1, c2): new_b = [mul_b(b1, b2, dim_alpha) for b1, b2 in itertools.product(self.b, other.b)] new_c = [mul_c(c1, c2) for c1, c2 in itertools.product(self.c, other.c)] - return Bargmann(A=new_a, b=new_b, c=new_c) + return PolyExpAnsatz(A=new_a, b=new_b, c=new_c) else: try: - return Bargmann(self.A, self.b, self.c * other) + return PolyExpAnsatz(self.A, self.b, self.c * other) except Exception as e: raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e - def __neg__(self) -> Bargmann: - return Bargmann(self.A, self.b, -self.c) + def __neg__(self) -> PolyExpAnsatz: + return PolyExpAnsatz(self.A, self.b, -self.c) - def __truediv__(self, other: Scalar | Bargmann) -> Bargmann: + def __truediv__(self, other: Scalar | PolyExpAnsatz) -> PolyExpAnsatz: def div_A(A1, A2, dim_alpha, dim_beta1, dim_beta2): A3 = math.block( [ @@ -935,7 +935,7 @@ def div_c(c1, c2): c3 = math.reshape(math.outer(c1, c2), (c1.shape + c2.shape)) return c3 - if isinstance(other, Bargmann): + if isinstance(other, PolyExpAnsatz): dim_beta1, _ = self.polynomial_shape dim_beta2, _ = other.polynomial_shape if dim_beta1 == 0 and dim_beta2 == 0: @@ -958,11 +958,11 @@ def div_c(c1, c2): new_b = [div_b(b1, -b2, dim_alpha) for b1, b2 in itertools.product(self.b, other.b)] new_c = [div_c(c1, 1 / c2) for c1, c2 in itertools.product(self.c, other.c)] - return Bargmann(A=new_a, b=new_b, c=new_c) + return PolyExpAnsatz(A=new_a, b=new_b, c=new_c) else: raise NotImplementedError("Only implemented if both c are scalars") else: try: - return Bargmann(self.A, self.b, self.c / other) + return PolyExpAnsatz(self.A, self.b, self.c / other) except Exception as e: raise TypeError(f"Cannot divide {self.__class__} and {other.__class__}.") from e diff --git a/mrmustard/physics/multi_representations.py b/mrmustard/physics/representations.py similarity index 67% rename from mrmustard/physics/multi_representations.py rename to mrmustard/physics/representations.py index 488242448..c9afdd2f0 100644 --- a/mrmustard/physics/multi_representations.py +++ b/mrmustard/physics/representations.py @@ -1,4 +1,4 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. +# Copyright 2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ """ -This module contains the class for multi-representations. +This module contains the class for representations. """ from __future__ import annotations @@ -29,11 +29,11 @@ Batch, ) -from .representations import Representation, Bargmann, Fock +from .ansatz import Ansatz, PolyExpAnsatz, ArrayAnsatz from .triples import identity_Abc from .wires import Wires -__all__ = ["MultiRepresentation"] +__all__ = ["Representation"] class RepEnum(Enum): @@ -48,11 +48,16 @@ class RepEnum(Enum): PHASESPACE = 4 @classmethod - def from_representation(cls, value: Representation): + def from_ansatz(cls, value: Ansatz): r""" - Returns a ``RepEnum`` from a ``Representation``. + Returns a ``RepEnum`` from an ``Ansatz``. """ - return cls[value.__class__.__name__.upper()] + if isinstance(value, PolyExpAnsatz): + return cls(1) + elif isinstance(value, ArrayAnsatz): + return cls(2) + else: + return cls(0) @classmethod def _missing_(cls, value): @@ -62,44 +67,40 @@ def __repr__(self) -> str: return self.name -class MultiRepresentation: - # TODO: merge current Representation and Anstaz -> Ansatz - # TODO: rename to Representation +class Representation: r""" - A class for multi-representations. + A class for representations. - A multi-representation handles the underlying representation, the wires of - said representation and keeps track of representation conversions. + A representation handles the underlying ansatz, wires and keeps track + of each wire's representation. Args: - representation: A representation for this multi-representation. - wires: The wires of this multi-representation. + ansatz: An ansatz for this representation. + wires: The wires of this representation. wire_reps: An optional dictionary for keeping track of each wire's representation. """ def __init__( self, - representation: Representation | None, + ansatz: Ansatz | None, wires: Wires | None, wire_reps: dict | None = None, ) -> None: - self._representation = representation + self._ansatz = ansatz self._wires = wires - self._wire_reps = wire_reps or dict.fromkeys( - wires.modes, RepEnum.from_representation(representation) - ) + self._wire_reps = wire_reps or dict.fromkeys(wires.modes, RepEnum.from_ansatz(ansatz)) @property - def representation(self) -> Representation | None: + def ansatz(self) -> Ansatz | None: r""" - The underlying representation of this multi-representation. + The underlying ansatz of this representation. """ - return self._representation + return self._ansatz @property def wires(self) -> Wires | None: r""" - The wires of this multi-representation. + The wires of this representation. """ return self._wires @@ -107,7 +108,7 @@ def bargmann_triple( self, batched: bool = False ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: r""" - The Bargmann parametrization of this multi-representation, if available. + The Bargmann parametrization of this representation, if available. It returns a triple (A, b, c) such that the Bargmann function of this is :math:`F(z) = c \exp\left(\frac{1}{2} z^T A z + b^T z\right)` @@ -117,8 +118,8 @@ def bargmann_triple( batched: Whether to return the triple batched. """ try: - A, b, c = self.representation.triple - if not batched and self.representation.batch_size == 1: + A, b, c = self.ansatz.triple + if not batched and self.ansatz.batch_size == 1: return A[0], b[0], c[0] else: return A, b, c @@ -139,7 +140,7 @@ def fock(self, shape: int | Sequence[int], batched=False) -> ComplexTensor: Returns: array: The Fock representation of this component. """ - num_vars = self.representation.num_vars + num_vars = self.ansatz.num_vars if isinstance(shape, int): shape = (shape,) * num_vars try: @@ -148,7 +149,7 @@ def fock(self, shape: int | Sequence[int], batched=False) -> ComplexTensor: raise ValueError( f"Expected Fock shape of length {num_vars}, got length {len(shape)}" ) - if self.representation.polynomial_shape[0] == 0: + if self.ansatz.polynomial_shape[0] == 0: arrays = [math.hermite_renormalized(A, b, c, shape) for A, b, c in zip(As, bs, cs)] else: arrays = [ @@ -165,46 +166,44 @@ def fock(self, shape: int | Sequence[int], batched=False) -> ComplexTensor: raise ValueError( f"Expected Fock shape of length {num_vars}, got length {len(shape)}" ) - arrays = self.representation.reduce(shape).array + arrays = self.ansatz.reduce(shape).array array = math.sum(arrays, axes=[0]) arrays = math.expand_dims(array, 0) if batched else array return arrays - def to_bargmann(self) -> MultiRepresentation: + def to_bargmann(self) -> Representation: r""" Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation. """ - if isinstance(self.representation, Bargmann): + if isinstance(self.ansatz, PolyExpAnsatz): return self else: - if self.representation._original_abc_data: - A, b, c = self.representation._original_abc_data + if self.ansatz._original_abc_data: + A, b, c = self.ansatz._original_abc_data else: A, b, _ = identity_Abc(len(self.wires.quantum)) - c = self.representation.data - bargmann = Bargmann(A, b, c) - return MultiRepresentation(bargmann, self.wires) + c = self.ansatz.data + bargmann = PolyExpAnsatz(A, b, c) + return Representation(bargmann, self.wires) - def to_fock(self, shape: int | Sequence[int]) -> MultiRepresentation: + def to_fock(self, shape: int | Sequence[int]) -> Representation: r""" - Returns a new multi-representation with a ``Fock`` representation. + Returns a new representation with an ``ArrayAnsatz``. Args: shape: The shape of the returned representation. If ``shape``is given as an ``int``, it is broadcasted to all the dimensions. If ``None``, it defaults to the value of ``AUTOSHAPE_MAX`` in the settings. """ - fock = Fock(self.fock(shape, batched=True), batched=True) + fock = ArrayAnsatz(self.fock(shape, batched=True), batched=True) try: - if self.representation.polynomial_shape[0] == 0: - fock._original_abc_data = self.representation.triple + if self.ansatz.polynomial_shape[0] == 0: + fock._original_abc_data = self.ansatz.triple except AttributeError: fock._original_abc_data = None - return MultiRepresentation(fock, self.wires) + return Representation(fock, self.wires) - def _matmul_indices( - self, other: MultiRepresentation - ) -> tuple[tuple[int, ...], tuple[int, ...]]: + def _matmul_indices(self, other: Representation) -> tuple[tuple[int, ...], tuple[int, ...]]: r""" Finds the indices of the wires being contracted when ``self @ other`` is called. """ @@ -219,24 +218,24 @@ def _matmul_indices( return idx_z, idx_zconj def __eq__(self, other): - if isinstance(other, MultiRepresentation): + if isinstance(other, Representation): return ( - self.representation == other.representation + self.ansatz == other.ansatz and self.wires == other.wires and self._wire_reps == other._wire_reps ) return False - def __matmul__(self, other: MultiRepresentation): + def __matmul__(self, other: Representation): wires_result, perm = self.wires @ other.wires idx_z, idx_zconj = self._matmul_indices(other) - if type(self.representation) is type(other.representation): - self_rep = self.representation - other_rep = other.representation + if type(self.ansatz) is type(other.ansatz): + self_rep = self.ansatz + other_rep = other.ansatz else: - self_rep = self.to_bargmann().representation - other_rep = other.to_bargmann().representation + self_rep = self.to_bargmann().ansatz + other_rep = other.to_bargmann().ansatz rep = self_rep[idx_z] @ other_rep[idx_zconj] rep = rep.reorder(perm) if perm else rep - return MultiRepresentation(rep, wires_result) + return Representation(rep, wires_result) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 9abcb0455..82d28f476 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -25,7 +25,7 @@ from mrmustard import math, settings from mrmustard.math.parameters import Constant, Variable from mrmustard.physics.triples import displacement_gate_Abc -from mrmustard.physics.representations import Bargmann, Fock +from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz from mrmustard.lab_dev.circuit_components import CircuitComponent from mrmustard.lab_dev.states import ( Ket, @@ -38,7 +38,7 @@ ) from mrmustard.lab_dev.transformations import Dgate, Attenuator, Unitary, Sgate, Channel from mrmustard.physics.wires import Wires -from mrmustard.physics.multi_representations import MultiRepresentation +from mrmustard.physics.representations import Representation from ..random import Abc_triple @@ -56,7 +56,7 @@ class TestCircuitComponent: @pytest.mark.parametrize("y", [0.4, [0.5, 0.6]]) def test_init(self, x, y): name = "my_component" - representation = Bargmann(*displacement_gate_Abc(x, y)) + representation = PolyExpAnsatz(*displacement_gate_Abc(x, y)) cc = CircuitComponent(representation, wires=[(), (), (1, 8), (1, 8)], name=name) assert cc.name == name @@ -67,21 +67,21 @@ def test_init(self, x, y): def test_missing_name(self): cc = CircuitComponent( - Bargmann(*displacement_gate_Abc(0.1, 0.2)), wires=[(), (), (1, 8), (1, 8)] + PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.2)), wires=[(), (), (1, 8), (1, 8)] ) cc._name = None assert cc.name == "CC18" def test_from_bargmann(self): cc = CircuitComponent.from_bargmann(displacement_gate_Abc(0.1, 0.2), {}, {}, {0}, {0}) - assert cc.representation == Bargmann(*displacement_gate_Abc(0.1, 0.2)) + assert cc.representation == PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.2)) def test_modes_init_out_of_order(self): m1 = (8, 1) m2 = (1, 8) - r1 = Bargmann(*displacement_gate_Abc(x=[0.1, 0.2])) - r2 = Bargmann(*displacement_gate_Abc(x=[0.2, 0.1])) + r1 = PolyExpAnsatz(*displacement_gate_Abc(x=[0.1, 0.2])) + r2 = PolyExpAnsatz(*displacement_gate_Abc(x=[0.2, 0.1])) cc1 = CircuitComponent(r1, wires=[(), (), m1, m1]) cc2 = CircuitComponent(r2, wires=[(), (), m2, m2]) @@ -155,7 +155,7 @@ def test_dual(self): def test_light_copy(self): d1 = CircuitComponent( - Bargmann(*displacement_gate_Abc(0.1, 0.1)), wires=[(), (), (1,), (1,)] + PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.1)), wires=[(), (), (1,), (1,)] ) d1_cp = d1._light_copy() @@ -185,17 +185,17 @@ def test_on_error(self): def test_to_fock_ket(self): vac = Vacuum([1, 2]) vac_fock = vac.to_fock(shape=[1, 2]) - assert vac_fock.representation == Fock(np.array([[1], [0]])) + assert vac_fock.representation == ArrayAnsatz(np.array([[1], [0]])) def test_to_fock_Number(self): num = Number([3], n=4) num_f = num.to_fock(shape=(6,)) - assert num_f.representation == Fock(np.array([0, 0, 0, 0, 1, 0])) + assert num_f.representation == ArrayAnsatz(np.array([0, 0, 0, 0, 1, 0])) def test_to_fock_Dgate(self): d = Dgate([1], x=0.1, y=0.1) d_fock = d.to_fock(shape=(4, 6)) - assert d_fock.representation == Fock( + assert d_fock.representation == ArrayAnsatz( math.hermite_renormalized(*displacement_gate_Abc(x=0.1, y=0.1), shape=(4, 6)) ) @@ -209,7 +209,7 @@ def test_to_fock_bargmann_Dgate(self): def test_to_fock_poly_exp(self): A, b, _ = Abc_triple(3) c = np.random.random((1, 5)) - barg = Bargmann(A, b, c) + barg = PolyExpAnsatz(A, b, c) fock_cc = CircuitComponent(barg, wires=[(), (), (0, 1), ()]).to_fock(shape=(10, 10)) poly = math.hermite_renormalized(A, b, 1, (10, 10, 5)) assert fock_cc.representation._original_abc_data is None @@ -400,7 +400,7 @@ def test_rshift_bargmann_and_fock(self, shape): def test_rshift_error(self): vac012 = Vacuum([0, 1, 2]) d0 = Dgate([0], x=0.1, y=0.1) - d0._multi_rep = MultiRepresentation(d0.representation, Wires()) + d0._multi_rep = Representation(d0.representation, Wires()) with pytest.raises(ValueError, match="not clear"): vac012 >> d0 @@ -507,13 +507,13 @@ def test_ipython_repr_invalid_obj(self, mock_display): def test_serialize_default_behaviour(self): """Test the default serializer.""" name = "my_component" - rep = Bargmann(*displacement_gate_Abc(0.1, 0.4)) + rep = PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.4)) cc = CircuitComponent(rep, wires=[(), (), (1, 8), (1, 8)], name=name) kwargs, arrays = cc._serialize() assert kwargs == { "class": f"{CircuitComponent.__module__}.CircuitComponent", "wires": cc.wires.sorted_args, - "rep_class": f"{Bargmann.__module__}.Bargmann", + "rep_class": f"{PolyExpAnsatz.__module__}.PolyExpAnsatz", "name": name, } assert arrays == {"A": rep.A, "b": rep.b, "c": rep.c} @@ -527,7 +527,7 @@ class MyComponent(CircuitComponent): def __init__(self, rep, custom_modes): super().__init__(rep, wires=[custom_modes] * 4, name="my_component") - cc = MyComponent(Bargmann(*displacement_gate_Abc(0.1, 0.4)), [0, 1]) + cc = MyComponent(PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.4)), [0, 1]) with pytest.raises( TypeError, match="MyComponent does not seem to have any wires construction method" ): diff --git a/tests/test_lab_dev/test_circuit_components_utils.py b/tests/test_lab_dev/test_circuit_components_utils.py index a6bcea70f..f7f1f0fcb 100644 --- a/tests/test_lab_dev/test_circuit_components_utils.py +++ b/tests/test_lab_dev/test_circuit_components_utils.py @@ -29,7 +29,7 @@ join_Abc, join_Abc_real, ) -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.lab_dev.circuit_components_utils import TraceOut, BtoPS, BtoQ from mrmustard.lab_dev.circuit_components import CircuitComponent from mrmustard.lab_dev.states import Coherent, DM @@ -51,7 +51,7 @@ def test_init(self, modes): assert tr.name == "Tr" assert tr.wires == Wires(modes_in_bra=set(modes), modes_in_ket=set(modes)) - assert tr.representation == Bargmann(*identity_Abc(len(modes))) + assert tr.representation == PolyExpAnsatz(*identity_Abc(len(modes))) def test_trace_out_bargmann_states(self): state = Coherent([0, 1, 2], x=1) diff --git a/tests/test_lab_dev/test_states/test_thermal.py b/tests/test_lab_dev/test_states/test_thermal.py index ad61a5121..968a0dd0e 100644 --- a/tests/test_lab_dev/test_states/test_thermal.py +++ b/tests/test_lab_dev/test_states/test_thermal.py @@ -18,7 +18,7 @@ import pytest -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics.triples import thermal_state_Abc from mrmustard.lab_dev.states import Thermal @@ -49,7 +49,7 @@ def test_init_error(self): @pytest.mark.parametrize("nbar", [1, [2, 3], [4, 4]]) def test_representation(self, nbar): rep = Thermal([0, 1], nbar).representation - exp = Bargmann(*thermal_state_Abc([nbar, nbar] if isinstance(nbar, int) else nbar)) + exp = PolyExpAnsatz(*thermal_state_Abc([nbar, nbar] if isinstance(nbar, int) else nbar)) assert rep == exp def test_representation_error(self): diff --git a/tests/test_lab_dev/test_transformations/test_dgate.py b/tests/test_lab_dev/test_transformations/test_dgate.py index 9148988bf..d72ed4719 100644 --- a/tests/test_lab_dev/test_transformations/test_dgate.py +++ b/tests/test_lab_dev/test_transformations/test_dgate.py @@ -20,7 +20,7 @@ import numpy as np from mrmustard import math from mrmustard.lab_dev import Dgate, SqueezedVacuum -from mrmustard.physics.representations import Fock +from mrmustard.physics.ansatz import ArrayAnsatz class TestDgate: @@ -85,7 +85,7 @@ def test_trainable_parameters(self): assert gate3.y.value == 2 gate_fock = gate3.to_fock() - assert isinstance(gate_fock.representation, Fock) + assert isinstance(gate_fock.representation, ArrayAnsatz) assert gate_fock.y.value == 2 def test_representation_error(self): diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 2b9369345..0d091c0c1 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -83,11 +83,12 @@ def test_repr(self): u_component = CircuitComponent._from_attributes( unitary1.representation, unitary1.wires, unitary1.name ) # pylint: disable=protected-access - assert repr(unitary1) == "Dgate(modes=[0, 1], name=Dgate, repr=Bargmann)" - assert repr(unitary1.to_fock(5)) == "Dgate(modes=[0, 1], name=Dgate, repr=Fock)" - assert repr(u_component) == "CircuitComponent(modes=[0, 1], name=Dgate, repr=Bargmann)" + assert repr(unitary1) == "Dgate(modes=[0, 1], name=Dgate, repr=PolyExpAnsatz)" + assert repr(unitary1.to_fock(5)) == "Dgate(modes=[0, 1], name=Dgate, repr=ArrayAnsatz)" + assert repr(u_component) == "CircuitComponent(modes=[0, 1], name=Dgate, repr=PolyExpAnsatz)" assert ( - repr(u_component.to_fock(5)) == "CircuitComponent(modes=[0, 1], name=Dgate, repr=Fock)" + repr(u_component.to_fock(5)) + == "CircuitComponent(modes=[0, 1], name=Dgate, repr=ArrayAnsatz)" ) def test_init_from_bargmann(self): @@ -167,8 +168,8 @@ def test_repr(self): channel1.representation, channel1.wires, channel1.name ) # pylint: disable=protected-access - assert repr(channel1) == "Attenuator(modes=[0, 1], name=Att, repr=Bargmann)" - assert repr(ch_component) == "CircuitComponent(modes=[0, 1], name=Att, repr=Bargmann)" + assert repr(channel1) == "Attenuator(modes=[0, 1], name=Att, repr=PolyExpAnsatz)" + assert repr(ch_component) == "CircuitComponent(modes=[0, 1], name=Att, repr=PolyExpAnsatz)" def test_inverse_channel(self): gate = Sgate([0], 0.1, 0.2) >> Dgate([0], 0.1, 0.2) >> Attenuator([0], 0.5) diff --git a/tests/test_physics/test_representations/__init__.py b/tests/test_physics/test_ansatz/__init__.py similarity index 100% rename from tests/test_physics/test_representations/__init__.py rename to tests/test_physics/test_ansatz/__init__.py diff --git a/tests/test_physics/test_representations/test_fock.py b/tests/test_physics/test_ansatz/test_array_ansatz.py similarity index 77% rename from tests/test_physics/test_representations/test_fock.py rename to tests/test_physics/test_ansatz/test_array_ansatz.py index c9abe2e3e..5fecae9bf 100644 --- a/tests/test_physics/test_representations/test_fock.py +++ b/tests/test_physics/test_ansatz/test_array_ansatz.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module contains tests for ``Fock`` objects.""" +"""This module contains tests for ``ArrayAnsatz`` objects.""" # pylint: disable = missing-function-docstring, disable=too-many-public-methods @@ -24,11 +24,11 @@ import pytest from mrmustard import math -from mrmustard.physics.representations.fock import Fock +from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz -class TestFockRepresentation: - r"""Tests the Fock Representation.""" +class TestArrayAnsatz: + r"""Tests the array ansatz.""" array578 = np.random.random((5, 7, 8)) array1578 = np.random.random((1, 5, 7, 8)) @@ -36,19 +36,19 @@ class TestFockRepresentation: array5578 = np.random.random((5, 5, 7, 8)) def test_init_batched(self): - fock = Fock(self.array1578, batched=True) - assert isinstance(fock, Fock) + fock = ArrayAnsatz(self.array1578, batched=True) + assert isinstance(fock, ArrayAnsatz) assert np.allclose(fock.array, self.array1578) def test_init_non_batched(self): - fock = Fock(self.array578, batched=False) - assert isinstance(fock, Fock) + fock = ArrayAnsatz(self.array578, batched=False) + assert isinstance(fock, ArrayAnsatz) assert fock.array.shape == (1, 5, 7, 8) assert np.allclose(fock.array[0, :, :, :], self.array578) def test_add(self): - fock1 = Fock(self.array2578, batched=True) - fock2 = Fock(self.array5578, batched=True) + fock1 = ArrayAnsatz(self.array2578, batched=True) + fock2 = ArrayAnsatz(self.array5578, batched=True) fock1_add_fock2 = fock1 + fock2 assert fock1_add_fock2.array.shape == (10, 5, 7, 8) assert np.allclose(fock1_add_fock2.array[0], self.array2578[0] + self.array5578[0]) @@ -58,8 +58,8 @@ def test_add(self): def test_algebra_with_different_shape_of_array_raise_errors(self): array = np.random.random((2, 4, 5)) array2 = np.random.random((3, 4, 8, 9)) - aa1 = Fock(array=array) - aa2 = Fock(array=array2) + aa1 = ArrayAnsatz(array=array) + aa2 = ArrayAnsatz(array=array2) with pytest.raises(Exception, match="Cannot add"): aa1 + aa2 # pylint: disable=pointless-statement @@ -77,8 +77,8 @@ def test_algebra_with_different_shape_of_array_raise_errors(self): aa1 == aa2 # pylint: disable=pointless-statement def test_and(self): - fock1 = Fock(self.array1578, batched=True) - fock2 = Fock(self.array5578, batched=True) + fock1 = ArrayAnsatz(self.array1578, batched=True) + fock2 = ArrayAnsatz(self.array5578, batched=True) fock_test = fock1 & fock2 assert fock_test.array.shape == (5, 5, 7, 8, 5, 7, 8) assert np.allclose( @@ -87,30 +87,30 @@ def test_and(self): ) def test_call(self): - fock = Fock(self.array1578, batched=True) + fock = ArrayAnsatz(self.array1578, batched=True) with pytest.raises(AttributeError, match="Cannot call"): fock(0) def test_conj(self): - fock = Fock(self.array1578, batched=True) + fock = ArrayAnsatz(self.array1578, batched=True) fock_conj = fock.conj assert np.allclose(fock_conj.array, np.conj(self.array1578)) def test_divide_on_a_scalar(self): - fock1 = Fock(self.array1578, batched=True) + fock1 = ArrayAnsatz(self.array1578, batched=True) fock_test = fock1 / 1.5 assert np.allclose(fock_test.array, self.array1578 / 1.5) def test_equal(self): array = np.random.random((2, 4, 5)) - aa1 = Fock(array=array) - aa2 = Fock(array=array) + aa1 = ArrayAnsatz(array=array) + aa2 = ArrayAnsatz(array=array) assert aa1 == aa2 def test_matmul_fock_fock(self): array2 = math.astensor(np.random.random((5, 6, 7, 8, 10))) - fock1 = Fock(self.array2578, batched=True) - fock2 = Fock(array2, batched=True) + fock1 = ArrayAnsatz(self.array2578, batched=True) + fock2 = ArrayAnsatz(array2, batched=True) fock_test = fock1[2] @ fock2[2] assert fock_test.array.shape == (10, 5, 7, 6, 7, 10) assert np.allclose( @@ -119,8 +119,8 @@ def test_matmul_fock_fock(self): ) def test_mul(self): - fock1 = Fock(self.array1578, batched=True) - fock2 = Fock(self.array5578, batched=True) + fock1 = ArrayAnsatz(self.array1578, batched=True) + fock2 = ArrayAnsatz(self.array5578, batched=True) fock1_mul_fock2 = fock1 * fock2 assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) assert np.allclose( @@ -129,37 +129,37 @@ def test_mul(self): ) def test_multiply_a_scalar(self): - fock1 = Fock(self.array1578, batched=True) + fock1 = ArrayAnsatz(self.array1578, batched=True) fock_test = 1.3 * fock1 assert np.allclose(fock_test.array, 1.3 * self.array1578) def test_neg(self): array = np.random.random((2, 4, 5)) - aa = Fock(array=array) + aa = ArrayAnsatz(array=array) minusaa = -aa - assert isinstance(minusaa, Fock) + assert isinstance(minusaa, ArrayAnsatz) assert np.allclose(minusaa.array, -array) @pytest.mark.parametrize("batched", [True, False]) def test_reduce(self, batched): shape = (1, 3, 3, 3) if batched else (3, 3, 3) array1 = math.astensor(np.arange(27).reshape(shape)) - fock1 = Fock(array1, batched=batched) + fock1 = ArrayAnsatz(array1, batched=batched) fock2 = fock1.reduce(3) assert fock1 == fock2 fock3 = fock1.reduce(2) array3 = math.astensor([[[0, 1], [3, 4]], [[9, 10], [12, 13]]]) - assert fock3 == Fock(array3) + assert fock3 == ArrayAnsatz(array3) fock4 = fock1.reduce((1, 3, 1)) array4 = math.astensor([[[0], [3], [6]]]) - assert fock4 == Fock(array4) + assert fock4 == ArrayAnsatz(array4) def test_reduce_error(self): array1 = math.astensor(np.arange(27).reshape((3, 3, 3))) - fock1 = Fock(array1) + fock1 = ArrayAnsatz(array1) with pytest.raises(ValueError, match="Expected shape"): fock1.reduce((1, 2)) @@ -168,21 +168,21 @@ def test_reduce_error(self): fock1.reduce((1, 2, 3, 4, 5)) def test_reduce_padded(self): - fock = Fock(self.array578) + fock = ArrayAnsatz(self.array578) with pytest.warns(UserWarning): fock1 = fock.reduce((8, 8, 8)) assert fock1.array.shape == (1, 8, 8, 8) def test_reorder(self): array1 = math.astensor(np.arange(8).reshape((1, 2, 2, 2))) - fock1 = Fock(array1, batched=True) + fock1 = ArrayAnsatz(array1, batched=True) fock2 = fock1.reorder(order=(2, 1, 0)) assert np.allclose(fock2.array, np.array([[[[0, 4], [2, 6]], [[1, 5], [3, 7]]]])) assert np.allclose(fock2.array, np.arange(8).reshape((1, 2, 2, 2), order="F")) def test_sub(self): - fock1 = Fock(self.array2578, batched=True) - fock2 = Fock(self.array5578, batched=True) + fock1 = ArrayAnsatz(self.array2578, batched=True) + fock2 = ArrayAnsatz(self.array5578, batched=True) fock1_sub_fock2 = fock1 - fock2 assert fock1_sub_fock2.array.shape == (10, 5, 7, 8) assert np.allclose(fock1_sub_fock2.array[0], self.array2578[0] - self.array5578[0]) @@ -190,21 +190,21 @@ def test_sub(self): assert np.allclose(fock1_sub_fock2.array[9], self.array2578[1] - self.array5578[4]) def test_sum_batch(self): - fock = Fock(self.array2578, batched=True) + fock = ArrayAnsatz(self.array2578, batched=True) fock_collapsed = fock.sum_batch()[0] assert fock_collapsed.array.shape == (1, 5, 7, 8) assert np.allclose(fock_collapsed.array, np.sum(self.array2578, axis=0)) def test_trace(self): array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) - fock1 = Fock(array1, batched=True) + fock1 = ArrayAnsatz(array1, batched=True) fock2 = fock1.trace(idxs1=[0, 3], idxs2=[1, 6]) assert fock2.array.shape == (2, 1, 4, 1, 3) assert np.allclose(fock2.array, np.einsum("bccefghfj -> beghj", array1)) def test_truediv(self): - fock1 = Fock(self.array1578, batched=True) - fock2 = Fock(self.array5578, batched=True) + fock1 = ArrayAnsatz(self.array1578, batched=True) + fock2 = ArrayAnsatz(self.array5578, batched=True) fock1_mul_fock2 = fock1 / fock2 assert fock1_mul_fock2.array.shape == (5, 5, 7, 8) assert np.allclose( @@ -214,16 +214,16 @@ def test_truediv(self): def test_truediv_a_scalar(self): array = np.random.random((2, 4, 5)) - aa1 = Fock(array=array) + aa1 = ArrayAnsatz(array=array) aa1_scalar = aa1 / 6 - assert isinstance(aa1_scalar, Fock) + assert isinstance(aa1_scalar, ArrayAnsatz) assert np.allclose(aa1_scalar.array, array / 6) @pytest.mark.parametrize("shape", [(1, 8), (1, 8, 8)]) - @patch("mrmustard.physics.representations.fock.display") + @patch("mrmustard.physics.ansatz.array_ansatz.display") def test_ipython_repr(self, mock_display, shape): """Test the IPython repr function.""" - rep = Fock(np.random.random(shape), batched=True) + rep = ArrayAnsatz(np.random.random(shape), batched=True) rep._ipython_display_() # pylint:disable=protected-access [hbox] = mock_display.call_args.args assert isinstance(hbox, HBox) @@ -243,16 +243,16 @@ def test_ipython_repr(self, mock_display, shape): plots = plots.children assert len(plots) == 2 and all(isinstance(p, FigureWidget) for p in plots) - @patch("mrmustard.physics.representations.fock.display") + @patch("mrmustard.physics.ansatz.array_ansatz.display") def test_ipython_repr_expects_batch_1(self, mock_display): """Test the IPython repr function does nothing with real batch.""" - rep = Fock(np.random.random((2, 8)), batched=True) + rep = ArrayAnsatz(np.random.random((2, 8)), batched=True) rep._ipython_display_() # pylint:disable=protected-access mock_display.assert_not_called() - @patch("mrmustard.physics.representations.fock.display") + @patch("mrmustard.physics.ansatz.array_ansatz.display") def test_ipython_repr_expects_3_dims_or_less(self, mock_display): """Test the IPython repr function does nothing with 4+ dims.""" - rep = Fock(np.random.random((1, 4, 4, 4)), batched=True) + rep = ArrayAnsatz(np.random.random((1, 4, 4, 4)), batched=True) rep._ipython_display_() # pylint:disable=protected-access mock_display.assert_not_called() diff --git a/tests/test_physics/test_representations/test_bargmann.py b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py similarity index 83% rename from tests/test_physics/test_representations/test_bargmann.py rename to tests/test_physics/test_ansatz/test_polyexp_ansatz.py index 7a3444a36..b01d335d5 100644 --- a/tests/test_physics/test_representations/test_bargmann.py +++ b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module contains tests for ``Bargmann`` objects.""" +"""This module contains tests for ``PolyExpAnsatz`` objects.""" # pylint: disable = too-many-public-methods, missing-function-docstring @@ -28,15 +28,15 @@ contract_two_Abc, complex_gaussian_integral, ) -from mrmustard.physics.representations.bargmann import Bargmann -from mrmustard.physics.representations.fock import Fock +from mrmustard.physics.ansatz.polyexp_ansatz import PolyExpAnsatz +from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz from ...random import Abc_triple -class TestBargmannRepresentation: +class TestPolyExpAnsatz: r""" - Tests the Bargmann Representation. + Tests the polyexp ansatz. """ Abc_n1 = Abc_triple(1) @@ -46,7 +46,7 @@ class TestBargmannRepresentation: @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) def test_init_non_batched(self, triple): A, b, c = triple - bargmann = Bargmann(*triple) + bargmann = PolyExpAnsatz(*triple) assert np.allclose(bargmann.A, A) assert np.allclose(bargmann.b, b) @@ -57,8 +57,8 @@ def test_add(self, n): triple1 = Abc_triple(n) triple2 = Abc_triple(n) - bargmann1 = Bargmann(*triple1) - bargmann2 = Bargmann(*triple2) + bargmann1 = PolyExpAnsatz(*triple1) + bargmann2 = PolyExpAnsatz(*triple2) bargmann_add = bargmann1 + bargmann2 assert np.allclose(bargmann_add.A, math.concat([bargmann1.A, bargmann2.A], axis=0)) @@ -70,8 +70,8 @@ def test_add(self, n): A2, b2, _ = Abc_triple(5) c2 = np.random.random(size=(1, 2, 2)) - bargmann3 = Bargmann(A1, b1, c1) - bargmann4 = Bargmann(A2, b2, c2) + bargmann3 = PolyExpAnsatz(A1, b1, c1) + bargmann4 = PolyExpAnsatz(A2, b2, c2) bargmann_add2 = bargmann3 + bargmann4 @@ -83,8 +83,8 @@ def test_add(self, n): assert np.allclose(bargmann_add2.c[1][:2, :2], c2[0]) def test_add_error(self): - bargmann = Bargmann(*Abc_triple(3)) - fock = Fock(np.random.random((1, 4, 4, 4)), batched=True) + bargmann = PolyExpAnsatz(*Abc_triple(3)) + fock = ArrayAnsatz(np.random.random((1, 4, 4, 4)), batched=True) with pytest.raises(TypeError, match="Cannot add"): bargmann + fock # pylint: disable=pointless-statement @@ -94,7 +94,7 @@ def test_and(self, n): triple1 = Abc_triple(n) triple2 = Abc_triple(n) - bargmann = Bargmann(*triple1) & Bargmann(*triple2) + bargmann = PolyExpAnsatz(*triple1) & PolyExpAnsatz(*triple2) assert bargmann.A.shape == (1, 2 * n, 2 * n) assert bargmann.b.shape == (1, 2 * n) @@ -102,13 +102,13 @@ def test_and(self, n): def test_call(self): A, b, c = Abc_triple(5) - ansatz = Bargmann(A, b, c) + ansatz = PolyExpAnsatz(A, b, c) assert np.allclose(ansatz(z=math.zeros_like(b)), c) A, b, _ = Abc_triple(4) c = np.random.random(size=(1, 3, 3, 3)) - ansatz = Bargmann(A, b, c) + ansatz = PolyExpAnsatz(A, b, c) z = np.random.uniform(-10, 10, size=(7, 2)) with pytest.raises( Exception, match="The sum of the dimension of the argument and polynomial" @@ -119,7 +119,7 @@ def test_call(self): b = np.zeros(2) c = c = np.zeros(10, dtype=complex).reshape(1, -1) c[0, -1] = 1 - obj1 = Bargmann(A, b, c) + obj1 = PolyExpAnsatz(A, b, c) nine_factorial = np.prod(np.arange(1, 9)) assert np.allclose(obj1(np.array([[0.1]])), 0.1**9 / np.sqrt(nine_factorial)) @@ -132,7 +132,7 @@ def test_call_none(self): batch = 3 c = np.random.random(size=(batch, 5, 5, 5)) / 1000 - obj = Bargmann([A1, A2, A3], [b1, b2, b3], c) + obj = PolyExpAnsatz([A1, A2, A3], [b1, b2, b3], c) z0 = np.array([[None, 2, None, 5]]) z1 = np.array([[1, 2, 4, 5]]) z2 = np.array([[1, 4]]) @@ -141,13 +141,17 @@ def test_call_none(self): val2 = obj_none(z2) assert np.allclose(val1, val2) - obj1 = Bargmann(A1, b1, c[0].reshape(1, 5, 5, 5)) + obj1 = PolyExpAnsatz(A1, b1, c[0].reshape(1, 5, 5, 5)) z0 = np.array([[None, 2, None, 5], [None, 1, None, 4]]) z1 = np.array([[1, 2, 4, 5], [2, 1, 4, 4]]) z2 = np.array([[1, 4], [2, 4]]) obj1_none = obj1(z0) - obj1_none0 = Bargmann(obj1_none.A[0], obj1_none.b[0], obj1_none.c[0].reshape(1, 5, 5, 5)) - obj1_none1 = Bargmann(obj1_none.A[1], obj1_none.b[1], obj1_none.c[1].reshape(1, 5, 5, 5)) + obj1_none0 = PolyExpAnsatz( + obj1_none.A[0], obj1_none.b[0], obj1_none.c[0].reshape(1, 5, 5, 5) + ) + obj1_none1 = PolyExpAnsatz( + obj1_none.A[1], obj1_none.b[1], obj1_none.c[1].reshape(1, 5, 5, 5) + ) val1 = obj1(z1) val2 = np.array( (obj1_none0(z2[0].reshape(1, -1)), obj1_none1(z2[1].reshape(1, -1))) @@ -157,7 +161,7 @@ def test_call_none(self): @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) def test_conj(self, triple): A, b, c = triple - bargmann = Bargmann(*triple).conj + bargmann = PolyExpAnsatz(*triple).conj assert np.allclose(bargmann.A, math.conj(A)) assert np.allclose(bargmann.b, math.conj(b)) @@ -166,7 +170,7 @@ def test_conj(self, triple): def test_decompose_ansatz(self): A, b, _ = Abc_triple(4) c = np.random.uniform(-10, 10, size=(1, 3, 3, 3)) - ansatz = Bargmann(A, b, c) + ansatz = PolyExpAnsatz(A, b, c) decomp_ansatz = ansatz.decompose_ansatz() z = np.random.uniform(-10, 10, size=(1, 1)) @@ -174,7 +178,7 @@ def test_decompose_ansatz(self): assert np.allclose(decomp_ansatz.A.shape, (1, 2, 2)) c2 = np.random.uniform(-10, 10, size=(1, 4)) - ansatz2 = Bargmann(A, b, c2) + ansatz2 = PolyExpAnsatz(A, b, c2) decomp_ansatz2 = ansatz2.decompose_ansatz() assert np.allclose(decomp_ansatz2.A, ansatz2.A) @@ -186,7 +190,7 @@ def test_decompose_ansatz_batch(self): c1 = np.random.uniform(-10, 10, size=(3, 3, 3)) A2, b2, _ = Abc_triple(4) c2 = np.random.uniform(-10, 10, size=(3, 3, 3)) - ansatz = Bargmann([A1, A2], [b1, b2], [c1, c2]) + ansatz = PolyExpAnsatz([A1, A2], [b1, b2], [c1, c2]) decomp_ansatz = ansatz.decompose_ansatz() z = np.random.uniform(-10, 10, size=(3, 1)) @@ -199,7 +203,7 @@ def test_decompose_ansatz_batch(self): c1 = np.random.uniform(-10, 10, size=(3, 3, 3)) A2, b2, _ = Abc_triple(5) c2 = np.random.uniform(-10, 10, size=(3, 3, 3)) - ansatz = Bargmann([A1, A2], [b1, b2], [c1, c2]) + ansatz = PolyExpAnsatz([A1, A2], [b1, b2], [c1, c2]) decomp_ansatz = ansatz.decompose_ansatz() z = np.random.uniform(-10, 10, size=(3, 2)) @@ -213,8 +217,8 @@ def test_div(self, n): triple1 = Abc_triple(n) triple2 = Abc_triple(n) - bargmann1 = Bargmann(*triple1) - bargmann2 = Bargmann(*triple2) + bargmann1 = PolyExpAnsatz(*triple1) + bargmann2 = PolyExpAnsatz(*triple2) bargmann_div = bargmann1 / bargmann2 assert np.allclose(bargmann_div.A, bargmann1.A - bargmann2.A) @@ -224,7 +228,7 @@ def test_div(self, n): @pytest.mark.parametrize("scalar", [0.5, 1.2]) @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) def test_div_with_scalar(self, scalar, triple): - bargmann1 = Bargmann(*triple) + bargmann1 = PolyExpAnsatz(*triple) bargmann_div = bargmann1 / scalar assert np.allclose(bargmann1.A, bargmann_div.A) @@ -234,8 +238,8 @@ def test_div_with_scalar(self, scalar, triple): def test_eq(self): A, b, c = Abc_triple(5) - ansatz = Bargmann(A, b, c) - ansatz2 = Bargmann(2 * A, 2 * b, 2 * c) + ansatz = PolyExpAnsatz(A, b, c) + ansatz2 = PolyExpAnsatz(2 * A, 2 * b, 2 * c) assert ansatz == ansatz # pylint: disable= comparison-with-itself assert ansatz2 == ansatz2 # pylint: disable= comparison-with-itself @@ -246,7 +250,7 @@ def test_matmul_barg_barg(self): triple1 = Abc_triple(3) triple2 = Abc_triple(3) - res1 = Bargmann(*triple1) @ Bargmann(*triple2) + res1 = PolyExpAnsatz(*triple1) @ PolyExpAnsatz(*triple2) exp1 = contract_two_Abc(triple1, triple2, [], []) assert np.allclose(res1.A, exp1[0]) assert np.allclose(res1.b, exp1[1]) @@ -257,8 +261,8 @@ def test_mul(self, n): triple1 = Abc_triple(n) triple2 = Abc_triple(n) - bargmann1 = Bargmann(*triple1) - bargmann2 = Bargmann(*triple2) + bargmann1 = PolyExpAnsatz(*triple1) + bargmann2 = PolyExpAnsatz(*triple2) bargmann_mul = bargmann1 * bargmann2 assert np.allclose(bargmann_mul.A, bargmann1.A + bargmann2.A) @@ -268,7 +272,7 @@ def test_mul(self, n): @pytest.mark.parametrize("scalar", [0.5, 1.2]) @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) def test_mul_with_scalar(self, scalar, triple): - bargmann1 = Bargmann(*triple) + bargmann1 = PolyExpAnsatz(*triple) bargmann_mul = bargmann1 * scalar assert np.allclose(bargmann1.A, bargmann_mul.A) @@ -276,7 +280,7 @@ def test_mul_with_scalar(self, scalar, triple): assert np.allclose(bargmann1.c * scalar, bargmann_mul.c) def test_order_batch(self): - ansatz = Bargmann( + ansatz = PolyExpAnsatz( A=[np.array([[0]]), np.array([[1]])], b=[np.array([1]), np.array([0])], c=[1, 2], @@ -293,7 +297,7 @@ def test_order_batch(self): def test_polynomial_shape(self): A, b, _ = Abc_triple(4) c = np.array([[1, 2, 3]]) - ansatz = Bargmann(A, b, c) + ansatz = PolyExpAnsatz(A, b, c) poly_dim, poly_shape = ansatz.polynomial_shape assert np.allclose(poly_dim, 1) @@ -301,11 +305,11 @@ def test_polynomial_shape(self): A1, b1, _ = Abc_triple(4) c1 = np.array([[1, 2, 3]]) - ansatz1 = Bargmann(A1, b1, c1) + ansatz1 = PolyExpAnsatz(A1, b1, c1) A2, b2, _ = Abc_triple(4) c2 = np.array([[1, 2, 3]]) - ansatz2 = Bargmann(A2, b2, c2) + ansatz2 = PolyExpAnsatz(A2, b2, c2) ansatz3 = ansatz1 * ansatz2 @@ -315,7 +319,7 @@ def test_polynomial_shape(self): def test_reorder(self): triple = Abc_triple(3) - bargmann = Bargmann(*triple).reorder((0, 2, 1)) + bargmann = PolyExpAnsatz(*triple).reorder((0, 2, 1)) assert np.allclose(bargmann.A[0], triple[0][[0, 2, 1], :][:, [0, 2, 1]]) assert np.allclose(bargmann.b[0], triple[1][[0, 2, 1]]) @@ -323,7 +327,7 @@ def test_reorder(self): def test_simplify(self): A, b, c = Abc_triple(5) - ansatz = Bargmann(A, b, c) + ansatz = PolyExpAnsatz(A, b, c) ansatz = ansatz + ansatz @@ -340,7 +344,7 @@ def test_simplify(self): def test_simplify_v2(self): A, b, c = Abc_triple(5) - ansatz = Bargmann(A, b, c) + ansatz = PolyExpAnsatz(A, b, c) ansatz = ansatz + ansatz @@ -366,8 +370,8 @@ def test_sub(self, n): triple1 = Abc_triple(n) triple2 = Abc_triple(n) - bargmann1 = Bargmann(*triple1) - bargmann2 = Bargmann(*triple2) + bargmann1 = PolyExpAnsatz(*triple1) + bargmann2 = PolyExpAnsatz(*triple2) bargmann_add = bargmann1 - bargmann2 assert np.allclose(bargmann_add.A, math.concat([bargmann1.A, bargmann2.A], axis=0)) @@ -376,17 +380,17 @@ def test_sub(self, n): def test_trace(self): triple = Abc_triple(4) - bargmann = Bargmann(*triple).trace([0], [2]) + bargmann = PolyExpAnsatz(*triple).trace([0], [2]) A, b, c = complex_gaussian_integral(triple, [0], [2]) assert np.allclose(bargmann.A, A) assert np.allclose(bargmann.b, b) assert np.allclose(bargmann.c, c) - @patch("mrmustard.physics.representations.bargmann.display") + @patch("mrmustard.physics.ansatz.polyexp_ansatz.display") def test_ipython_repr(self, mock_display): """Test the IPython repr function.""" - rep = Bargmann(*Abc_triple(2)) + rep = PolyExpAnsatz(*Abc_triple(2)) rep._ipython_display_() # pylint:disable=protected-access [box] = mock_display.call_args.args assert isinstance(box, Box) @@ -408,12 +412,12 @@ def test_ipython_repr(self, mock_display): assert isinstance(eig_header, HTML) assert isinstance(unit_circle, FigureWidget) - @patch("mrmustard.physics.representations.bargmann.display") + @patch("mrmustard.physics.ansatz.polyexp_ansatz.display") def test_ipython_repr_batched(self, mock_display): """Test the IPython repr function for a batched repr.""" A1, b1, c1 = Abc_triple(2) A2, b2, c2 = Abc_triple(2) - rep = Bargmann(np.array([A1, A2]), np.array([b1, b2]), np.array([c1, c2])) + rep = PolyExpAnsatz(np.array([A1, A2]), np.array([b1, b2]), np.array([c1, c2])) rep._ipython_display_() # pylint:disable=protected-access [vbox] = mock_display.call_args.args assert isinstance(vbox, VBox) diff --git a/tests/test_physics/test_triples.py b/tests/test_physics/test_triples.py index 358e515cc..b066b259b 100644 --- a/tests/test_physics/test_triples.py +++ b/tests/test_physics/test_triples.py @@ -19,7 +19,7 @@ from mrmustard import math from mrmustard.physics import triples -from mrmustard.physics.representations import Bargmann +from mrmustard.physics.ansatz import PolyExpAnsatz # pylint: disable = missing-function-docstring @@ -333,6 +333,6 @@ def test_displacement_gate_s_parametrized_Abc(self): @pytest.mark.parametrize("eta", [0.0, 0.1, 0.5, 0.9, 1.0]) def test_attenuator_kraus_Abc(self, eta): - B = Bargmann(*triples.attenuator_kraus_Abc(eta)) - Att = Bargmann(*triples.attenuator_Abc(eta)) + B = PolyExpAnsatz(*triples.attenuator_kraus_Abc(eta)) + Att = PolyExpAnsatz(*triples.attenuator_Abc(eta)) assert B[2] @ B[2] == Att From e779692b3a92196c694f6cbe9ade1757d8057f14 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 10:12:25 -0400 Subject: [PATCH 36/87] cleanup --- mrmustard/physics/ansatz/array_ansatz.py | 1 - mrmustard/physics/ansatz/base.py | 1 - mrmustard/physics/ansatz/polyexp_ansatz.py | 4 ++-- mrmustard/physics/representations.py | 1 - 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index 9091bc8d7..50c518422 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """ This module contains the array ansatz. """ diff --git a/mrmustard/physics/ansatz/base.py b/mrmustard/physics/ansatz/base.py index 593be5efd..d7399c162 100644 --- a/mrmustard/physics/ansatz/base.py +++ b/mrmustard/physics/ansatz/base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """ This module contains the base ansatz class. """ diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index d5d09f98d..344890fd8 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. - """ This module contains the Bargmann representation. """ +# pylint: disable=too-many-instance-attributes + from __future__ import annotations from typing import Any, Callable @@ -55,7 +56,6 @@ __all__ = ["PolyExpAnsatz"] -# pylint: disable=too-many-instance-attributes class PolyExpAnsatz(Ansatz): r""" The Fock-Bargmann representation of a broad class of quantum states, transformations, diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index c9afdd2f0..55138d35f 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """ This module contains the class for representations. """ From e4ca9f574314482490f1a770cd4579a6f8914c17 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 10:16:17 -0400 Subject: [PATCH 37/87] docs --- mrmustard/physics/ansatz/base.py | 58 ++++++++++++++++---------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/mrmustard/physics/ansatz/base.py b/mrmustard/physics/ansatz/base.py index d7399c162..a2bd44d57 100644 --- a/mrmustard/physics/ansatz/base.py +++ b/mrmustard/physics/ansatz/base.py @@ -37,7 +37,7 @@ class Ansatz(ABC): r""" - A base class for representations. + A base class for ansatz. """ def __init__(self) -> None: @@ -49,21 +49,21 @@ def __init__(self) -> None: @abstractmethod def batch_size(self) -> int: r""" - The batch size of the representation. + The batch size of the ansatz. """ @property @abstractmethod def conj(self) -> Ansatz: r""" - The conjugate of the representation. + The conjugate of the ansatz. """ @property @abstractmethod def data(self) -> tuple | Tensor: r""" - The data of the representation. + The data of the ansatz. For now, it's the triple for Bargmann and the array for Fock. """ @@ -71,14 +71,14 @@ def data(self) -> tuple | Tensor: @abstractmethod def num_vars(self) -> int: r""" - The number of variables in the representation. + The number of variables in the ansatz. """ @property @abstractmethod def scalar(self) -> Scalar: r""" - The scalar part of the representation. + The scalar part of the ansatz. For now it's ``c`` for Bargmann and the array for Fock. """ @@ -102,13 +102,13 @@ def from_dict(cls, data: dict[str, ArrayLike]) -> Ansatz: @abstractmethod def from_function(cls, fn: Callable, **kwargs: Any) -> Ansatz: r""" - Returns a representation from a function and kwargs. + Returns an ansatz from a function and kwargs. """ @abstractmethod def reorder(self, order: tuple[int, ...] | list[int]) -> Ansatz: r""" - Reorders the representation indices. + Reorders the ansatz indices. """ @abstractmethod @@ -127,7 +127,7 @@ def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Ansatz: idxs2: The second part. Returns: - The traced-over representation. + The traced-over ansatz. """ @abstractmethod @@ -140,31 +140,31 @@ def _generate_ansatz(self): @abstractmethod def __add__(self, other: Ansatz) -> Ansatz: r""" - Adds this representation and another representation. + Adds this ansatz and another ansatz. Args: - other: Another representation. + other: Another ansatz. Returns: - The addition of this representation and other. + The addition of this ansatz and other. """ @abstractmethod def __and__(self, other: Ansatz) -> Ansatz: r""" - Tensor product of this representation with another. + Tensor product of this ansatz with another. Args: - other: Another representation. + other: Another ansatz. Returns: - The tensor product of this representation and other. + The tensor product of this ansatz and other. """ @abstractmethod def __call__(self, z: Batch[Vector]) -> Scalar | Ansatz: r""" - Evaluates this representation at a given point in the domain. + Evaluates this ansatz at a given point in the domain. Args: z: point in C^n where the function is evaluated @@ -176,7 +176,7 @@ def __call__(self, z: Batch[Vector]) -> Scalar | Ansatz: @abstractmethod def __eq__(self, other: Ansatz) -> bool: r""" - Whether this representation is equal to another. + Whether this ansatz is equal to another. """ @abstractmethod @@ -191,42 +191,42 @@ def __matmul__(self, other: Ansatz) -> Ansatz: Implements the inner product of representations over the marked indices. Args: - other: Another representation. + other: Another ansatz. Returns: - The resulting representation. + The resulting ansatz. """ @abstractmethod def __mul__(self, other: Scalar | Ansatz) -> Ansatz: r""" - Multiplies this representation by a scalar or another representation. + Multiplies this ansatz by a scalar or another ansatz. Args: - other: A scalar or another representation. + other: A scalar or another ansatz. Raises: - TypeError: If other is neither a scalar nor a representation. + TypeError: If other is neither a scalar nor an ansatz. Returns: - The product of this representation and other. + The product of this ansatz and other. """ @abstractmethod def __neg__(self) -> Ansatz: r""" - Negates the values in the representation. + Negates the values in the ansatz. """ def __rmul__(self, other: Ansatz | Scalar) -> Ansatz: r""" - Multiplies this representation by another or by a scalar on the right. + Multiplies this ansatz by another or by a scalar on the right. """ return self.__mul__(other) def __sub__(self, other: Ansatz) -> Ansatz: r""" - Subtracts other from this representation. + Subtracts other from this ansatz. """ try: return self.__add__(-other) @@ -236,11 +236,11 @@ def __sub__(self, other: Ansatz) -> Ansatz: @abstractmethod def __truediv__(self, other: Scalar | Ansatz) -> Ansatz: r""" - Divides this representation by another representation. + Divides this ansatz by another ansatz. Args: - other: A scalar or another representation. + other: A scalar or another ansatz. Returns: - The division of this representation and other. + The division of this ansatz and other. """ From ad6d5bfd01dce9edde9000462e0863ec156ebbd1 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 10:18:08 -0400 Subject: [PATCH 38/87] cleanup --- mrmustard/physics/ansatz/polyexp_ansatz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index 344890fd8..835798626 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -13,7 +13,7 @@ # limitations under the License. """ -This module contains the Bargmann representation. +This module contains the PolyExp ansatz. """ # pylint: disable=too-many-instance-attributes From 122a300eb06b43ac481fe27feb4d743740a4cbde Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 10:54:05 -0400 Subject: [PATCH 39/87] workflow --- .github/workflows/tests_docs.yml | 2 +- mrmustard/lab_dev/circuit_components.py | 1 - mrmustard/lab_dev/circuit_components_utils/b_to_q.py | 3 --- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/tests_docs.yml b/.github/workflows/tests_docs.yml index f46240532..eb0b8d792 100644 --- a/.github/workflows/tests_docs.yml +++ b/.github/workflows/tests_docs.yml @@ -39,5 +39,5 @@ jobs: - name: Run tests run: | python -m pytest --doctest-modules mrmustard/math/parameter_set.py - python -m pytest --doctest-modules mrmustard/physics/representations + python -m pytest --doctest-modules mrmustard/physics/ansatz python -m pytest --doctest-modules mrmustard/lab_dev diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 9475afbed..6d60835f0 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -620,7 +620,6 @@ def _rshift_return( ) -> CircuitComponent | np.ndarray | complex: "internal convenience method for right-shift, to return the right type of object" if len(ret.wires) > 0: - print("ret", ret._multi_rep._wire_reps) return ret scalar = ret.representation.scalar return math.sum(scalar) if not settings.UNSAFE_ZIP_BATCH else scalar diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index 7ead7bea8..e4460a8e1 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -93,9 +93,6 @@ def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitCompon msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components." raise ValueError(msg) - # update ret._multi_rep._wire_reps temp = dict.fromkeys(self.modes, RepEnum.QUADRATURE) - print("1", ret._multi_rep._wire_reps) ret._multi_rep._wire_reps.update(temp) - print("2", ret._multi_rep._wire_reps) return self._rshift_return(ret) From 6034e8536c0176ff4fa21d2c02df3ccb5cd5d45a Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 13:18:00 -0400 Subject: [PATCH 40/87] some rename --- mrmustard/lab_dev/circuit_components.py | 95 ++++++++++--------- .../circuit_components_utils/b_to_q.py | 2 +- .../branch_and_bound.py | 2 +- .../circuit_components_utils/trace_out.py | 8 +- mrmustard/lab_dev/samplers.py | 2 +- mrmustard/lab_dev/states/base.py | 54 +++++------ mrmustard/lab_dev/states/coherent.py | 2 +- .../lab_dev/states/displaced_squeezed.py | 2 +- mrmustard/lab_dev/states/number.py | 2 +- .../lab_dev/states/quadrature_eigenstate.py | 2 +- mrmustard/lab_dev/states/sauron.py | 2 +- mrmustard/lab_dev/states/squeezed_vacuum.py | 2 +- mrmustard/lab_dev/states/thermal.py | 2 +- .../states/two_mode_squeezed_vacuum.py | 2 +- .../lab_dev/transformations/amplifier.py | 2 +- .../lab_dev/transformations/attenuator.py | 2 +- mrmustard/lab_dev/transformations/base.py | 38 ++++---- mrmustard/lab_dev/transformations/bsgate.py | 2 +- mrmustard/lab_dev/transformations/cft.py | 2 +- mrmustard/lab_dev/transformations/dgate.py | 8 +- .../lab_dev/transformations/fockdamping.py | 2 +- mrmustard/lab_dev/transformations/ggate.py | 2 +- mrmustard/lab_dev/transformations/rgate.py | 2 +- mrmustard/lab_dev/transformations/s2gate.py | 2 +- mrmustard/lab_dev/transformations/sgate.py | 2 +- tests/test_lab_dev/test_circuit_components.py | 86 ++++++++--------- .../test_circuit_components_utils.py | 32 +++---- tests/test_lab_dev/test_circuits.py | 2 +- .../test_lab_dev/test_states/test_coherent.py | 14 +-- .../test_states/test_displaced_squeezed.py | 6 +- tests/test_lab_dev/test_states/test_number.py | 6 +- .../test_states/test_quadrature_eigenstate.py | 2 +- .../test_states/test_squeezed_vacuum.py | 6 +- .../test_states/test_states_base.py | 32 +++---- .../test_lab_dev/test_states/test_thermal.py | 4 +- .../test_two_mode_squeezed_vacuum.py | 6 +- tests/test_lab_dev/test_states/test_vacuum.py | 2 +- .../test_transformations/test_amplifier.py | 18 ++-- .../test_transformations/test_attenuator.py | 4 +- .../test_transformations/test_bsgate.py | 4 +- .../test_transformations/test_cft.py | 4 +- .../test_transformations/test_dgate.py | 10 +- .../test_transformations/test_fockdamping.py | 8 +- .../test_transformations/test_identity.py | 4 +- .../test_transformations/test_rgate.py | 8 +- .../test_transformations/test_s2gate.py | 6 +- .../test_transformations/test_sgate.py | 8 +- .../test_transformations_base.py | 32 +++---- tests/test_physics/test_bargmann_utils.py | 2 +- 49 files changed, 277 insertions(+), 272 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 6d60835f0..39a10f2bb 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -105,10 +105,12 @@ def __init__( + tuple(np.argsort(modes_in_ket) + offsets[2]) ) if representation is not None: - self._multi_rep = Representation(representation.reorder(tuple(perm)), wires) + self._representation = Representation( + representation.reorder(tuple(perm)), wires + ) - if not hasattr(self, "_multi_rep"): - self._multi_rep = Representation(representation, wires) + if not hasattr(self, "_representation"): + self._representation = Representation(representation, wires) def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: """ @@ -121,11 +123,11 @@ def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: serializable = {"class": f"{cls.__module__}.{cls.__qualname__}"} params = signature(cls).parameters if "name" in params: # assume abstract type, serialize the representation - rep_cls = type(self.representation) + rep_cls = type(self.ansatz) serializable["name"] = self.name serializable["wires"] = self.wires.sorted_args serializable["rep_class"] = f"{rep_cls.__module__}.{rep_cls.__qualname__}" - return serializable, self.representation.to_dict() + return serializable, self.ansatz.to_dict() # handle modes parameter if "modes" in params: @@ -163,7 +165,7 @@ def adjoint(self) -> CircuitComponent: """ bras = self.wires.bra.indices kets = self.wires.ket.indices - rep = self.representation.reorder(kets + bras).conj if self.representation else None + rep = self.ansatz.reorder(kets + bras).conj if self.ansatz else None ret = CircuitComponent(rep, self.wires.adjoint, self.name) ret.short_name = self.short_name @@ -182,7 +184,7 @@ def dual(self) -> CircuitComponent: ik = self.wires.ket.input.indices ib = self.wires.bra.input.indices ob = self.wires.bra.output.indices - rep = self.representation.reorder(ib + ob + ik + ok).conj if self.representation else None + rep = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None ret = CircuitComponent(rep, self.wires.dual, self.name) ret.short_name = self.short_name @@ -205,7 +207,7 @@ def manual_shape(self) -> list[int | None]: in the `.wires` attribute. """ try: # to read it from array ansatz - return list(self.representation.array.shape[1:]) + return list(self.ansatz.array.shape[1:]) except AttributeError: # bargmann return [None] * len(self.wires) @@ -242,18 +244,25 @@ def parameter_set(self) -> ParameterSet: return self._parameter_set @property - def representation(self) -> Ansatz | None: + def ansatz(self) -> Ansatz | None: r""" - A representation of this circuit component. + The ansatz of this circuit component. """ - return self._multi_rep.ansatz + return self._representation.ansatz + + @property + def representation(self) -> Representation | None: + r""" + The representation of this circuit component. + """ + return self._representation @property def wires(self) -> Wires: r""" The wires of this component. """ - return self._multi_rep.wires + return self._representation.wires @classmethod def from_bargmann( @@ -320,7 +329,7 @@ def from_quadrature( # NOTE: the representation is Bargmann here because we use the inverse of BtoQ on the B side QQQQ = CircuitComponent._from_attributes(PolyExpAnsatz(*triple), wires) BBBB = QtoB_ib @ (QtoB_ik @ QQQQ @ QtoB_ok) @ QtoB_ob - return cls._from_attributes(BBBB.representation, wires, name) + return cls._from_attributes(BBBB.ansatz, wires, name) def to_quadrature(self, phi: float = 0.0) -> CircuitComponent: r""" @@ -341,7 +350,7 @@ def to_quadrature(self, phi: float = 0.0) -> CircuitComponent: BtoQ_ik = BtoQ(self.wires.input.ket.modes, phi).dual object_to_convert = self - if isinstance(self.representation, ArrayAnsatz): + if isinstance(self.ansatz, ArrayAnsatz): object_to_convert = self.to_bargmann() QQQQ = BtoQ_ib @ (BtoQ_ik @ object_to_convert @ BtoQ_ok) @ BtoQ_ob @@ -359,7 +368,7 @@ def quadrature_triple( Returns: A,b,c triple of the quadrature representation """ - return self.to_quadrature(phi=phi).representation.data + return self.to_quadrature(phi=phi).ansatz.data def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor: r""" @@ -373,8 +382,8 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor: A circuit component with the given quadrature representation. """ - if isinstance(self.representation, ArrayAnsatz): - fock_arrays = self.representation.array + if isinstance(self.ansatz, ArrayAnsatz): + fock_arrays = self.ansatz.array # Find where all the bras and kets are so they can be conjugated appropriately conjugates = [i not in self.wires.ket.indices for i in range(len(self.wires.indices))] quad_basis = math.sum( @@ -383,7 +392,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor: return quad_basis QQQQ = self.to_quadrature(phi=phi) - return QQQQ.representation(quad) + return QQQQ.ansatz(quad) @classmethod def _from_attributes( @@ -422,7 +431,7 @@ def _from_attributes( if tp.__name__ in types: ret = tp() ret._name = name - ret._multi_rep = Representation(representation, wires) + ret._representation = Representation(representation, wires) return ret return CircuitComponent(representation, wires, name) @@ -454,7 +463,7 @@ def bargmann_triple( >>> assert isinstance(coh_cc, CircuitComponent) >>> assert coh == coh_cc # equality looks at representation and wires """ - return self._multi_rep.bargmann_triple(batched) + return self._representation.bargmann_triple(batched) def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> ComplexTensor: r""" @@ -470,7 +479,7 @@ def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> Compl Returns: array: The Fock representation of this component. """ - return self._multi_rep.fock(shape or self.auto_shape(), batched) + return self._representation.fock(shape or self.auto_shape(), batched) def on(self, modes: Sequence[int]) -> CircuitComponent: r""" @@ -530,15 +539,15 @@ def to_bargmann(self) -> CircuitComponent: >>> assert d_bargmann.wires == d.wires >>> assert isinstance(d_bargmann.representation, Bargmann) """ - if isinstance(self.representation, PolyExpAnsatz): + if isinstance(self.ansatz, PolyExpAnsatz): return self else: - mult_rep = self._multi_rep.to_bargmann() + rep = self._representation.to_bargmann() try: ret = self._getitem_builtin(self.modes) - ret._multi_rep = mult_rep + ret._representation = rep except TypeError: - ret = self._from_attributes(mult_rep.ansatz, mult_rep.wires, self.name) + ret = self._from_attributes(rep.ansatz, rep.wires, self.name) if "manual_shape" in ret.__dict__: del ret.manual_shape return ret @@ -564,12 +573,12 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: an ``int``, it is broadcasted to all the dimensions. If ``None``, it defaults to the value of ``AUTOSHAPE_MAX`` in the settings. """ - mult_rep = self._multi_rep.to_fock(shape or self.auto_shape()) + rep = self._representation.to_fock(shape or self.auto_shape()) try: ret = self._getitem_builtin(self.modes) - ret._multi_rep = mult_rep + ret._representation = rep except TypeError: - ret = self._from_attributes(mult_rep.ansatz, mult_rep.wires, self.name) + ret = self._from_attributes(rep.ansatz, rep.wires, self.name) if "manual_shape" in ret.__dict__: del ret.manual_shape return ret @@ -610,8 +619,8 @@ def _light_copy(self, wires: Wires | None = None) -> CircuitComponent: """ instance = super().__new__(self.__class__) instance.__dict__ = self.__dict__.copy() - instance.__dict__["_multi_rep"] = Representation( - self.representation, wires or Wires(*self.wires.args) + instance.__dict__["_representation"] = Representation( + self.ansatz, wires or Wires(*self.wires.args) ) return instance @@ -621,7 +630,7 @@ def _rshift_return( "internal convenience method for right-shift, to return the right type of object" if len(ret.wires) > 0: return ret - scalar = ret.representation.scalar + scalar = ret.ansatz.scalar return math.sum(scalar) if not settings.UNSAFE_ZIP_BATCH else scalar def __add__(self, other: CircuitComponent) -> CircuitComponent: @@ -630,7 +639,7 @@ def __add__(self, other: CircuitComponent) -> CircuitComponent: """ if self.wires != other.wires: raise ValueError("Cannot add components with different wires.") - rep = self.representation + other.representation + rep = self.ansatz + other.ansatz name = self.name if self.name == other.name else "" return self._from_attributes(rep, self.wires, name) @@ -638,11 +647,11 @@ def __eq__(self, other) -> bool: r""" Whether this component is equal to another component. - Compares multi-representations, but not the other attributes + Compares representations, but not the other attributes (e.g. name and parameter set). """ if isinstance(other, CircuitComponent): - return self._multi_rep == other._multi_rep + return self._representation == other._representation return False def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: @@ -664,17 +673,17 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: """ if isinstance(other, (numbers.Number, np.ndarray)): return self * other - result = self._multi_rep @ other._multi_rep + result = self._representation @ other._representation return CircuitComponent._from_attributes(result.ansatz, result.wires, None) def __mul__(self, other: Scalar) -> CircuitComponent: r""" Implements the multiplication by a scalar from the right. """ - return self._from_attributes(self.representation * other, self.wires, self.name) + return self._from_attributes(self.ansatz * other, self.wires, self.name) def __repr__(self) -> str: - repr = self.representation + repr = self.ansatz repr_name = repr.__class__.__name__ if repr_name == "NoneType": return self.__class__.__name__ + f"(modes={self.modes}, name={self.name})" @@ -708,7 +717,7 @@ def __rrshift__(self, other: Scalar) -> CircuitComponent | np.array: not be called, and something else will be returned. """ ret = self * other - return ret.representation.scalar + return ret.ansatz.scalar def __rshift__(self, other: CircuitComponent | numbers.Number) -> CircuitComponent | np.ndarray: r""" @@ -768,7 +777,7 @@ def __sub__(self, other: CircuitComponent) -> CircuitComponent: """ if self.wires != other.wires: raise ValueError("Cannot subtract components with different wires.") - rep = self.representation - other.representation + rep = self.ansatz - other.ansatz name = self.name if self.name == other.name else "" return self._from_attributes(rep, self.wires, name) @@ -776,14 +785,12 @@ def __truediv__(self, other: Scalar) -> CircuitComponent: r""" Implements the division by a scalar for circuit components. """ - return self._from_attributes(self.representation / other, self.wires, self.name) + return self._from_attributes(self.ansatz / other, self.wires, self.name) def _ipython_display_(self): # both reps might return None - rep_fn = ( - mmwidgets.fock if isinstance(self.representation, ArrayAnsatz) else mmwidgets.bargmann - ) - rep_widget = rep_fn(self.representation) + rep_fn = mmwidgets.fock if isinstance(self.ansatz, ArrayAnsatz) else mmwidgets.bargmann + rep_widget = rep_fn(self.ansatz) wires_widget = mmwidgets.wires(self.wires) if not rep_widget: title_widget = widgets.HTML(f"

{self.name or type(self).__name__}

") diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index e4460a8e1..d90c082c0 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -94,5 +94,5 @@ def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitCompon raise ValueError(msg) temp = dict.fromkeys(self.modes, RepEnum.QUADRATURE) - ret._multi_rep._wire_reps.update(temp) + ret._representation._wire_reps.update(temp) return self._rshift_return(ret) diff --git a/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py b/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py index 2f50f1bde..82bd471f7 100644 --- a/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py +++ b/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py @@ -68,7 +68,7 @@ def from_circuitcomponent(cls, c: CircuitComponent): c: A CircuitComponent. """ return GraphComponent( - repr=str(c.representation.__class__.__name__), + repr=str(c.ansatz.__class__.__name__), wires=Wires(*c.wires.args), shape=c.auto_shape(), name=c.__class__.__name__, diff --git a/mrmustard/lab_dev/circuit_components_utils/trace_out.py b/mrmustard/lab_dev/circuit_components_utils/trace_out.py index 665fdaa21..67263da07 100644 --- a/mrmustard/lab_dev/circuit_components_utils/trace_out.py +++ b/mrmustard/lab_dev/circuit_components_utils/trace_out.py @@ -80,14 +80,14 @@ def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitCompon idx_zconj = [bra[m].indices[0] for m in self.wires.modes & bra.modes] idx_z = [ket[m].indices[0] for m in self.wires.modes & ket.modes] if len(self.wires) == 0: - repr = other.representation + repr = other.ansatz wires = other.wires elif not ket or not bra: - repr = other.representation.conj[idx_z] @ other.representation[idx_z] + repr = other.ansatz.conj[idx_z] @ other.ansatz[idx_z] wires, _ = (other.wires.adjoint @ other.wires)[0] @ self.wires else: - repr = other.representation.trace(idx_z, idx_zconj) + repr = other.ansatz.trace(idx_z, idx_zconj) wires, _ = other.wires @ self.wires cpt = other._from_attributes(repr, wires) # pylint:disable=protected-access - return math.sum(cpt.representation.scalar) if len(cpt.wires) == 0 else cpt + return math.sum(cpt.ansatz.scalar) if len(cpt.wires) == 0 else cpt diff --git a/mrmustard/lab_dev/samplers.py b/mrmustard/lab_dev/samplers.py index c7fda9e15..c088061e8 100644 --- a/mrmustard/lab_dev/samplers.py +++ b/mrmustard/lab_dev/samplers.py @@ -229,7 +229,7 @@ def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) - for unique_sample, counts in zip(unique_samples, counts): quad = np.array([[unique_sample] + [None] * (state.n_modes - 1)]) quad = quad if isinstance(state, Ket) else math.tile(quad, (1, 2)) - reduced_rep = (state >> BtoQ([initial_mode], phi=self._phi)).representation(quad) + reduced_rep = (state >> BtoQ([initial_mode], phi=self._phi)).ansatz(quad) reduced_state = state.__class__.from_bargmann(state.modes[1:], reduced_rep.triple) prob = probs[initial_samples.tolist().index(unique_sample)] / self._step norm = math.sqrt(prob) if isinstance(state, Ket) else prob diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 3c4505bac..ed5dcdea6 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -350,7 +350,7 @@ def phase_space(self, s: float) -> tuple: Returns: The covariance matrix, the mean vector and the coefficient of the state in s-parametrized phase space. """ - if not isinstance(self.representation, PolyExpAnsatz): + if not isinstance(self.ansatz, PolyExpAnsatz): raise ValueError("Can calculate phase space only for Bargmann states.") new_state = self >> BtoPS(self.modes, s=s) @@ -426,7 +426,7 @@ def visualize_2d( shape = [max(min_shape, d) for d in self.auto_shape()] state = self.to_fock(tuple(shape)) state = state if isinstance(state, DM) else state.dm() - dm = math.sum(state.representation.array, axes=[0]) + dm = math.sum(state.ansatz.array, axes=[0]) x, prob_x = quadrature_distribution(dm) p, prob_p = quadrature_distribution(dm, np.pi / 2) @@ -542,7 +542,7 @@ def visualize_3d( shape = [max(min_shape, d) for d in self.auto_shape()] state = self.to_fock(tuple(shape)) state = state if isinstance(state, DM) else state.dm() - dm = math.sum(state.representation.array, axes=[0]) + dm = math.sum(state.ansatz.array, axes=[0]) xvec = np.linspace(*xbounds, resolution) pvec = np.linspace(*pbounds, resolution) @@ -616,7 +616,7 @@ def visualize_dm( raise ValueError("DM visualization not available for multi-mode states.") state = self.to_fock(cutoff) state = state if isinstance(state, DM) else state.dm() - dm = math.sum(state.representation.array, axes=[0]) + dm = math.sum(state.ansatz.array, axes=[0]) fig = go.Figure( data=go.Heatmap(z=abs(dm), colorscale="viridis", name="abs(ρ)", showscale=False) @@ -635,7 +635,7 @@ def visualize_dm( def _ipython_display_(self): # pragma: no cover is_ket = isinstance(self, Ket) - is_fock = isinstance(self.representation, ArrayAnsatz) + is_fock = isinstance(self.ansatz, ArrayAnsatz) display(widgets.state(self, is_ket=is_ket, is_fock=is_fock)) @@ -665,19 +665,19 @@ def __init__( wires=[modes, (), modes, ()], name=name, ) - self._multi_rep = Representation(representation, self.wires) + self._representation = Representation(representation, self.wires) @property def is_positive(self) -> bool: r""" Whether this DM is a positive operator. """ - batch_dim = self.representation.batch_size + batch_dim = self.ansatz.batch_size if batch_dim > 1: raise ValueError( "Physicality conditions are not implemented for batch dimension larger than 1." ) - A = self.representation.A[0] + A = self.ansatz.A[0] m = A.shape[-1] // 2 gamma_A = A[:m, m:] @@ -712,7 +712,7 @@ def _probabilities(self) -> RealVector: """ idx_ket = self.wires.output.ket.indices idx_bra = self.wires.output.bra.indices - rep = self.representation.trace(idx_ket, idx_bra) + rep = self.ansatz.trace(idx_ket, idx_bra) return math.real(math.sum(rep.scalar)) @property @@ -798,7 +798,7 @@ def from_quadrature( """ QtoB = BtoQ(modes, phi).inverse() Q = DM(modes, PolyExpAnsatz(*triple)) - return DM(modes, (Q >> QtoB).representation, name) + return DM(modes, (Q >> QtoB).ansatz, name) @classmethod def random(cls, modes: Sequence[int], m: int | None = None, max_r: float = 1.0) -> DM: @@ -838,12 +838,12 @@ def auto_shape( respect_manual_shape: Whether to respect the non-None values in ``manual_shape``. """ # experimental: - if self.representation.batch_size == 1: + if self.ansatz.batch_size == 1: try: # fock - shape = self.representation.array.shape[1:] + shape = self.ansatz.array.shape[1:] except AttributeError: # bargmann - if self.representation.polynomial_shape[0] == 0: - repr = self.representation + if self.ansatz.polynomial_shape[0] == 0: + repr = self.ansatz A, b, c = repr.A[0], repr.b[0], repr.c[0] repr = repr / self.probability shape = autoshape_numba( @@ -942,7 +942,7 @@ def __getitem__(self, modes: int | Sequence[int]) -> State: idxz = [i for i, m in enumerate(self.modes) if m not in modes] idxz_conj = [i + len(self.modes) for i, m in enumerate(self.modes) if m not in modes] - representation = self.representation.trace(idxz, idxz_conj) + representation = self.ansatz.trace(idxz, idxz_conj) return self.__class__._from_attributes( representation, wires, self.name @@ -964,7 +964,7 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: w = result.wires if not w.input and w.bra.modes == w.ket.modes: - return DM(w.modes, result.representation) + return DM(w.modes, result.ansatz) return result @@ -994,20 +994,20 @@ def __init__( wires=[(), (), modes, ()], name=name, ) - self._multi_rep = Representation(representation, self.wires) + self._representation = Representation(representation, self.wires) @property def is_physical(self) -> bool: r""" Whether the ket object is a physical one. """ - batch_dim = self.representation.batch_size + batch_dim = self.ansatz.batch_size if batch_dim > 1: raise ValueError( "Physicality conditions are not implemented for batch dimension larger than 1." ) - A = self.representation.A[0] + A = self.ansatz.A[0] return all(math.abs(math.eigvals(A)) < 1) and math.allclose( self.probability, 1, settings.ATOL @@ -1079,7 +1079,7 @@ def from_quadrature( ) -> State: QtoB = BtoQ(modes, phi).inverse() Q = Ket(modes, PolyExpAnsatz(*triple)) - return Ket(modes, (Q >> QtoB).representation, name) + return Ket(modes, (Q >> QtoB).ansatz, name) @classmethod def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Ket: @@ -1135,12 +1135,12 @@ def auto_shape( respect_manual_shape: Whether to respect the non-None values in ``manual_shape``. """ # experimental: - if self.representation.batch_size == 1: + if self.ansatz.batch_size == 1: try: # fock - shape = self.representation.array.shape[1:] + shape = self.ansatz.array.shape[1:] except AttributeError: # bargmann - if self.representation.polynomial_shape[0] == 0: - repr = self.representation.conj & self.representation + if self.ansatz.polynomial_shape[0] == 0: + repr = self.ansatz.conj & self.ansatz A, b, c = repr.A[0], repr.b[0], repr.c[0] repr = repr / self.probability shape = autoshape_numba( @@ -1164,7 +1164,7 @@ def dm(self) -> DM: The ``DM`` object obtained from this ``Ket``. """ dm = self @ self.adjoint - ret = DM._from_attributes(dm.representation, dm.wires, self.name) + ret = DM._from_attributes(dm.ansatz, dm.wires, self.name) ret.manual_shape = self.manual_shape + self.manual_shape return ret @@ -1258,7 +1258,7 @@ def __rshift__(self, other: CircuitComponent | Scalar) -> CircuitComponent | Bat if not result.wires.input: if not result.wires.bra: - return Ket(result.wires.modes, result.representation) + return Ket(result.wires.modes, result.ansatz) elif result.wires.bra.modes == result.wires.ket.modes: - result = DM(result.wires.modes, result.representation) + result = DM(result.wires.modes, result.ansatz) return result diff --git a/mrmustard/lab_dev/states/coherent.py b/mrmustard/lab_dev/states/coherent.py index 8243ddb34..10a535fc5 100644 --- a/mrmustard/lab_dev/states/coherent.py +++ b/mrmustard/lab_dev/states/coherent.py @@ -83,7 +83,7 @@ def __init__( self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function(fn=triples.coherent_state_Abc, x=self.x, y=self.y), self.wires, ) diff --git a/mrmustard/lab_dev/states/displaced_squeezed.py b/mrmustard/lab_dev/states/displaced_squeezed.py index 56f1770fd..d5b946dd6 100644 --- a/mrmustard/lab_dev/states/displaced_squeezed.py +++ b/mrmustard/lab_dev/states/displaced_squeezed.py @@ -85,7 +85,7 @@ def __init__( self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function( fn=triples.displaced_squeezed_vacuum_state_Abc, x=self.x, diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index 30ee38b0b..6a7edf46f 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -75,7 +75,7 @@ def __init__( for i, cutoff in enumerate(self.cutoffs.value): self.manual_shape[i] = int(cutoff) + 1 - self._multi_rep = Representation( + self._representation = Representation( ArrayAnsatz.from_function(fock_state, n=self.n.value, cutoffs=self.cutoffs.value), self.wires, ) diff --git a/mrmustard/lab_dev/states/quadrature_eigenstate.py b/mrmustard/lab_dev/states/quadrature_eigenstate.py index 4eb045002..276e19811 100644 --- a/mrmustard/lab_dev/states/quadrature_eigenstate.py +++ b/mrmustard/lab_dev/states/quadrature_eigenstate.py @@ -68,7 +68,7 @@ def __init__( xs, phis = list(reshape_params(len(modes), x=x, phi=phi)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function( fn=triples.quadrature_eigenstates_Abc, x=self.x, phi=self.phi ), diff --git a/mrmustard/lab_dev/states/sauron.py b/mrmustard/lab_dev/states/sauron.py index 827ace06c..846018389 100644 --- a/mrmustard/lab_dev/states/sauron.py +++ b/mrmustard/lab_dev/states/sauron.py @@ -43,7 +43,7 @@ def __init__(self, modes: Sequence[int], n: int, epsilon: float = 0.1): super().__init__(name=f"Sauron-{n}", modes=modes) self._add_parameter(make_parameter(False, n, "n", (None, None), dtype="int64")) self._add_parameter(make_parameter(False, epsilon, "epsilon", (None, None))) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function( triples.sauron_state_Abc, n=self.n.value, epsilon=self.epsilon.value ), diff --git a/mrmustard/lab_dev/states/squeezed_vacuum.py b/mrmustard/lab_dev/states/squeezed_vacuum.py index 4c3431018..fd713e7ca 100644 --- a/mrmustard/lab_dev/states/squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/squeezed_vacuum.py @@ -69,7 +69,7 @@ def __init__( rs, phis = list(reshape_params(len(modes), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function( fn=triples.squeezed_vacuum_state_Abc, r=self.r, phi=self.phi ), diff --git a/mrmustard/lab_dev/states/thermal.py b/mrmustard/lab_dev/states/thermal.py index 8aa37d97f..35da098b8 100644 --- a/mrmustard/lab_dev/states/thermal.py +++ b/mrmustard/lab_dev/states/thermal.py @@ -62,6 +62,6 @@ def __init__( super().__init__(modes=modes, name="Thermal") (nbars,) = list(reshape_params(len(modes), nbar=nbar)) self._add_parameter(make_parameter(nbar_trainable, nbars, "nbar", nbar_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function(fn=triples.thermal_state_Abc, nbar=self.nbar), self.wires ) diff --git a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py index 33a6ce4c4..c8ed7a66b 100644 --- a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py @@ -67,7 +67,7 @@ def __init__( rs, phis = list(reshape_params(int(len(modes) / 2), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function( fn=triples.two_mode_squeezed_vacuum_state_Abc, r=self.r, phi=self.phi ), diff --git a/mrmustard/lab_dev/transformations/amplifier.py b/mrmustard/lab_dev/transformations/amplifier.py index 1bf3e12cf..8d6807e8e 100644 --- a/mrmustard/lab_dev/transformations/amplifier.py +++ b/mrmustard/lab_dev/transformations/amplifier.py @@ -96,6 +96,6 @@ def __init__( None, ) ) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function(fn=triples.amplifier_Abc, g=self.gain), self.wires ) diff --git a/mrmustard/lab_dev/transformations/attenuator.py b/mrmustard/lab_dev/transformations/attenuator.py index 32b484c2b..8847e3387 100644 --- a/mrmustard/lab_dev/transformations/attenuator.py +++ b/mrmustard/lab_dev/transformations/attenuator.py @@ -96,7 +96,7 @@ def __init__( None, ) ) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function(fn=triples.attenuator_Abc, eta=self.transmissivity), self.wires, ) diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index e1ab29d16..b89710df7 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -89,18 +89,18 @@ def inverse(self) -> Transformation: raise NotImplementedError( "Only Transformations with the same number of input and output wires are supported." ) - if not isinstance(self.representation, PolyExpAnsatz): + if not isinstance(self.ansatz, PolyExpAnsatz): raise NotImplementedError("Only Bargmann representation is supported.") - if self.representation.batch_size > 1: + if self.ansatz.batch_size > 1: raise NotImplementedError("Batched transformations are not supported.") # compute the inverse - A, b, _ = self.dual.representation.conj.triple # apply X(.)X + A, b, _ = self.dual.ansatz.conj.triple # apply X(.)X almost_inverse = self._from_attributes( PolyExpAnsatz(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 + 0j), self.wires ) almost_identity = self @ almost_inverse - invert_this_c = almost_identity.representation.c + invert_this_c = almost_identity.ansatz.c actual_inverse = self._from_attributes( PolyExpAnsatz(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 / invert_this_c), self.wires, @@ -155,7 +155,7 @@ def from_quadrature( QtoB_in = BtoQ(modes_in, phi).inverse().dual QQ = Operation(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Operation(modes_out, modes_in, BB.representation, name) + return Operation(modes_out, modes_in, BB.ansatz, name) class Unitary(Operation): @@ -177,8 +177,8 @@ def symplectic(self): r""" Returns the symplectic matrix that corresponds to this unitary """ - batch_size = self.representation.batch_size - return [au2Symplectic(self.representation.A[batch, :, :]) for batch in range(batch_size)] + batch_size = self.ansatz.batch_size + return [au2Symplectic(self.ansatz.A[batch, :, :]) for batch in range(batch_size)] @classmethod def from_bargmann( @@ -205,7 +205,7 @@ def from_quadrature( QtoB_in = BtoQ(modes_in, phi).inverse().dual QQ = Unitary(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Unitary(modes_out, modes_in, BB.representation, name) + return Unitary(modes_out, modes_in, BB.ansatz, name) @classmethod def from_symplectic(cls, modes, S) -> Unitary: @@ -235,7 +235,7 @@ def random(cls, modes, max_r=1): def inverse(self) -> Unitary: unitary_dual = self.dual return Unitary._from_attributes( - representation=unitary_dual.representation, + representation=unitary_dual.ansatz, wires=unitary_dual.wires, name=unitary_dual.name, ) @@ -254,9 +254,9 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: ret = super().__rshift__(other) if isinstance(other, Unitary): - return Unitary._from_attributes(ret.representation, ret.wires) + return Unitary._from_attributes(ret.ansatz, ret.wires) elif isinstance(other, Channel): - return Channel._from_attributes(ret.representation, ret.wires) + return Channel._from_attributes(ret.ansatz, ret.wires) return ret @@ -311,7 +311,7 @@ def from_quadrature( QtoB_in = BtoQ(modes_in, phi).inverse().dual QQ = Map(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Map(modes_out, modes_in, BB.representation, name) + return Map(modes_out, modes_in, BB.ansatz, name) class Channel(Map): @@ -332,12 +332,12 @@ def is_CP(self) -> bool: r""" Whether this channel is completely positive (CP). """ - batch_dim = self.representation.batch_size + batch_dim = self.ansatz.batch_size if batch_dim > 1: raise ValueError( "Physicality conditions are not implemented for batch dimension larger than 1." ) - A = self.representation.A + A = self.ansatz.A m = A.shape[-1] // 2 gamma_A = A[0, :m, m:] @@ -353,7 +353,7 @@ def is_TP(self) -> bool: r""" Whether this channel is trace preserving (TP). """ - A = self.representation.A + A = self.ansatz.A m = A.shape[-1] // 2 gamma_A = A[0, :m, m:] lambda_A = A[0, m:, m:] @@ -372,7 +372,7 @@ def XY(self) -> tuple[ComplexMatrix, ComplexMatrix]: r""" Returns the X and Y matrix corresponding to the channel. """ - return XY_of_channel(self.representation.A[0]) + return XY_of_channel(self.ansatz.A[0]) @classmethod def from_bargmann( @@ -399,7 +399,7 @@ def from_quadrature( QtoB_in = BtoQ(modes_in, phi).inverse().dual QQ = Channel(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Channel(modes_out, modes_in, BB.representation, name) + return Channel(modes_out, modes_in, BB.ansatz, name) @classmethod def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Channel: @@ -415,7 +415,7 @@ def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Channel: m = len(modes) U = Unitary.random(range(3 * m), max_r) u_psi = Vacuum(range(2 * m)) >> U - A = u_psi.representation + A = u_psi.ansatz kraus = A.conj[range(2 * m)] @ A[range(2 * m)] return Channel.from_bargmann(modes, modes, kraus.triple) @@ -428,5 +428,5 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: """ ret = super().__rshift__(other) if isinstance(other, (Channel, Unitary)): - return Channel._from_attributes(ret.representation, ret.wires) + return Channel._from_attributes(ret.ansatz, ret.wires) return ret diff --git a/mrmustard/lab_dev/transformations/bsgate.py b/mrmustard/lab_dev/transformations/bsgate.py index 15a89a133..4953a1dee 100644 --- a/mrmustard/lab_dev/transformations/bsgate.py +++ b/mrmustard/lab_dev/transformations/bsgate.py @@ -105,7 +105,7 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="BSgate") self._add_parameter(make_parameter(theta_trainable, theta, "theta", theta_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function( fn=triples.beamsplitter_gate_Abc, theta=self.theta, phi=self.phi ), diff --git a/mrmustard/lab_dev/transformations/cft.py b/mrmustard/lab_dev/transformations/cft.py index d0ccd7f05..669f07a46 100644 --- a/mrmustard/lab_dev/transformations/cft.py +++ b/mrmustard/lab_dev/transformations/cft.py @@ -48,7 +48,7 @@ def __init__( modes_in=modes, name="CFT", ) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function( fn=triples.complex_fourier_transform_Abc, n_modes=len(modes) ), diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index e7b756356..48835ef56 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -95,7 +95,7 @@ def __init__( xs, ys = list(reshape_params(len(modes), x=x, y=y)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function(fn=triples.displacement_gate_Abc, x=self.x, y=self.y), self.wires, ) @@ -114,7 +114,7 @@ def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTenso array: The Fock representation of this component. """ if isinstance(shape, int): - shape = (shape,) * self.representation.num_vars + shape = (shape,) * self.ansatz.num_vars auto_shape = self.auto_shape() shape = shape or auto_shape if len(shape) != len(auto_shape): @@ -146,7 +146,7 @@ def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTenso def to_fock(self, shape: int | Sequence[int] | None = None) -> Dgate: fock = ArrayAnsatz(self.fock(shape, batched=True), batched=True) - fock._original_abc_data = self.representation.triple + fock._original_abc_data = self.ansatz.triple ret = self._getitem_builtin(self.modes) - ret._multi_rep = Representation(fock, self.wires) + ret._representation = Representation(fock, self.wires) return ret diff --git a/mrmustard/lab_dev/transformations/fockdamping.py b/mrmustard/lab_dev/transformations/fockdamping.py index 9d4426d8b..9ff01a718 100644 --- a/mrmustard/lab_dev/transformations/fockdamping.py +++ b/mrmustard/lab_dev/transformations/fockdamping.py @@ -86,6 +86,6 @@ def __init__( None, ) ) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function(fn=triples.fock_damping_Abc, beta=self.damping), self.wires ) diff --git a/mrmustard/lab_dev/transformations/ggate.py b/mrmustard/lab_dev/transformations/ggate.py index 1b953cab0..5ecbf54e5 100644 --- a/mrmustard/lab_dev/transformations/ggate.py +++ b/mrmustard/lab_dev/transformations/ggate.py @@ -58,7 +58,7 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="Ggate") S = make_parameter(symplectic_trainable, symplectic, "symplectic", (None, None)) self.parameter_set.add_parameter(S) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function( fn=lambda s: Unitary.from_symplectic(modes, s).bargmann_triple(), s=self.parameter_set.symplectic, diff --git a/mrmustard/lab_dev/transformations/rgate.py b/mrmustard/lab_dev/transformations/rgate.py index 9e58106d3..f2e769daa 100644 --- a/mrmustard/lab_dev/transformations/rgate.py +++ b/mrmustard/lab_dev/transformations/rgate.py @@ -63,6 +63,6 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="Rgate") (phis,) = list(reshape_params(len(modes), phi=phi)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function(fn=triples.rotation_gate_Abc, theta=self.phi), self.wires ) diff --git a/mrmustard/lab_dev/transformations/s2gate.py b/mrmustard/lab_dev/transformations/s2gate.py index d5c292744..b9661a22b 100644 --- a/mrmustard/lab_dev/transformations/s2gate.py +++ b/mrmustard/lab_dev/transformations/s2gate.py @@ -88,7 +88,7 @@ def __init__( super().__init__(modes_out=modes, modes_in=modes, name="S2gate") self._add_parameter(make_parameter(r_trainable, r, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function( fn=triples.twomode_squeezing_gate_Abc, r=self.r, phi=self.phi ), diff --git a/mrmustard/lab_dev/transformations/sgate.py b/mrmustard/lab_dev/transformations/sgate.py index c4d0a1233..d5e687a70 100644 --- a/mrmustard/lab_dev/transformations/sgate.py +++ b/mrmustard/lab_dev/transformations/sgate.py @@ -95,7 +95,7 @@ def __init__( rs, phis = list(reshape_params(len(modes), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._multi_rep = Representation( + self._representation = Representation( PolyExpAnsatz.from_function(fn=triples.squeezing_gate_Abc, r=self.r, delta=self.phi), self.wires, ) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 82d28f476..7865820c9 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -62,7 +62,7 @@ def test_init(self, x, y): assert cc.name == name assert list(cc.modes) == [1, 8] assert cc.wires == Wires(modes_out_ket={1, 8}, modes_in_ket={1, 8}) - assert cc.representation == representation + assert cc.ansatz == representation assert cc.manual_shape == [None] * 4 def test_missing_name(self): @@ -74,7 +74,7 @@ def test_missing_name(self): def test_from_bargmann(self): cc = CircuitComponent.from_bargmann(displacement_gate_Abc(0.1, 0.2), {}, {}, {0}, {0}) - assert cc.representation == PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.2)) + assert cc.ansatz == PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.2)) def test_modes_init_out_of_order(self): m1 = (8, 1) @@ -87,19 +87,19 @@ def test_modes_init_out_of_order(self): cc2 = CircuitComponent(r2, wires=[(), (), m2, m2]) assert cc1 == cc2 - r3 = (cc1.adjoint @ cc1).representation + r3 = (cc1.adjoint @ cc1).ansatz cc3 = CircuitComponent(r3, wires=[m2, m2, m2, m1]) cc4 = CircuitComponent(r3, wires=[m2, m2, m2, m2]) - assert cc3.representation == cc4.representation.reorder([0, 1, 2, 3, 4, 5, 7, 6]) + assert cc3.ansatz == cc4.ansatz.reorder([0, 1, 2, 3, 4, 5, 7, 6]) @pytest.mark.parametrize("x", [0.1, [0.2, 0.3]]) @pytest.mark.parametrize("y", [0.4, [0.5, 0.6]]) def test_from_attributes(self, x, y): cc = Dgate([1, 8], x=x, y=y) - cc1 = Dgate._from_attributes(cc.representation, cc.wires, cc.name) - cc2 = Unitary._from_attributes(cc.representation, cc.wires, cc.name) - cc3 = CircuitComponent._from_attributes(cc.representation, cc.wires, cc.name) + cc1 = Dgate._from_attributes(cc.ansatz, cc.wires, cc.name) + cc2 = Unitary._from_attributes(cc.ansatz, cc.wires, cc.name) + cc3 = CircuitComponent._from_attributes(cc.ansatz, cc.wires, cc.name) assert cc1 == cc assert cc2 == cc @@ -111,7 +111,7 @@ def test_from_attributes(self, x, y): def test_from_to_quadrature(self): c = Dgate([0], x=0.1, y=0.2) >> Sgate([0], r=1.0, phi=0.1) - cc = CircuitComponent._from_attributes(c.representation, c.wires, c.name) + cc = CircuitComponent._from_attributes(c.ansatz, c.wires, c.name) ccc = CircuitComponent.from_quadrature(tuple(), tuple(), (0,), (0,), cc.quadrature_triple()) assert cc == ccc @@ -123,16 +123,14 @@ def test_adjoint(self): assert d1_adj.name == d1.name assert d1_adj.wires == d1.wires.adjoint assert d1_adj.parameter_set == d1.parameter_set - assert ( - d1_adj.representation == d1.representation.conj - ) # this holds for the Dgate but not in general + assert d1_adj.ansatz == d1.ansatz.conj # this holds for the Dgate but not in general d1_adj_adj = d1_adj.adjoint assert isinstance(d1_adj_adj, CircuitComponent) assert d1_adj_adj.wires == d1.wires assert d1_adj_adj.parameter_set == d1_adj.parameter_set assert d1_adj_adj.parameter_set == d1.parameter_set - assert d1_adj_adj.representation == d1.representation + assert d1_adj_adj.ansatz == d1.ansatz def test_dual(self): d1 = Dgate([1, 8], x=0.1, y=0.2) @@ -143,15 +141,15 @@ def test_dual(self): assert d1_dual.name == d1.name assert d1_dual.wires == d1.wires.dual assert d1_dual.parameter_set == d1.parameter_set - assert (vac >> d1 >> d1_dual).representation == vac.representation - assert (vac >> d1_dual >> d1).representation == vac.representation + assert (vac >> d1 >> d1_dual).ansatz == vac.ansatz + assert (vac >> d1_dual >> d1).ansatz == vac.ansatz d1_dual_dual = d1_dual.dual assert isinstance(d1_dual_dual, CircuitComponent) assert d1_dual_dual.parameter_set == d1_dual.parameter_set assert d1_dual_dual.parameter_set == d1.parameter_set assert d1_dual_dual.wires == d1.wires - assert d1_dual_dual.representation == d1.representation + assert d1_dual_dual.ansatz == d1.ansatz def test_light_copy(self): d1 = CircuitComponent( @@ -160,7 +158,7 @@ def test_light_copy(self): d1_cp = d1._light_copy() assert d1_cp.parameter_set is d1.parameter_set - assert d1_cp.representation is d1.representation + assert d1_cp.ansatz is d1.ansatz assert d1_cp.wires is not d1.wires def test_on(self): @@ -176,7 +174,7 @@ def test_on(self): assert isinstance(d67.r, Variable) assert math.allclose(d89.r.value, d67.r.value) assert bool(d67.parameter_set) is True - assert d67.representation is d89.representation + assert d67.ansatz is d89.ansatz def test_on_error(self): with pytest.raises(ValueError): @@ -185,17 +183,17 @@ def test_on_error(self): def test_to_fock_ket(self): vac = Vacuum([1, 2]) vac_fock = vac.to_fock(shape=[1, 2]) - assert vac_fock.representation == ArrayAnsatz(np.array([[1], [0]])) + assert vac_fock.ansatz == ArrayAnsatz(np.array([[1], [0]])) def test_to_fock_Number(self): num = Number([3], n=4) num_f = num.to_fock(shape=(6,)) - assert num_f.representation == ArrayAnsatz(np.array([0, 0, 0, 0, 1, 0])) + assert num_f.ansatz == ArrayAnsatz(np.array([0, 0, 0, 0, 1, 0])) def test_to_fock_Dgate(self): d = Dgate([1], x=0.1, y=0.1) d_fock = d.to_fock(shape=(4, 6)) - assert d_fock.representation == ArrayAnsatz( + assert d_fock.ansatz == ArrayAnsatz( math.hermite_renormalized(*displacement_gate_Abc(x=0.1, y=0.1), shape=(4, 6)) ) @@ -203,7 +201,7 @@ def test_to_fock_bargmann_Dgate(self): d = Dgate([1], x=0.1, y=0.1) d_fock = d.to_fock(shape=(4, 6)) d_barg = d_fock.to_bargmann() - assert d_fock.representation._original_abc_data == d.representation.triple + assert d_fock.ansatz._original_abc_data == d.ansatz.triple assert d_barg == d def test_to_fock_poly_exp(self): @@ -212,33 +210,33 @@ def test_to_fock_poly_exp(self): barg = PolyExpAnsatz(A, b, c) fock_cc = CircuitComponent(barg, wires=[(), (), (0, 1), ()]).to_fock(shape=(10, 10)) poly = math.hermite_renormalized(A, b, 1, (10, 10, 5)) - assert fock_cc.representation._original_abc_data is None - assert np.allclose(fock_cc.representation.data, np.einsum("ijk,k", poly, c[0])) + assert fock_cc.ansatz._original_abc_data is None + assert np.allclose(fock_cc.ansatz.data, np.einsum("ijk,k", poly, c[0])) def test_add(self): d1 = Dgate([1], x=0.1, y=0.1) d2 = Dgate([1], x=0.2, y=0.2) d12 = d1 + d2 - assert d12.representation == d1.representation + d2.representation + assert d12.ansatz == d1.ansatz + d2.ansatz def test_sub(self): s1 = DisplacedSqueezed([1], x=1.0, y=0.5, r=0.1) s2 = DisplacedSqueezed([1], x=0.5, y=0.2, r=0.2) s12 = s1 - s2 - assert s12.representation == s1.representation - s2.representation + assert s12.ansatz == s1.ansatz - s2.ansatz def test_mul(self): d1 = Dgate([1], x=0.1, y=0.1) - assert (d1 * 3).representation == d1.representation * 3 - assert (3 * d1).representation == d1.representation * 3 + assert (d1 * 3).ansatz == d1.ansatz * 3 + assert (3 * d1).ansatz == d1.ansatz * 3 assert isinstance(d1 * 3, Unitary) def test_truediv(self): d1 = Dgate([1], x=0.1, y=0.1) - assert (d1 / 3).representation == d1.representation / 3 + assert (d1 / 3).ansatz == d1.ansatz / 3 assert isinstance(d1 / 3, Unitary) def test_add_error(self): @@ -266,9 +264,9 @@ def test_matmul(self): result = result @ result.adjoint @ a0 @ a1 @ a2 assert result.wires == Wires(modes_out_bra={0, 1, 2}, modes_out_ket={0, 1, 2}) - assert np.allclose(result.representation.A, 0) + assert np.allclose(result.ansatz.A, 0) assert np.allclose( - result.representation.b, + result.ansatz.b, [ 0.08944272 - 0.08944272j, 0.08944272 - 0.08944272j, @@ -278,7 +276,7 @@ def test_matmul(self): 0.083666 + 0.083666j, ], ) - assert np.allclose(result.representation.c, 0.95504196) + assert np.allclose(result.ansatz.c, 0.95504196) def test_matmul_one_mode_Dgate_contraction(self): r""" @@ -295,7 +293,7 @@ def test_matmul_one_mode_Dgate_contraction(self): (alpha * np.conj(beta) - np.conj(alpha) * beta) / 2 ) - assert np.allclose(result1.representation.c, correct_c) + assert np.allclose(result1.ansatz.c, correct_c) def test_matmul_is_associative(self): d0 = Dgate([0], x=0.1, y=0.1) @@ -317,13 +315,13 @@ def test_matmul_is_associative(self): def test_matmul_scalar(self): d0 = Dgate([0], x=0.1, y=0.1) result = d0 @ 0.8 - assert math.allclose(result.representation.A, d0.representation.A) - assert math.allclose(result.representation.b, d0.representation.b) - assert math.allclose(result.representation.c, 0.8 * d0.representation.c) + assert math.allclose(result.ansatz.A, d0.ansatz.A) + assert math.allclose(result.ansatz.b, d0.ansatz.b) + assert math.allclose(result.ansatz.c, 0.8 * d0.ansatz.c) result2 = 0.8 @ d0 - assert math.allclose(result2.representation.A, d0.representation.A) - assert math.allclose(result2.representation.b, d0.representation.b) - assert math.allclose(result2.representation.c, 0.8 * d0.representation.c) + assert math.allclose(result2.ansatz.A, d0.ansatz.A) + assert math.allclose(result2.ansatz.b, d0.ansatz.b) + assert math.allclose(result2.ansatz.c, 0.8 * d0.ansatz.c) def test_rshift_all_bargmann(self): vac012 = Vacuum([0, 1, 2]) @@ -337,9 +335,9 @@ def test_rshift_all_bargmann(self): result = vac012 >> d0 >> d1 >> d2 >> a0 >> a1 >> a2 assert result.wires == Wires(modes_out_bra={0, 1, 2}, modes_out_ket={0, 1, 2}) - assert np.allclose(result.representation.A, 0) + assert np.allclose(result.ansatz.A, 0) assert np.allclose( - result.representation.b, + result.ansatz.b, [ 0.08944272 - 0.08944272j, 0.08944272 - 0.08944272j, @@ -349,7 +347,7 @@ def test_rshift_all_bargmann(self): 0.083666 + 0.083666j, ], ) - assert np.allclose(result.representation.c, 0.95504196) + assert np.allclose(result.ansatz.c, 0.95504196) def test_rshift_all_fock(self): vac012 = Vacuum([0, 1, 2]) @@ -400,7 +398,7 @@ def test_rshift_bargmann_and_fock(self, shape): def test_rshift_error(self): vac012 = Vacuum([0, 1, 2]) d0 = Dgate([0], x=0.1, y=0.1) - d0._multi_rep = Representation(d0.representation, Wires()) + d0._representation = Representation(d0.ansatz, Wires()) with pytest.raises(ValueError, match="not clear"): vac012 >> d0 @@ -432,10 +430,10 @@ def test_rshift_is_associative(self): def test_rshift_scalar(self): d0 = Dgate([0], x=0.1, y=0.1) result = 0.8 >> d0 - assert math.allclose(result, 0.8 * d0.representation.c) + assert math.allclose(result, 0.8 * d0.ansatz.c) result2 = d0 >> 0.8 - assert math.allclose(result2.representation.c, 0.8 * d0.representation.c) + assert math.allclose(result2.ansatz.c, 0.8 * d0.ansatz.c) def test_repr(self): c1 = CircuitComponent(wires=Wires(modes_out_ket=(0, 1, 2))) diff --git a/tests/test_lab_dev/test_circuit_components_utils.py b/tests/test_lab_dev/test_circuit_components_utils.py index f7f1f0fcb..2a5cf0a77 100644 --- a/tests/test_lab_dev/test_circuit_components_utils.py +++ b/tests/test_lab_dev/test_circuit_components_utils.py @@ -51,7 +51,7 @@ def test_init(self, modes): assert tr.name == "Tr" assert tr.wires == Wires(modes_in_bra=set(modes), modes_in_ket=set(modes)) - assert tr.representation == PolyExpAnsatz(*identity_Abc(len(modes))) + assert tr.ansatz == PolyExpAnsatz(*identity_Abc(len(modes))) def test_trace_out_bargmann_states(self): state = Coherent([0, 1, 2], x=1) @@ -99,13 +99,13 @@ def test_init(self, modes, s): assert dsmap.modes == [modes] if not isinstance(modes, list) else sorted(modes) def test_representation(self): - rep1 = BtoPS(modes=[0], s=0).representation # pylint: disable=protected-access + rep1 = BtoPS(modes=[0], s=0).ansatz # pylint: disable=protected-access A_correct, b_correct, c_correct = displacement_map_s_parametrized_Abc(s=0, n_modes=1) assert math.allclose(rep1.A[0], A_correct) assert math.allclose(rep1.b[0], b_correct) assert math.allclose(rep1.c[0], c_correct) - rep2 = BtoPS(modes=[5, 10], s=1).representation # pylint: disable=protected-access + rep2 = BtoPS(modes=[5, 10], s=1).ansatz # pylint: disable=protected-access A_correct, b_correct, c_correct = displacement_map_s_parametrized_Abc(s=1, n_modes=2) assert math.allclose(rep2.A[0], A_correct) assert math.allclose(rep2.b[0], b_correct) @@ -178,9 +178,9 @@ def testBtoQ_works_correctly_by_applying_it_twice_on_a_state(self): modes = [0, 1] BtoQ_CC1 = BtoQ(modes, 0.0) step1A, step1b, step1c = ( - BtoQ_CC1.representation.A[0], - BtoQ_CC1.representation.b[0], - BtoQ_CC1.representation.c[0], + BtoQ_CC1.ansatz.A[0], + BtoQ_CC1.ansatz.b[0], + BtoQ_CC1.ansatz.c[0], ) Ainter, binter, cinter = complex_gaussian_integral( join_Abc((A0, b0, c0), (step1A, step1b, step1c)), @@ -190,9 +190,9 @@ def testBtoQ_works_correctly_by_applying_it_twice_on_a_state(self): ) QtoBMap_CC2 = BtoQ(modes, 0.0).dual step2A, step2b, step2c = ( - QtoBMap_CC2.representation.A[0], - QtoBMap_CC2.representation.b[0], - QtoBMap_CC2.representation.c[0], + QtoBMap_CC2.ansatz.A[0], + QtoBMap_CC2.ansatz.b[0], + QtoBMap_CC2.ansatz.c[0], ) new_A, new_b, new_c = join_Abc_real( @@ -212,9 +212,9 @@ def testBtoQ_works_correctly_by_applying_it_twice_on_a_state(self): modes = [0] BtoQ_CC1 = BtoQ(modes, 0.0) step1A, step1b, step1c = ( - BtoQ_CC1.representation.A[0], - BtoQ_CC1.representation.b[0], - BtoQ_CC1.representation.c[0], + BtoQ_CC1.ansatz.A[0], + BtoQ_CC1.ansatz.b[0], + BtoQ_CC1.ansatz.c[0], ) Ainter, binter, cinter = complex_gaussian_integral( join_Abc((A0, b0, c0), (step1A, step1b, step1c)), @@ -226,9 +226,9 @@ def testBtoQ_works_correctly_by_applying_it_twice_on_a_state(self): ) QtoBMap_CC2 = BtoQ(modes, 0.0).dual step2A, step2b, step2c = ( - QtoBMap_CC2.representation.A[0], - QtoBMap_CC2.representation.b[0], - QtoBMap_CC2.representation.c[0], + QtoBMap_CC2.ansatz.A[0], + QtoBMap_CC2.ansatz.b[0], + QtoBMap_CC2.ansatz.c[0], ) new_A, new_b, new_c = join_Abc_real( @@ -261,6 +261,6 @@ def wavefunction_coh(alpha, quad, axis_angle): quad = np.random.random() state = Coherent([0], x, y) - wavefunction = (state >> BtoQ([0], axis_angle)).representation + wavefunction = (state >> BtoQ([0], axis_angle)).ansatz assert np.allclose(wavefunction(quad), wavefunction_coh(x + 1j * y, quad, axis_angle)) diff --git a/tests/test_lab_dev/test_circuits.py b/tests/test_lab_dev/test_circuits.py index 6d8ddac90..f25db3d69 100644 --- a/tests/test_lab_dev/test_circuits.py +++ b/tests/test_lab_dev/test_circuits.py @@ -179,7 +179,7 @@ def test_repr(self): n12 = Number([0, 1], n=3) n2 = Number([2], n=3) cc = CircuitComponent._from_attributes( - bs01.representation, bs01.wires, "my_cc" + bs01.ansatz, bs01.wires, "my_cc" ) # pylint: disable=protected-access assert repr(Circuit()) == "" diff --git a/tests/test_lab_dev/test_states/test_coherent.py b/tests/test_lab_dev/test_states/test_coherent.py index fc5c70870..c1ec6dac6 100644 --- a/tests/test_lab_dev/test_states/test_coherent.py +++ b/tests/test_lab_dev/test_states/test_coherent.py @@ -61,24 +61,24 @@ def test_trainable_parameters(self): assert state3.y.value == 2 def test_representation(self): - rep1 = Coherent(modes=[0], x=0.1, y=0.2).representation + rep1 = Coherent(modes=[0], x=0.1, y=0.2).ansatz assert math.allclose(rep1.A, np.zeros((1, 1, 1))) assert math.allclose(rep1.b, [[0.1 + 0.2j]]) assert math.allclose(rep1.c, [0.97530991]) - rep2 = Coherent(modes=[0, 1], x=0.1, y=[0.2, 0.3]).representation + rep2 = Coherent(modes=[0, 1], x=0.1, y=[0.2, 0.3]).ansatz assert math.allclose(rep2.A, np.zeros((1, 2, 2))) assert math.allclose(rep2.b, [[0.1 + 0.2j, 0.1 + 0.3j]]) assert math.allclose(rep2.c, [0.9277434863]) - rep3 = Coherent(modes=[1], x=0.1).representation + rep3 = Coherent(modes=[1], x=0.1).ansatz assert math.allclose(rep3.A, np.zeros((1, 1, 1))) assert math.allclose(rep3.b, [[0.1]]) assert math.allclose(rep3.c, [0.9950124791926823]) def test_representation_error(self): with pytest.raises(ValueError): - Coherent(modes=[0], x=[0.1, 0.2]).representation + Coherent(modes=[0], x=[0.1, 0.2]).ansatz def test_linear_combinations(self): state1 = Coherent([0], x=1, y=2) @@ -86,11 +86,11 @@ def test_linear_combinations(self): state3 = Coherent([0], x=3, y=4) lc = state1 + state2 - state3 - assert lc.representation.batch_size == 3 + assert lc.ansatz.batch_size == 3 - assert (lc @ lc.dual).representation.batch_size == 9 + assert (lc @ lc.dual).ansatz.batch_size == 9 settings.UNSAFE_ZIP_BATCH = True - assert (lc @ lc.dual).representation.batch_size == 3 # not 9 + assert (lc @ lc.dual).ansatz.batch_size == 3 # not 9 settings.UNSAFE_ZIP_BATCH = False def test_vacuum_shape(self): diff --git a/tests/test_lab_dev/test_states/test_displaced_squeezed.py b/tests/test_lab_dev/test_states/test_displaced_squeezed.py index 9951d14d1..276ec5be1 100644 --- a/tests/test_lab_dev/test_states/test_displaced_squeezed.py +++ b/tests/test_lab_dev/test_states/test_displaced_squeezed.py @@ -63,10 +63,10 @@ def test_trainable_parameters(self): @pytest.mark.parametrize("modes,x,y,r,phi", zip(modes, x, y, r, phi)) def test_representation(self, modes, x, y, r, phi): - rep = DisplacedSqueezed(modes, x, y, r, phi).representation - exp = (Vacuum(modes) >> Sgate(modes, r, phi) >> Dgate(modes, x, y)).representation + rep = DisplacedSqueezed(modes, x, y, r, phi).ansatz + exp = (Vacuum(modes) >> Sgate(modes, r, phi) >> Dgate(modes, x, y)).ansatz assert rep == exp def test_representation_error(self): with pytest.raises(ValueError): - DisplacedSqueezed(modes=[0], x=[0.1, 0.2]).representation + DisplacedSqueezed(modes=[0], x=[0.1, 0.2]).ansatz diff --git a/tests/test_lab_dev/test_states/test_number.py b/tests/test_lab_dev/test_states/test_number.py index c67c3601a..998923a62 100644 --- a/tests/test_lab_dev/test_states/test_number.py +++ b/tests/test_lab_dev/test_states/test_number.py @@ -50,13 +50,13 @@ def test_init_error(self): @pytest.mark.parametrize("n", [2, [2, 3], [4, 4]]) @pytest.mark.parametrize("cutoffs", [None, [4, 5], [5, 5]]) def test_representation(self, n, cutoffs): - rep1 = Number([0, 1], n, cutoffs).representation.array + rep1 = Number([0, 1], n, cutoffs).ansatz.array exp1 = fock_state((n,) * 2 if isinstance(n, int) else n, cutoffs) assert math.allclose(rep1, math.asnumpy(exp1).reshape(1, *exp1.shape)) - rep2 = Number([0, 1], n, cutoffs).to_fock().representation.array + rep2 = Number([0, 1], n, cutoffs).to_fock().ansatz.array assert math.allclose(rep2, rep1) def test_representation_error(self): with pytest.raises(ValueError): - Coherent(modes=[0], x=[0.1, 0.2]).representation + Coherent(modes=[0], x=[0.1, 0.2]).ansatz diff --git a/tests/test_lab_dev/test_states/test_quadrature_eigenstate.py b/tests/test_lab_dev/test_states/test_quadrature_eigenstate.py index 07f31d1a0..b203545ea 100644 --- a/tests/test_lab_dev/test_states/test_quadrature_eigenstate.py +++ b/tests/test_lab_dev/test_states/test_quadrature_eigenstate.py @@ -70,7 +70,7 @@ def test_probability_hbar(self, hbar): def test_representation_error(self): with pytest.raises(ValueError): - QuadratureEigenstate(modes=[0], x=[0.1, 0.2]).representation + QuadratureEigenstate(modes=[0], x=[0.1, 0.2]).ansatz def test_trainable_parameters(self): state1 = QuadratureEigenstate([0, 1], 1, 1) diff --git a/tests/test_lab_dev/test_states/test_squeezed_vacuum.py b/tests/test_lab_dev/test_states/test_squeezed_vacuum.py index 7ca95cef0..ccb0d6233 100644 --- a/tests/test_lab_dev/test_states/test_squeezed_vacuum.py +++ b/tests/test_lab_dev/test_states/test_squeezed_vacuum.py @@ -65,10 +65,10 @@ def test_trainable_parameters(self): @pytest.mark.parametrize("modes,r,phi", zip(modes, r, phi)) def test_representation(self, modes, r, phi): - rep = SqueezedVacuum(modes, r, phi).representation - exp = (Vacuum(modes) >> Sgate(modes, r, phi)).representation + rep = SqueezedVacuum(modes, r, phi).ansatz + exp = (Vacuum(modes) >> Sgate(modes, r, phi)).ansatz assert rep == exp def test_representation_error(self): with pytest.raises(ValueError): - SqueezedVacuum(modes=[0], r=[0.1, 0.2]).representation + SqueezedVacuum(modes=[0], r=[0.1, 0.2]).ansatz diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index e07a08be5..58895c328 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -141,7 +141,7 @@ def test_to_from_fock(self, modes): state_in_fock = state_in.to_fock(5) array_in = state_in.fock(5, batched=True) - assert math.allclose(array_in, state_in_fock.representation.array) + assert math.allclose(array_in, state_in_fock.ansatz.array) state_out = Ket.from_fock(modes, array_in, "my_ket", True) assert state_in_fock == state_out @@ -205,7 +205,7 @@ def test_dm(self): dm = ket.dm() assert dm.name == ket.name - assert dm.representation == (ket @ ket.adjoint).representation + assert dm.ansatz == (ket @ ket.adjoint).ansatz assert dm.wires == (ket @ ket.adjoint).wires @pytest.mark.parametrize("phi", [0, 0.3, np.pi / 4, np.pi / 2]) @@ -268,7 +268,7 @@ def test_expectation_bargmann(self): assert math.allclose(ket.expectation(k0), res_k0) assert math.allclose(ket.expectation(k1), res_k1) - assert math.allclose(ket.expectation(k01), math.sum(res_k01.representation.c)) + assert math.allclose(ket.expectation(k01), math.sum(res_k01.ansatz.c)) dm0 = Coherent([0], x=1, y=2).dm() dm1 = Coherent([1], x=1, y=3).dm() @@ -280,7 +280,7 @@ def test_expectation_bargmann(self): assert math.allclose(ket.expectation(dm0), res_dm0) assert math.allclose(ket.expectation(dm1), res_dm1) - assert math.allclose(ket.expectation(dm01), math.sum(res_dm01.representation.c)) + assert math.allclose(ket.expectation(dm01), math.sum(res_dm01.ansatz.c)) u0 = Dgate([1], x=0.1) u1 = Dgate([0], x=0.2) @@ -317,7 +317,7 @@ def test_expectation_fock(self): res_dm0 = (ket @ ket.adjoint @ dm0.dual) >> TraceOut([1]) res_dm1 = (ket @ ket.adjoint @ dm1.dual) >> TraceOut([0]) - res_dm01 = (ket @ ket.adjoint @ dm01.dual).to_fock(10).representation.array + res_dm01 = (ket @ ket.adjoint @ dm01.dual).to_fock(10).ansatz.array assert math.allclose(ket.expectation(dm0), res_dm0) assert math.allclose(ket.expectation(dm1), res_dm1) @@ -327,9 +327,9 @@ def test_expectation_fock(self): u1 = Dgate([0], x=0.2) u01 = Dgate([0, 1], x=[0.3, 0.4]) - res_u0 = (ket @ u0 @ ket.dual).to_fock(10).representation.array - res_u1 = (ket @ u1 @ ket.dual).to_fock(10).representation.array - res_u01 = (ket @ u01 @ ket.dual).to_fock(10).representation.array + res_u0 = (ket @ u0 @ ket.dual).to_fock(10).ansatz.array + res_u1 = (ket @ u1 @ ket.dual).to_fock(10).ansatz.array + res_u01 = (ket @ u01 @ ket.dual).to_fock(10).ansatz.array assert math.allclose(ket.expectation(u0), res_u0[0]) assert math.allclose(ket.expectation(u1), res_u1[0]) @@ -356,11 +356,11 @@ def test_rshift(self): ket = Coherent([0, 1], 1) unitary = Dgate([0], 1) u_component = CircuitComponent._from_attributes( - unitary.representation, unitary.wires, unitary.name + unitary.ansatz, unitary.wires, unitary.name ) # pylint: disable=protected-access channel = Attenuator([1], 1) ch_component = CircuitComponent._from_attributes( - channel.representation, + channel.ansatz, channel.wires, channel.name, ) # pylint: disable=protected-access @@ -434,7 +434,7 @@ def test_unsafe_batch_zipping(self): @pytest.mark.parametrize("max_sq", [1, 2, 3]) def test_random_states(self, max_sq): psi = Ket.random([1, 22], max_sq) - A = psi.representation.A[0] + A = psi.ansatz.A[0] assert np.isclose(psi.probability, 1) # checks if the state is normalized assert np.allclose(A - np.transpose(A), np.zeros(2)) # checks if the A matrix is symmetric @@ -598,7 +598,7 @@ def test_to_from_fock(self, modes): state_in_fock = state_in.to_fock(5) array_in = state_in.fock(5, batched=True) - assert math.allclose(array_in, state_in_fock.representation.array) + assert math.allclose(array_in, state_in_fock.ansatz.array) state_out = DM.from_fock(modes, array_in, "my_dm", True) assert state_in_fock == state_out @@ -732,7 +732,7 @@ def test_expectation_bargmann_ket(self): assert math.allclose(dm.expectation(k0), res_k0) assert math.allclose(dm.expectation(k1), res_k1) - assert math.allclose(dm.expectation(k01), res_k01.representation.c[0]) + assert math.allclose(dm.expectation(k01), res_k01.ansatz.c[0]) def test_expectation_bargmann_dm(self): dm0 = Coherent([0], x=1, y=2).dm() @@ -822,11 +822,11 @@ def test_rshift(self): ket = Coherent([0, 1], 1) unitary = Dgate([0], 1) u_component = CircuitComponent._from_attributes( - unitary.representation, unitary.wires, unitary.name + unitary.ansatz, unitary.wires, unitary.name ) # pylint: disable=protected-access channel = Attenuator([1], 1) ch_component = CircuitComponent._from_attributes( - channel.representation, channel.wires, channel.name + channel.ansatz, channel.wires, channel.name ) # pylint: disable=protected-access dm = ket >> channel @@ -846,7 +846,7 @@ def test_rshift(self): def test_random(self, modes): m = len(modes) dm = DM.random(modes) - A = dm.representation.A[0] + A = dm.ansatz.A[0] Gamma = A[:m, m:] Lambda = A[m:, m:] Temp = Gamma + math.conj(Lambda.T) @ math.inv(1 - Gamma.T) @ Lambda diff --git a/tests/test_lab_dev/test_states/test_thermal.py b/tests/test_lab_dev/test_states/test_thermal.py index 968a0dd0e..d73ffe4dc 100644 --- a/tests/test_lab_dev/test_states/test_thermal.py +++ b/tests/test_lab_dev/test_states/test_thermal.py @@ -48,10 +48,10 @@ def test_init_error(self): @pytest.mark.parametrize("nbar", [1, [2, 3], [4, 4]]) def test_representation(self, nbar): - rep = Thermal([0, 1], nbar).representation + rep = Thermal([0, 1], nbar).ansatz exp = PolyExpAnsatz(*thermal_state_Abc([nbar, nbar] if isinstance(nbar, int) else nbar)) assert rep == exp def test_representation_error(self): with pytest.raises(ValueError): - Thermal(modes=[0], nbar=[0.1, 0.2]).representation + Thermal(modes=[0], nbar=[0.1, 0.2]).ansatz diff --git a/tests/test_lab_dev/test_states/test_two_mode_squeezed_vacuum.py b/tests/test_lab_dev/test_states/test_two_mode_squeezed_vacuum.py index e9115b3c3..4846761b7 100644 --- a/tests/test_lab_dev/test_states/test_two_mode_squeezed_vacuum.py +++ b/tests/test_lab_dev/test_states/test_two_mode_squeezed_vacuum.py @@ -61,10 +61,10 @@ def test_trainable_parameters(self): @pytest.mark.parametrize("modes,r,phi", zip(modes, r, phi)) def test_representation(self, modes, r, phi): - rep = TwoModeSqueezedVacuum(modes, r, phi).representation - exp = (Vacuum(modes) >> S2gate(modes, r, phi)).representation + rep = TwoModeSqueezedVacuum(modes, r, phi).ansatz + exp = (Vacuum(modes) >> S2gate(modes, r, phi)).ansatz assert rep == exp def test_representation_error(self): with pytest.raises(ValueError): - TwoModeSqueezedVacuum(modes=[0], r=[0.1, 0.2]).representation + TwoModeSqueezedVacuum(modes=[0], r=[0.1, 0.2]).ansatz diff --git a/tests/test_lab_dev/test_states/test_vacuum.py b/tests/test_lab_dev/test_states/test_vacuum.py index 580446389..19a750296 100644 --- a/tests/test_lab_dev/test_states/test_vacuum.py +++ b/tests/test_lab_dev/test_states/test_vacuum.py @@ -38,7 +38,7 @@ def test_init(self, modes): @pytest.mark.parametrize("n_modes", [1, 3]) def test_representation(self, n_modes): - rep = Vacuum(range(n_modes)).representation + rep = Vacuum(range(n_modes)).ansatz assert math.allclose(rep.A, np.zeros((1, n_modes, n_modes))) assert math.allclose(rep.b, np.zeros((1, n_modes))) diff --git a/tests/test_lab_dev/test_transformations/test_amplifier.py b/tests/test_lab_dev/test_transformations/test_amplifier.py index 21b332a88..2862d9dda 100644 --- a/tests/test_lab_dev/test_transformations/test_amplifier.py +++ b/tests/test_lab_dev/test_transformations/test_amplifier.py @@ -44,7 +44,7 @@ def test_init_error(self): Amplifier(modes=[0, 1], gain=[1.2, 1.3, 1.4]) def test_representation(self): - rep1 = Amplifier(modes=[0], gain=1.1).representation + rep1 = Amplifier(modes=[0], gain=1.1).ansatz g1 = 0.95346258 g2 = 0.09090909 assert math.allclose( @@ -65,7 +65,7 @@ def test_trainable_parameters(self): def test_representation_error(self): with pytest.raises(ValueError): - Amplifier(modes=[0], gain=[1.1, 1.2]).representation + Amplifier(modes=[0], gain=[1.1, 1.2]).ansatz def test_operation(self): amp_channel = Amplifier(modes=[0], gain=1.5) @@ -73,7 +73,7 @@ def test_operation(self): operation = amp_channel >> att_channel assert math.allclose( - operation.representation.A, + operation.ansatz.A, [ [ [0.0 + 0.0j, 0.75903339 + 0.0j, 0.25925926 + 0.0j, 0.0 + 0.0j], @@ -83,8 +83,8 @@ def test_operation(self): ] ], ) - assert math.allclose(operation.representation.b, np.zeros((1, 4))) - assert math.allclose(operation.representation.c, [0.74074074 + 0.0j]) + assert math.allclose(operation.ansatz.b, np.zeros((1, 4))) + assert math.allclose(operation.ansatz.c, [0.74074074 + 0.0j]) def test_circuit_identity(self): amp_channel = Amplifier(modes=[0], gain=2) @@ -92,12 +92,12 @@ def test_circuit_identity(self): input_state = Coherent(modes=[0], x=0.5, y=0.7) assert math.allclose( - (input_state >> amp_channel).representation.A, - (input_state >> att_channel.dual).representation.A, + (input_state >> amp_channel).ansatz.A, + (input_state >> att_channel.dual).ansatz.A, ) assert math.allclose( - (input_state >> amp_channel).representation.b, - (input_state >> att_channel.dual).representation.b, + (input_state >> amp_channel).ansatz.b, + (input_state >> att_channel.dual).ansatz.b, ) @pytest.mark.parametrize("n", [1, 2, 3, 4, 5]) diff --git a/tests/test_lab_dev/test_transformations/test_attenuator.py b/tests/test_lab_dev/test_transformations/test_attenuator.py index a8e2d9dec..ef7ebaff0 100644 --- a/tests/test_lab_dev/test_transformations/test_attenuator.py +++ b/tests/test_lab_dev/test_transformations/test_attenuator.py @@ -43,7 +43,7 @@ def test_init_error(self): Attenuator(modes=[0, 1], transmissivity=[0.2, 0.3, 0.4]) def test_representation(self): - rep1 = Attenuator(modes=[0], transmissivity=0.1).representation + rep1 = Attenuator(modes=[0], transmissivity=0.1).ansatz e = 0.31622777 assert math.allclose(rep1.A, [[[0, e, 0, 0], [e, 0, 0, 0.9], [0, 0, 0, e], [0, 0.9, e, 0]]]) assert math.allclose(rep1.b, np.zeros((1, 4))) @@ -63,4 +63,4 @@ def test_trainable_parameters(self): def test_representation_error(self): with pytest.raises(ValueError): - Attenuator(modes=[0], transmissivity=[0.1, 0.2]).representation + Attenuator(modes=[0], transmissivity=[0.1, 0.2]).ansatz diff --git a/tests/test_lab_dev/test_transformations/test_bsgate.py b/tests/test_lab_dev/test_transformations/test_bsgate.py index a5066be08..62a2aa50d 100644 --- a/tests/test_lab_dev/test_transformations/test_bsgate.py +++ b/tests/test_lab_dev/test_transformations/test_bsgate.py @@ -45,7 +45,7 @@ def test_init_error(self): BSgate([1, 2, 3]) def test_representation(self): - rep1 = BSgate([0, 1], 0.1, 0.2).representation + rep1 = BSgate([0, 1], 0.1, 0.2).ansatz A_exp = [ [ [0, 0, 0.99500417, -0.0978434 + 0.01983384j], @@ -58,7 +58,7 @@ def test_representation(self): assert math.allclose(rep1.b, np.zeros((1, 4))) assert math.allclose(rep1.c, [1]) - rep2 = BSgate([0, 1], 0.1).representation + rep2 = BSgate([0, 1], 0.1).ansatz A_exp = [ [ [0, 0, 9.95004165e-01, -9.98334166e-02], diff --git a/tests/test_lab_dev/test_transformations/test_cft.py b/tests/test_lab_dev/test_transformations/test_cft.py index d89fde6aa..b434b41b7 100644 --- a/tests/test_lab_dev/test_transformations/test_cft.py +++ b/tests/test_lab_dev/test_transformations/test_cft.py @@ -40,11 +40,11 @@ def test_wigner_function(self): state = Ket.random([0]) >> Dgate([0], x=1.0, y=0.1) - dm = math.sum(state.to_fock(100).dm().representation.array, axes=[0]) + dm = math.sum(state.to_fock(100).dm().ansatz.array, axes=[0]) vec = np.linspace(-5, 5, 100) wigner, _, _ = wigner_discretized(dm, vec, vec) - Wigner = (state >> CFT([0]).inverse() >> BtoPS([0], s=0)).representation + Wigner = (state >> CFT([0]).inverse() >> BtoPS([0], s=0)).ansatz X, Y = np.meshgrid( vec * np.sqrt(2 / settings.HBAR), vec * np.sqrt(2 / settings.HBAR) ) # scaling to take care of HBAR diff --git a/tests/test_lab_dev/test_transformations/test_dgate.py b/tests/test_lab_dev/test_transformations/test_dgate.py index d72ed4719..bbcaebd99 100644 --- a/tests/test_lab_dev/test_transformations/test_dgate.py +++ b/tests/test_lab_dev/test_transformations/test_dgate.py @@ -55,17 +55,17 @@ def test_to_fock_method(self): assert np.all(math.abs(dgate.fock(150)) < 1) def test_representation(self): - rep1 = Dgate(modes=[0], x=0.1, y=0.1).representation + rep1 = Dgate(modes=[0], x=0.1, y=0.1).ansatz assert math.allclose(rep1.A, [[[0, 1], [1, 0]]]) assert math.allclose(rep1.b, [[0.1 + 0.1j, -0.1 + 0.1j]]) assert math.allclose(rep1.c, [0.990049833749168]) - rep2 = Dgate(modes=[0, 1], x=[0.1, 0.2], y=0.1).representation + rep2 = Dgate(modes=[0, 1], x=[0.1, 0.2], y=0.1).ansatz assert math.allclose(rep2.A, [[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]]) assert math.allclose(rep2.b, [[0.1 + 0.1j, 0.2 + 0.1j, -0.1 + 0.1j, -0.2 + 0.1j]]) assert math.allclose(rep2.c, [0.9656054162575665]) - rep3 = Dgate(modes=[1, 8], x=[0.1, 0.2]).representation + rep3 = Dgate(modes=[1, 8], x=[0.1, 0.2]).ansatz assert math.allclose(rep3.A, [[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]]) assert math.allclose(rep3.b, [[0.1, 0.2, -0.1, -0.2]]) assert math.allclose(rep3.c, [0.9753099120283327]) @@ -85,9 +85,9 @@ def test_trainable_parameters(self): assert gate3.y.value == 2 gate_fock = gate3.to_fock() - assert isinstance(gate_fock.representation, ArrayAnsatz) + assert isinstance(gate_fock.ansatz, ArrayAnsatz) assert gate_fock.y.value == 2 def test_representation_error(self): with pytest.raises(ValueError): - Dgate(modes=[0], x=[0.1, 0.2]).representation + Dgate(modes=[0], x=[0.1, 0.2]).ansatz diff --git a/tests/test_lab_dev/test_transformations/test_fockdamping.py b/tests/test_lab_dev/test_transformations/test_fockdamping.py index 25254879b..a59308d0f 100644 --- a/tests/test_lab_dev/test_transformations/test_fockdamping.py +++ b/tests/test_lab_dev/test_transformations/test_fockdamping.py @@ -40,7 +40,7 @@ def test_init(self, modes, damping): assert np.allclose(gate.damping.value, damping) def test_representation(self): - rep1 = FockDamping(modes=[0], damping=0.1).representation + rep1 = FockDamping(modes=[0], damping=0.1).ansatz e = math.exp(-0.1) assert math.allclose( rep1.A, @@ -69,11 +69,11 @@ def test_trainable_parameters(self): def test_representation_error(self): with pytest.raises(ValueError): - FockDamping(modes=[0], damping=[0.1, 0.2]).representation + FockDamping(modes=[0], damping=[0.1, 0.2]).ansatz def test_identity(self): - rep1 = FockDamping(modes=[0, 1], damping=0.0).representation - rep2 = Identity(modes=[0, 1]).representation + rep1 = FockDamping(modes=[0, 1], damping=0.0).ansatz + rep2 = Identity(modes=[0, 1]).ansatz assert math.allclose(rep1.A, rep2.A) assert math.allclose(rep1.b, rep2.b) diff --git a/tests/test_lab_dev/test_transformations/test_identity.py b/tests/test_lab_dev/test_transformations/test_identity.py index c437f6d1d..1d5691959 100644 --- a/tests/test_lab_dev/test_transformations/test_identity.py +++ b/tests/test_lab_dev/test_transformations/test_identity.py @@ -45,7 +45,7 @@ def test_init_error(self): Identity() def test_representation(self): - rep1 = Identity(modes=[0]).representation + rep1 = Identity(modes=[0]).ansatz assert math.allclose( rep1.A, [ @@ -58,7 +58,7 @@ def test_representation(self): assert math.allclose(rep1.b, np.zeros((1, 2))) assert math.allclose(rep1.c, [1.0 + 0.0j]) - rep2 = Identity(modes=[0, 1]).representation + rep2 = Identity(modes=[0, 1]).ansatz assert math.allclose( rep2.A, [ diff --git a/tests/test_lab_dev/test_transformations/test_rgate.py b/tests/test_lab_dev/test_transformations/test_rgate.py index 4ce4b5afd..d1b9b9786 100644 --- a/tests/test_lab_dev/test_transformations/test_rgate.py +++ b/tests/test_lab_dev/test_transformations/test_rgate.py @@ -43,7 +43,7 @@ def test_init_error(self): Rgate(modes=[0, 1], phi=[2, 3, 4]) def test_representation(self): - rep1 = Rgate(modes=[0], phi=0.1).representation + rep1 = Rgate(modes=[0], phi=0.1).ansatz assert math.allclose( rep1.A, [ @@ -56,7 +56,7 @@ def test_representation(self): assert math.allclose(rep1.b, np.zeros((1, 2))) assert math.allclose(rep1.c, [1.0 + 0.0j]) - rep2 = Rgate(modes=[0, 1], phi=[0.1, 0.3]).representation + rep2 = Rgate(modes=[0, 1], phi=[0.1, 0.3]).ansatz assert math.allclose( rep2.A, [ @@ -71,7 +71,7 @@ def test_representation(self): assert math.allclose(rep2.b, np.zeros((1, 4))) assert math.allclose(rep2.c, [1.0 + 0.0j]) - rep3 = Rgate(modes=[1], phi=0.1).representation + rep3 = Rgate(modes=[1], phi=0.1).ansatz assert math.allclose( rep3.A, [ @@ -96,4 +96,4 @@ def test_trainable_parameters(self): def test_representation_error(self): with pytest.raises(ValueError): - Rgate(modes=[0], phi=[0.1, 0.2]).representation + Rgate(modes=[0], phi=[0.1, 0.2]).ansatz diff --git a/tests/test_lab_dev/test_transformations/test_s2gate.py b/tests/test_lab_dev/test_transformations/test_s2gate.py index e5639d4c0..030ff51bd 100644 --- a/tests/test_lab_dev/test_transformations/test_s2gate.py +++ b/tests/test_lab_dev/test_transformations/test_s2gate.py @@ -46,7 +46,7 @@ def test_init_error(self): S2gate([1, 2, 3]) def test_representation(self): - rep1 = S2gate([0, 1], 0.1, 0.2).representation + rep1 = S2gate([0, 1], 0.1, 0.2).ansatz tanhr = np.exp(1j * 0.2) * np.sinh(0.1) / np.cosh(0.1) sechr = 1 / np.cosh(0.1) @@ -77,8 +77,8 @@ def test_trainable_parameters(self): assert gate3.phi.value == 2 def test_operation(self): - rep1 = (Vacuum([0]) >> Vacuum([1]) >> S2gate(modes=[0, 1], r=1, phi=0.5)).representation - rep2 = (TwoModeSqueezedVacuum(modes=[0, 1], r=1, phi=0.5)).representation + rep1 = (Vacuum([0]) >> Vacuum([1]) >> S2gate(modes=[0, 1], r=1, phi=0.5)).ansatz + rep2 = (TwoModeSqueezedVacuum(modes=[0, 1], r=1, phi=0.5)).ansatz assert math.allclose(rep1.A, rep2.A) assert math.allclose(rep1.b, rep2.b) diff --git a/tests/test_lab_dev/test_transformations/test_sgate.py b/tests/test_lab_dev/test_transformations/test_sgate.py index 8f5263900..73a927044 100644 --- a/tests/test_lab_dev/test_transformations/test_sgate.py +++ b/tests/test_lab_dev/test_transformations/test_sgate.py @@ -47,7 +47,7 @@ def test_init_error(self): Sgate(modes=[0, 1], r=1, phi=[2, 3, 4]) def test_representation(self): - rep1 = Sgate(modes=[0], r=0.1, phi=0.2).representation + rep1 = Sgate(modes=[0], r=0.1, phi=0.2).ansatz assert math.allclose( rep1.A, [ @@ -60,7 +60,7 @@ def test_representation(self): assert math.allclose(rep1.b, np.zeros((1, 2))) assert math.allclose(rep1.c, [0.9975072676192522]) - rep2 = Sgate(modes=[0, 1], r=[0.1, 0.3], phi=0.2).representation + rep2 = Sgate(modes=[0, 1], r=[0.1, 0.3], phi=0.2).ansatz assert math.allclose( rep2.A, [ @@ -75,7 +75,7 @@ def test_representation(self): assert math.allclose(rep2.b, np.zeros((1, 4))) assert math.allclose(rep2.c, [0.9756354961606032]) - rep3 = Sgate(modes=[1], r=0.1).representation + rep3 = Sgate(modes=[1], r=0.1).ansatz assert math.allclose( rep3.A, [ @@ -104,4 +104,4 @@ def test_trainable_parameters(self): def test_representation_error(self): with pytest.raises(ValueError): - Sgate(modes=[0], r=[0.1, 0.2]).representation + Sgate(modes=[0], r=[0.1, 0.2]).ansatz diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 0d091c0c1..aa6e384bf 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -44,8 +44,8 @@ def test_init_from_bargmann(self): b = np.array([0, 1, 5]) c = 1 operator = Operation.from_bargmann([0], [1, 2], (A, b, c), "my_operator") - assert np.allclose(operator.representation.A[None, ...], A) - assert np.allclose(operator.representation.b[None, ...], b) + assert np.allclose(operator.ansatz.A[None, ...], A) + assert np.allclose(operator.ansatz.b[None, ...], b) class TestUnitary: @@ -66,11 +66,11 @@ def test_rshift(self): unitary1 = Dgate([0, 1], 1) unitary2 = Dgate([1, 2], 2) u_component = CircuitComponent._from_attributes( - unitary1.representation, unitary1.wires, unitary1.name + unitary1.ansatz, unitary1.wires, unitary1.name ) # pylint: disable=protected-access channel = Attenuator([1], 1) ch_component = CircuitComponent._from_attributes( - channel.representation, channel.wires, channel.name + channel.ansatz, channel.wires, channel.name ) # pylint: disable=protected-access assert isinstance(unitary1 >> unitary2, Unitary) @@ -81,7 +81,7 @@ def test_rshift(self): def test_repr(self): unitary1 = Dgate([0, 1], 1) u_component = CircuitComponent._from_attributes( - unitary1.representation, unitary1.wires, unitary1.name + unitary1.ansatz, unitary1.wires, unitary1.name ) # pylint: disable=protected-access assert repr(unitary1) == "Dgate(modes=[0, 1], name=Dgate, repr=PolyExpAnsatz)" assert repr(unitary1.to_fock(5)) == "Dgate(modes=[0, 1], name=Dgate, repr=ArrayAnsatz)" @@ -96,8 +96,8 @@ def test_init_from_bargmann(self): b = np.array([0, 0]) c = 1 gate = Unitary.from_bargmann([2], [2], (A, b, c), "my_unitary") - assert np.allclose(gate.representation.A[None, ...], A) - assert np.allclose(gate.representation.b[None, ...], b) + assert np.allclose(gate.ansatz.A[None, ...], A) + assert np.allclose(gate.ansatz.b[None, ...], b) def test_init_from_symplectic(self): S = math.random_symplectic(2) @@ -111,7 +111,7 @@ def test_inverse_unitary(self): gate_inv_inv = gate_inv.inverse() assert gate_inv_inv == gate should_be_identity = gate >> gate_inv - assert should_be_identity.representation == Dgate([0], 0.0, 0.0).representation + assert should_be_identity.ansatz == Dgate([0], 0.0, 0.0).ansatz def test_random(self): modes = [3, 1, 20] @@ -143,18 +143,18 @@ def test_init_from_bargmann(self): b = np.array([0, 1, 2, 3]) c = 1 channel = Channel.from_bargmann([0], [0], (A, b, c), "my_channel") - assert np.allclose(channel.representation.A[None, ...], A) - assert np.allclose(channel.representation.b[None, ...], b) + assert np.allclose(channel.ansatz.A[None, ...], A) + assert np.allclose(channel.ansatz.b[None, ...], b) def test_rshift(self): unitary = Dgate([0, 1], 1) u_component = CircuitComponent._from_attributes( - unitary.representation, unitary.wires, unitary.name + unitary.ansatz, unitary.wires, unitary.name ) # pylint: disable=protected-access channel1 = Attenuator([1, 2], 0.9) channel2 = Attenuator([2, 3], 0.9) ch_component = CircuitComponent._from_attributes( - channel1.representation, channel1.wires, channel1.name + channel1.ansatz, channel1.wires, channel1.name ) # pylint: disable=protected-access assert isinstance(channel1 >> unitary, Channel) @@ -165,7 +165,7 @@ def test_rshift(self): def test_repr(self): channel1 = Attenuator([0, 1], 0.9) ch_component = CircuitComponent._from_attributes( - channel1.representation, channel1.wires, channel1.name + channel1.ansatz, channel1.wires, channel1.name ) # pylint: disable=protected-access assert repr(channel1) == "Attenuator(modes=[0, 1], name=Att, repr=PolyExpAnsatz)" @@ -174,7 +174,7 @@ def test_repr(self): def test_inverse_channel(self): gate = Sgate([0], 0.1, 0.2) >> Dgate([0], 0.1, 0.2) >> Attenuator([0], 0.5) should_be_identity = gate >> gate.inverse() - assert should_be_identity.representation == Attenuator([0], 1.0).representation + assert should_be_identity.ansatz == Attenuator([0], 1.0).ansatz def test_random(self): @@ -183,7 +183,7 @@ def test_random(self): @pytest.mark.parametrize("modes", [[0], [0, 1], [0, 1, 2]]) def test_is_CP(self, modes): - u = Unitary.random(modes).representation + u = Unitary.random(modes).ansatz kraus = u @ u.conj assert Channel.from_bargmann(modes, modes, kraus.triple).is_CP @@ -195,7 +195,7 @@ def test_is_physical(self): def test_XY(self): U = Unitary.random([0, 1]) - u = U.representation + u = U.ansatz unitary_channel = Channel.from_bargmann([0, 1], [0, 1], (u.conj @ u).triple) X, Y = unitary_channel.XY assert np.allclose(X, U.symplectic) and np.allclose(Y, np.zeros(4)) diff --git a/tests/test_physics/test_bargmann_utils.py b/tests/test_physics/test_bargmann_utils.py index 7545b15be..4a7a4f668 100644 --- a/tests/test_physics/test_bargmann_utils.py +++ b/tests/test_physics/test_bargmann_utils.py @@ -169,7 +169,7 @@ def test_XY_of_channel(): Tests the function X_of_channel. """ - X, Y = XY_of_channel(Channel.random([0]).representation.A[0]) + X, Y = XY_of_channel(Channel.random([0]).ansatz.A[0]) omega = np.array([[0, 1j], [-1j, 0]]) channel_check = X @ omega @ X.T / 2 - omega / 2 + Y assert np.all([mu > 0 for mu in np.linalg.eigvals(channel_check)]) From a79284add866a1cd2cdb3918c99d11f7b8ed437e Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 13:33:07 -0400 Subject: [PATCH 41/87] cleanup --- mrmustard/lab_dev/circuit_components.py | 42 +++++++++---------- .../circuit_components_utils/b_to_ps.py | 2 +- .../circuit_components_utils/b_to_q.py | 4 +- .../circuit_components_utils/trace_out.py | 10 ++--- mrmustard/lab_dev/states/base.py | 24 +++++------ mrmustard/lab_dev/states/vacuum.py | 4 +- mrmustard/lab_dev/transformations/base.py | 16 +++---- mrmustard/lab_dev/transformations/identity.py | 4 +- mrmustard/physics/representations.py | 2 +- 9 files changed, 53 insertions(+), 55 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 39a10f2bb..bb9255799 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -57,7 +57,7 @@ class CircuitComponent: and :class:`Representation` classes (and their subclasses) for more details. Args: - representation: A representation for this circuit component. + ansatz: An ansatz for this circuit component. wires: The wires of this component. Alternatively, can be a ``(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket)`` where if any of the modes are out of order the representation @@ -69,7 +69,7 @@ class CircuitComponent: def __init__( self, - representation: PolyExpAnsatz | ArrayAnsatz | None = None, + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, wires: Wires | Sequence[tuple[int]] | None = None, name: str | None = None, ) -> None: @@ -104,13 +104,11 @@ def __init__( + tuple(np.argsort(modes_out_ket) + offsets[1]) + tuple(np.argsort(modes_in_ket) + offsets[2]) ) - if representation is not None: - self._representation = Representation( - representation.reorder(tuple(perm)), wires - ) + if ansatz is not None: + self._representation = Representation(ansatz.reorder(tuple(perm)), wires) if not hasattr(self, "_representation"): - self._representation = Representation(representation, wires) + self._representation = Representation(ansatz, wires) def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: """ @@ -165,9 +163,9 @@ def adjoint(self) -> CircuitComponent: """ bras = self.wires.bra.indices kets = self.wires.ket.indices - rep = self.ansatz.reorder(kets + bras).conj if self.ansatz else None + ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None - ret = CircuitComponent(rep, self.wires.adjoint, self.name) + ret = CircuitComponent(ansatz, self.wires.adjoint, self.name) ret.short_name = self.short_name for param in self.parameter_set.all_parameters.values(): ret._add_parameter(param) @@ -184,9 +182,9 @@ def dual(self) -> CircuitComponent: ik = self.wires.ket.input.indices ib = self.wires.bra.input.indices ob = self.wires.bra.output.indices - rep = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None + ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None - ret = CircuitComponent(rep, self.wires.dual, self.name) + ret = CircuitComponent(ansatz, self.wires.dual, self.name) ret.short_name = self.short_name for param in self.parameter_set.all_parameters.values(): ret._add_parameter(param) @@ -288,9 +286,9 @@ def from_bargmann( Returns: A circuit component with the given Bargmann representation. """ - repr = PolyExpAnsatz(*triple) + ansatz = PolyExpAnsatz(*triple) wires = Wires(set(modes_out_bra), set(modes_in_bra), set(modes_out_ket), set(modes_in_ket)) - return cls._from_attributes(repr, wires, name) + return cls._from_attributes(ansatz, wires, name) @classmethod def from_quadrature( @@ -397,7 +395,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor: @classmethod def _from_attributes( cls, - representation: Ansatz, + ansatz: Ansatz, wires: Wires, name: str | None = None, ) -> CircuitComponent: @@ -431,9 +429,9 @@ def _from_attributes( if tp.__name__ in types: ret = tp() ret._name = name - ret._representation = Representation(representation, wires) + ret._representation = Representation(ansatz, wires) return ret - return CircuitComponent(representation, wires, name) + return CircuitComponent(ansatz, wires, name) def auto_shape(self, **_) -> tuple[int, ...]: r""" @@ -639,9 +637,9 @@ def __add__(self, other: CircuitComponent) -> CircuitComponent: """ if self.wires != other.wires: raise ValueError("Cannot add components with different wires.") - rep = self.ansatz + other.ansatz + ansatz = self.ansatz + other.ansatz name = self.name if self.name == other.name else "" - return self._from_attributes(rep, self.wires, name) + return self._from_attributes(ansatz, self.wires, name) def __eq__(self, other) -> bool: r""" @@ -683,8 +681,8 @@ def __mul__(self, other: Scalar) -> CircuitComponent: return self._from_attributes(self.ansatz * other, self.wires, self.name) def __repr__(self) -> str: - repr = self.ansatz - repr_name = repr.__class__.__name__ + ansatz = self.ansatz + repr_name = ansatz.__class__.__name__ if repr_name == "NoneType": return self.__class__.__name__ + f"(modes={self.modes}, name={self.name})" else: @@ -777,9 +775,9 @@ def __sub__(self, other: CircuitComponent) -> CircuitComponent: """ if self.wires != other.wires: raise ValueError("Cannot subtract components with different wires.") - rep = self.ansatz - other.ansatz + ansatz = self.ansatz - other.ansatz name = self.name if self.name == other.name else "" - return self._from_attributes(rep, self.wires, name) + return self._from_attributes(ansatz, self.wires, name) def __truediv__(self, other: Scalar) -> CircuitComponent: r""" diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index f1c8e65f7..d741f60c6 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -46,7 +46,7 @@ def __init__( super().__init__( modes_out=modes, modes_in=modes, - representation=PolyExpAnsatz.from_function( + ansatz=PolyExpAnsatz.from_function( fn=triples.displacement_map_s_parametrized_Abc, s=s, n_modes=len(modes) ), name="BtoPS", diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index d90c082c0..5eee24a13 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -49,13 +49,13 @@ def __init__( modes: Sequence[int], phi: float = 0.0, ): - repr = PolyExpAnsatz.from_function( + ansatz = PolyExpAnsatz.from_function( fn=triples.bargmann_to_quadrature_Abc, n_modes=len(modes), phi=phi ) super().__init__( modes_out=modes, modes_in=modes, - representation=repr, + ansatz=ansatz, name="BtoQ", ) self._add_parameter(Constant(phi, "phi")) diff --git a/mrmustard/lab_dev/circuit_components_utils/trace_out.py b/mrmustard/lab_dev/circuit_components_utils/trace_out.py index 67263da07..ff862a5d9 100644 --- a/mrmustard/lab_dev/circuit_components_utils/trace_out.py +++ b/mrmustard/lab_dev/circuit_components_utils/trace_out.py @@ -63,7 +63,7 @@ def __init__( ): super().__init__( wires=[(), modes, (), modes], - representation=PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)), + ansatz=PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)), name="Tr", ) @@ -80,14 +80,14 @@ def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitCompon idx_zconj = [bra[m].indices[0] for m in self.wires.modes & bra.modes] idx_z = [ket[m].indices[0] for m in self.wires.modes & ket.modes] if len(self.wires) == 0: - repr = other.ansatz + ansatz = other.ansatz wires = other.wires elif not ket or not bra: - repr = other.ansatz.conj[idx_z] @ other.ansatz[idx_z] + ansatz = other.ansatz.conj[idx_z] @ other.ansatz[idx_z] wires, _ = (other.wires.adjoint @ other.wires)[0] @ self.wires else: - repr = other.ansatz.trace(idx_z, idx_zconj) + ansatz = other.ansatz.trace(idx_z, idx_zconj) wires, _ = other.wires @ self.wires - cpt = other._from_attributes(repr, wires) # pylint:disable=protected-access + cpt = other._from_attributes(ansatz, wires) # pylint:disable=protected-access return math.sum(cpt.ansatz.scalar) if len(cpt.wires) == 0 else cpt diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index ed5dcdea6..317e6a671 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -645,7 +645,7 @@ class DM(State): Args: modes: The modes of this density matrix. - representation: The representation of this density matrix. + ansatz: The ansatz of this density matrix. name: The name of this density matrix. """ @@ -654,18 +654,18 @@ class DM(State): def __init__( self, modes: Sequence[int] = (), - representation: PolyExpAnsatz | ArrayAnsatz | None = None, + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): - if representation and representation.num_vars != 2 * len(modes): + if ansatz and ansatz.num_vars != 2 * len(modes): raise ValueError( - f"Expected a representation with {2*len(modes)} variables, found {representation.num_vars}." + f"Expected a representation with {2*len(modes)} variables, found {ansatz.num_vars}." ) super().__init__( wires=[modes, (), modes, ()], name=name, ) - self._representation = Representation(representation, self.wires) + self._representation = Representation(ansatz, self.wires) @property def is_positive(self) -> bool: @@ -942,10 +942,10 @@ def __getitem__(self, modes: int | Sequence[int]) -> State: idxz = [i for i, m in enumerate(self.modes) if m not in modes] idxz_conj = [i + len(self.modes) for i, m in enumerate(self.modes) if m not in modes] - representation = self.ansatz.trace(idxz, idxz_conj) + ansatz = self.ansatz.trace(idxz, idxz_conj) return self.__class__._from_attributes( - representation, wires, self.name + ansatz, wires, self.name ) # pylint: disable=protected-access def __rshift__(self, other: CircuitComponent) -> CircuitComponent: @@ -974,7 +974,7 @@ class Ket(State): Arguments: modes: The modes of this ket. - representation: The representation of this ket. + ansatz: The ansatz of this ket. name: The name of this ket. """ @@ -983,18 +983,18 @@ class Ket(State): def __init__( self, modes: Sequence[int] = (), - representation: PolyExpAnsatz | ArrayAnsatz | None = None, + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): - if representation and representation.num_vars != len(modes): + if ansatz and ansatz.num_vars != len(modes): raise ValueError( - f"Expected a representation with {len(modes)} variables, found {representation.num_vars}." + f"Expected a representation with {len(modes)} variables, found {ansatz.num_vars}." ) super().__init__( wires=[(), (), modes, ()], name=name, ) - self._representation = Representation(representation, self.wires) + self._representation = Representation(ansatz, self.wires) @property def is_physical(self) -> bool: diff --git a/mrmustard/lab_dev/states/vacuum.py b/mrmustard/lab_dev/states/vacuum.py index 00b3619de..d50493a53 100644 --- a/mrmustard/lab_dev/states/vacuum.py +++ b/mrmustard/lab_dev/states/vacuum.py @@ -60,8 +60,8 @@ def __init__( self, modes: Sequence[int], ) -> None: - rep = PolyExpAnsatz.from_function(fn=triples.vacuum_state_Abc, n_modes=len(modes)) - super().__init__(modes=modes, representation=rep, name="Vac") + ansatz = PolyExpAnsatz.from_function(fn=triples.vacuum_state_Abc, n_modes=len(modes)) + super().__init__(modes=modes, ansatz=ansatz, name="Vac") for i in range(len(modes)): self.manual_shape[i] = 1 diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index b89710df7..719617c08 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -121,11 +121,11 @@ def __init__( self, modes_out: tuple[int, ...] = (), modes_in: tuple[int, ...] = (), - representation: PolyExpAnsatz | ArrayAnsatz | None = None, + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): super().__init__( - representation=representation, + ansatz=ansatz, wires=[(), (), modes_out, modes_in], name=name, ) @@ -166,7 +166,7 @@ class Unitary(Operation): Arguments: modes_out: The output modes of this Unitary. modes_in: The input modes of this Unitary. - representation: The representation of this Unitary. + ansatz: The ansatz of this Unitary. name: The name of this Unitary. """ @@ -235,7 +235,7 @@ def random(cls, modes, max_r=1): def inverse(self) -> Unitary: unitary_dual = self.dual return Unitary._from_attributes( - representation=unitary_dual.ansatz, + ansatz=unitary_dual.ansatz, wires=unitary_dual.wires, name=unitary_dual.name, ) @@ -267,7 +267,7 @@ class Map(Transformation): Arguments: modes_out: The output modes of this Map. modes_in: The input modes of this Map. - representation: The representation of this Map. + ansatz: The ansatz of this Map. name: The name of this Map. """ @@ -277,11 +277,11 @@ def __init__( self, modes_out: tuple[int, ...] = (), modes_in: tuple[int, ...] = (), - representation: PolyExpAnsatz | ArrayAnsatz | None = None, + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): super().__init__( - representation=representation, + ansatz=ansatz, wires=[modes_out, modes_in, modes_out, modes_in], name=name or self.__class__.__name__, ) @@ -321,7 +321,7 @@ class Channel(Map): Arguments: modes_out: The output modes of this Channel. modes_in: The input modes of this Channel. - representation: The representation of this Channel. + ansatz: The ansatz of this Channel. name: The name of this Channel """ diff --git a/mrmustard/lab_dev/transformations/identity.py b/mrmustard/lab_dev/transformations/identity.py index ff33738cc..ab1ce6324 100644 --- a/mrmustard/lab_dev/transformations/identity.py +++ b/mrmustard/lab_dev/transformations/identity.py @@ -51,5 +51,5 @@ def __init__( self, modes: Sequence[int], ): - rep = PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)) - super().__init__(modes_out=modes, modes_in=modes, representation=rep, name="Identity") + ansatz = PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)) + super().__init__(modes_out=modes, modes_in=modes, ansatz=ansatz, name="Identity") diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 55138d35f..1a8d1e8ee 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -187,7 +187,7 @@ def to_bargmann(self) -> Representation: def to_fock(self, shape: int | Sequence[int]) -> Representation: r""" - Returns a new representation with an ``ArrayAnsatz``. + Returns a new representation with an ``ArrayAnsatz``. Args: shape: The shape of the returned representation. If ``shape``is given as From d5f9f3adb7fac6843c9b77b9bc78c39ca698a1cf Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 13:41:27 -0400 Subject: [PATCH 42/87] cleanup --- mrmustard/lab_dev/circuit_components.py | 4 ++-- mrmustard/lab_dev/states/base.py | 2 +- mrmustard/lab_dev/transformations/dgate.py | 4 ++-- mrmustard/physics/representations.py | 8 ++++---- tests/test_lab_dev/test_circuit_components.py | 2 +- tests/test_lab_dev/test_states/test_states_base.py | 6 +++--- tests/test_lab_dev/test_transformations/test_dgate.py | 2 +- tests/test_math/test_lattice.py | 8 ++++---- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index bb9255799..b44060ccf 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -463,7 +463,7 @@ def bargmann_triple( """ return self._representation.bargmann_triple(batched) - def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> ComplexTensor: + def fock_array(self, shape: int | Sequence[int] | None = None, batched=False) -> ComplexTensor: r""" Returns an array representation of this component in the Fock basis with the given shape. If the shape is not given, it defaults to the ``auto_shape`` of the component if it is @@ -477,7 +477,7 @@ def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> Compl Returns: array: The Fock representation of this component. """ - return self._representation.fock(shape or self.auto_shape(), batched) + return self._representation.fock_array(shape or self.auto_shape(), batched) def on(self, modes: Sequence[int]) -> CircuitComponent: r""" diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 317e6a671..db96a9186 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -324,7 +324,7 @@ def fock_distribution(self, cutoff: int) -> ComplexTensor: Returns: The Fock distribution. """ - fock_array = self.fock(cutoff) + fock_array = self.fock_array(cutoff) if isinstance(self, Ket): probs = ( math.astensor( diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index 48835ef56..eeb0616df 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -100,7 +100,7 @@ def __init__( self.wires, ) - def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTensor: + def fock_array(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTensor: r""" Returns the unitary representation of the Displacement gate using the Laguerre polynomials. If the shape is not given, it defaults to the ``auto_shape`` of the component if it is @@ -145,7 +145,7 @@ def fock(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTenso return arrays def to_fock(self, shape: int | Sequence[int] | None = None) -> Dgate: - fock = ArrayAnsatz(self.fock(shape, batched=True), batched=True) + fock = ArrayAnsatz(self.fock_array(shape, batched=True), batched=True) fock._original_abc_data = self.ansatz.triple ret = self._getitem_builtin(self.modes) ret._representation = Representation(fock, self.wires) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 1a8d1e8ee..b7d569858 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -125,7 +125,7 @@ def bargmann_triple( except AttributeError as e: raise AttributeError("No Bargmann data for this component.") from e - def fock(self, shape: int | Sequence[int], batched=False) -> ComplexTensor: + def fock_array(self, shape: int | Sequence[int], batched=False) -> ComplexTensor: r""" Returns an array representation of this component in the Fock basis with the given shape. If the shape is not given, it defaults to the ``auto_shape`` of the component if it is @@ -172,7 +172,7 @@ def fock(self, shape: int | Sequence[int], batched=False) -> ComplexTensor: def to_bargmann(self) -> Representation: r""" - Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation. + Converts this representation to a Bargmann representation. """ if isinstance(self.ansatz, PolyExpAnsatz): return self @@ -187,14 +187,14 @@ def to_bargmann(self) -> Representation: def to_fock(self, shape: int | Sequence[int]) -> Representation: r""" - Returns a new representation with an ``ArrayAnsatz``. + Converts this representation to a Fock representation. Args: shape: The shape of the returned representation. If ``shape``is given as an ``int``, it is broadcasted to all the dimensions. If ``None``, it defaults to the value of ``AUTOSHAPE_MAX`` in the settings. """ - fock = ArrayAnsatz(self.fock(shape, batched=True), batched=True) + fock = ArrayAnsatz(self.fock_array(shape, batched=True), batched=True) try: if self.ansatz.polynomial_shape[0] == 0: fock._original_abc_data = self.ansatz.triple diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 7865820c9..5c9bb7776 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -451,7 +451,7 @@ def test_to_fock_keeps_bargmann(self): def test_fock_component_no_bargmann(self): "tests that a fock component doesn't have a bargmann representation by default" coh = Coherent([0], x=1.0) - CC = Ket.from_fock([0], coh.fock(20), batched=False) + CC = Ket.from_fock([0], coh.fock_array(20), batched=False) with pytest.raises(AttributeError): CC.bargmann_triple() # pylint: disable=pointless-statement diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index 58895c328..31b7f6a08 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -139,7 +139,7 @@ def test_normalize(self, modes, x, y, coeff): def test_to_from_fock(self, modes): state_in = Coherent(modes, x=1, y=2) state_in_fock = state_in.to_fock(5) - array_in = state_in.fock(5, batched=True) + array_in = state_in.fock_array(5, batched=True) assert math.allclose(array_in, state_in_fock.ansatz.array) @@ -522,7 +522,7 @@ def test_from_fock_error(self): state01 = Coherent([0, 1], 1).dm() state01 = state01.to_fock(2) with pytest.raises(ValueError): - DM.from_fock([0], state01.fock(5), "my_dm", True) + DM.from_fock([0], state01.fock_array(5), "my_dm", True) def test_bargmann_Abc_to_phasespace_cov_means(self): # The init state cov and means comes from the random state 'state = Gaussian(1) >> Dgate([0.2], [0.3])' @@ -596,7 +596,7 @@ def test_normalize(self, modes, x, y, coeff): def test_to_from_fock(self, modes): state_in = Coherent(modes, x=1, y=2) >> Attenuator([modes[0]], 0.8) state_in_fock = state_in.to_fock(5) - array_in = state_in.fock(5, batched=True) + array_in = state_in.fock_array(5, batched=True) assert math.allclose(array_in, state_in_fock.ansatz.array) diff --git a/tests/test_lab_dev/test_transformations/test_dgate.py b/tests/test_lab_dev/test_transformations/test_dgate.py index bbcaebd99..b2ff27a5e 100644 --- a/tests/test_lab_dev/test_transformations/test_dgate.py +++ b/tests/test_lab_dev/test_transformations/test_dgate.py @@ -52,7 +52,7 @@ def test_to_fock_method(self): # displacement gate in fock representation for large displacement dgate = Dgate([0], x=10.0).to_fock(150) assert (state.to_fock() >> dgate).probability < 1 - assert np.all(math.abs(dgate.fock(150)) < 1) + assert np.all(math.abs(dgate.fock_array(150)) < 1) def test_representation(self): rep1 = Dgate(modes=[0], x=0.1, y=0.1).ansatz diff --git a/tests/test_math/test_lattice.py b/tests/test_math/test_lattice.py index 0c33bdb37..4433793b3 100644 --- a/tests/test_math/test_lattice.py +++ b/tests/test_math/test_lattice.py @@ -122,14 +122,14 @@ def test_diagonalbatchNumba_vs_diagonalNumba(batch_size): def test_bs_schwinger(): "test that the schwinger method to apply a BS works correctly" - G = mmld.Ket.random([0, 1]).fock([20, 20]) + G = mmld.Ket.random([0, 1]).fock_array([20, 20]) G = math.asnumpy(G) BS = beamsplitter((20, 20, 20, 20), 1.0, 1.0) manual = np.einsum("ab, cdab", G, BS) G = apply_BS_schwinger(1.0, 1.0, 0, 1, G) assert np.allclose(manual, G) - Gg = mmld.Unitary.random([0, 1]).fock([20, 20, 20, 20]) + Gg = mmld.Unitary.random([0, 1]).fock_array([20, 20, 20, 20]) Gg = math.asnumpy(Gg) BS = beamsplitter((20, 20, 20, 20), 2.0, -1.0) manual = np.einsum("cdab, abef", BS, Gg) @@ -157,10 +157,10 @@ def test_vanilla_stable(): "tests the vanilla stable against other known stable methods" settings.STABLE_FOCK_CONVERSION = True assert np.allclose( - mmld.Dgate([0], x=4.0, y=4.0).fock([1000, 1000]), + mmld.Dgate([0], x=4.0, y=4.0).fock_array([1000, 1000]), displacement((1000, 1000), 4.0 + 4.0j), ) - sgate = mmld.Sgate([0], r=4.0, phi=2.0).fock([1000, 1000]) + sgate = mmld.Sgate([0], r=4.0, phi=2.0).fock_array([1000, 1000]) assert np.max(np.abs(sgate)) < 1 settings.STABLE_FOCK_CONVERSION = False From f0a969c56baa088b0247ffb191d9bd05abca3c9d Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 13:50:38 -0400 Subject: [PATCH 43/87] docs --- mrmustard/lab_dev/circuit_components.py | 5 ++--- mrmustard/lab_dev/states/base.py | 4 ++-- mrmustard/physics/ansatz/polyexp_ansatz.py | 6 +++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index b44060ccf..b2cdc4f18 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -526,16 +526,15 @@ def to_bargmann(self) -> CircuitComponent: .. code-block:: >>> from mrmustard.lab_dev import Dgate - >>> from mrmustard.physics.representations import Bargmann + >>> from mrmustard.physics.ansatz import PolyExpAnsatz >>> d = Dgate([1], x=0.1, y=0.1) >>> d_fock = d.to_fock(shape=3) >>> d_bargmann = d_fock.to_bargmann() - >>> assert d_bargmann.name == d.name >>> assert d_bargmann.wires == d.wires - >>> assert isinstance(d_bargmann.representation, Bargmann) + >>> assert isinstance(d_bargmann.ansatz, PolyExpAnsatz) """ if isinstance(self.ansatz, PolyExpAnsatz): return self diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index db96a9186..263ec4270 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -184,7 +184,7 @@ def from_bargmann( .. code-block:: - >>> from mrmustard.physics.representations import Bargmann + >>> from mrmustard.physics.ansatz import PolyExpAnsatz >>> from mrmustard.physics.triples import coherent_state_Abc >>> from mrmustard.lab_dev.states.base import Ket @@ -193,7 +193,7 @@ def from_bargmann( >>> coh = Ket.from_bargmann(modes, triple) >>> assert coh.modes == modes - >>> assert coh.representation == Bargmann(*triple) + >>> assert coh.ansatz == PolyExpAnsatz(*triple) >>> assert isinstance(coh, Ket) Args: diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index 835798626..f1c2c3445 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -72,14 +72,14 @@ class PolyExpAnsatz(Ansatz): .. code-block :: - >>> from mrmustard.physics.representations import Bargmann + >>> from mrmustard.physics.representations import PolyExpAnsatz >>> from mrmustard.physics.triples import displacement_gate_Abc, vacuum_state_Abc >>> # bargmann representation of one-mode vacuum - >>> rep_vac = Bargmann(*vacuum_state_Abc(1)) + >>> rep_vac = PolyExpAnsatz(*vacuum_state_Abc(1)) >>> # bargmann representation of one-mode dgate with gamma=1+0j - >>> rep_dgate = Bargmann(*displacement_gate_Abc(1)) + >>> rep_dgate = PolyExpAnsatz(*displacement_gate_Abc(1)) The inner product is defined as the contraction of two Bargmann objects across marked indices. Indices are marked using ``__getitem__``. Once the indices are marked for contraction, they are From 1331c93a78b9fe79b56c03a06e740e81b51d19ab Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 14:00:12 -0400 Subject: [PATCH 44/87] cleanup --- .../branch_and_bound.py | 30 ++++++++++--------- mrmustard/lab_dev/states/base.py | 12 ++++---- mrmustard/physics/wires.py | 4 +-- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py b/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py index 82bd471f7..4be60e961 100644 --- a/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py +++ b/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py @@ -39,21 +39,23 @@ class GraphComponent: r""" A lightweight "CircuitComponent" without the actual representation. Basically a wrapper around Wires, so that it can emulate components in - a circuit. It exposes the repr, wires, shape, name and cost of obtaining + a circuit. It exposes the representation, wires, shape, name and cost of obtaining the component from previous contractions. Args: - repr: The name of the representation of the component. + representation: The name of the representation of the component. wires: The wires of the component. shape: The fock shape of the component. name: The name of the component. cost: The cost of obtaining this component. """ - def __init__(self, repr: str, wires: Wires, shape: list[int], name: str = "", cost: int = 0): + def __init__( + self, representation: str, wires: Wires, shape: list[int], name: str = "", cost: int = 0 + ): if None in shape: raise ValueError("Detected `None`s in shape. Please provide a full shape.") - self.repr = repr + self.representation = representation self.wires = wires self.shape = list(shape) self.name = name @@ -68,7 +70,7 @@ def from_circuitcomponent(cls, c: CircuitComponent): c: A CircuitComponent. """ return GraphComponent( - repr=str(c.ansatz.__class__.__name__), + representation=str(c.ansatz.__class__.__name__), wires=Wires(*c.wires.args), shape=c.auto_shape(), name=c.__class__.__name__, @@ -101,7 +103,7 @@ def contraction_cost(self, other: GraphComponent) -> int: m = len(idxA) # same as len(idxB) nA, nB = len(self.shape) - m, len(other.shape) - m - if self.repr == "Bargmann" and other.repr == "Bargmann": + if self.representation == "Bargmann" and other.representation == "Bargmann": cost = ( # +1s to include vector part) m * m * m # M inverse + (m + 1) * m * nA # left matmul @@ -117,8 +119,8 @@ def contraction_cost(self, other: GraphComponent) -> int: ) cost = ( prod_A * prod_B * prod_contracted # matmul - + np.prod(self.shape) * (self.repr == "Bargmann") # conversion - + np.prod(other.shape) * (other.repr == "Bargmann") # conversion + + np.prod(self.shape) * (self.representation == "Bargmann") # conversion + + np.prod(other.shape) * (other.representation == "Bargmann") # conversion ) return int(cost) @@ -136,7 +138,7 @@ def __matmul__(self, other) -> GraphComponent: shape = shape_A + shape_B new_shape = [shape[p] for p in perm] new_component = GraphComponent( - "Bargmann" if self.repr == other.repr == "Bargmann" else "Fock", + "Bargmann" if self.representation == other.representation == "Bargmann" else "Fock", new_wires, new_shape, f"({self.name}@{other.name})", @@ -390,7 +392,7 @@ def assign_costs(graph: Graph, debug: int = 0) -> None: graph.edges[edge]["cost"] = A.contraction_cost(B) if debug > 0: print( - f"cost of edge {edge}: {A.repr}|{A.shape} x {B.repr}|{B.shape} = {graph.edges[edge]['cost']}" + f"cost of edge {edge}: {A.representation}|{A.shape} x {B.representation}|{B.shape} = {graph.edges[edge]['cost']}" ) @@ -414,10 +416,10 @@ def reduce_first(graph: Graph, code: str) -> tuple[Graph, Edge | bool]: r""" Reduces the first pair of nodes that match the pattern in the code. The first number and letter describe a node with that number of - edges and that repr (B for Bargmann, F for Fock), and the last letter - describes the repr of the second node. + edges and that representation (B for Bargmann, F for Fock), and the last letter + describes the representation of the second node. For example 1BB means we will contract the first occurrence of a node - that has one edge (a leaf) connected to a node of repr B with an arbitrary + that has one edge (a leaf) connected to a node of representation B with an arbitrary number of edges. We typically use codes like 1BB, 2BB, 1FF, 2FF by default because they are safe, and codes like 1BF, 1FB optionally as they are not always the best choice. @@ -432,7 +434,7 @@ def reduce_first(graph: Graph, code: str) -> tuple[Graph, Edge | bool]: for edge in list(graph.out_edges(node)) + list(graph.in_edges(node)): A = graph.nodes[edge[0]]["component"] B = graph.nodes[edge[1]]["component"] - if A.repr[0] == tA and B.repr[0] == tB: + if A.representation[0] == tA and B.representation[0] == tB: graph = contract(graph, edge) return graph, edge return graph, False diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 263ec4270..4f8fe998a 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -843,9 +843,9 @@ def auto_shape( shape = self.ansatz.array.shape[1:] except AttributeError: # bargmann if self.ansatz.polynomial_shape[0] == 0: - repr = self.ansatz - A, b, c = repr.A[0], repr.b[0], repr.c[0] - repr = repr / self.probability + ansatz = self.ansatz + A, b, c = ansatz.A[0], ansatz.b[0], ansatz.c[0] + ansatz = ansatz / self.probability shape = autoshape_numba( math.asnumpy(A), math.asnumpy(b), @@ -1140,9 +1140,9 @@ def auto_shape( shape = self.ansatz.array.shape[1:] except AttributeError: # bargmann if self.ansatz.polynomial_shape[0] == 0: - repr = self.ansatz.conj & self.ansatz - A, b, c = repr.A[0], repr.b[0], repr.c[0] - repr = repr / self.probability + ansatz = self.ansatz.conj & self.ansatz + A, b, c = ansatz.A[0], ansatz.b[0], ansatz.c[0] + ansatz = ansatz / self.probability shape = autoshape_numba( math.asnumpy(A), math.asnumpy(b), diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index 7f896df1d..b68089549 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -470,9 +470,9 @@ def __matmul__(self, other: Wires) -> tuple[Wires, list[int]]: .. code-block:: - repr = repr1[idx1] @ repr2[idx2] # not in standard order + ansatz = ansatz1[idx1] @ ansatz2[idx2] # not in standard order wires, perm = wires1 @ wires2 # matmul the wires of each component - repr = repr.reorder(perm) # now in standard order + ansatz = ansatz.reorder(perm) # now in standard order Args: other: The wires of the other circuit component. From 219275415fda3ade26e6262aed19b4c8a24b29bf Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 14:02:56 -0400 Subject: [PATCH 45/87] docs --- mrmustard/physics/ansatz/array_ansatz.py | 41 +++-------- mrmustard/physics/ansatz/polyexp_ansatz.py | 84 ++++------------------ 2 files changed, 23 insertions(+), 102 deletions(-) diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index 50c518422..cb4204b14 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -36,48 +36,23 @@ class ArrayAnsatz(Ansatz): r""" - The Fock representation of a broad class of quantum states, transformations, measurements, - channels, etc. + The ansatz of the Fock-Bargmann representation. - The ansatz available in this representation is ``ArrayAnsatz``. - - This function allows for vector space operations on Fock objects including - linear combinations, outer product (``&``), and inner product (``@``). + Represents the ansatz as a multidimensional array. .. code-block:: - >>> from mrmustard.physics.representations import Fock - - >>> # initialize Fock objects - >>> array1 = np.random.random((5,7,8)) - >>> array2 = np.random.random((5,7,8)) - >>> array3 = np.random.random((3,5,7,8)) # where 3 is the batch. - >>> fock1 = Fock(array1) - >>> fock2 = Fock(array2) - >>> fock3 = Fock(array3, batched=True) - - >>> # linear combination can be done with the same batch dimension - >>> fock4 = 1.3 * fock1 - fock2 * 2.1 - - >>> # division by a scalar - >>> fock5 = fock1 / 1.3 + >>> from mrmustard.physics.ansatze import ArrayAnsatz - >>> # inner product by contracting on marked indices - >>> fock6 = fock1[2] @ fock3[2] - - >>> # outer product (tensor product) - >>> fock7 = fock1 & fock3 - - >>> # conjugation - >>> fock8 = fock1.conj + >>> array = np.random.random((2, 4, 5)) + >>> ansatz = ArrayAnsatz(array) Args: - array: the (batched) array in Fock representation. - batched: whether the array input has a batch dimension. + array: A (potentially) batched array. + batched: Whether the array input has a batch dimension. Note: The args can be passed non-batched, as they will be automatically broadcasted to the - correct batch shape. - + correct batch shape if ``batched`` is set to ``False``. """ def __init__(self, array: Batch[Tensor], batched=False): diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index f1c2c3445..0b224e1da 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -58,89 +58,35 @@ class PolyExpAnsatz(Ansatz): r""" - The Fock-Bargmann representation of a broad class of quantum states, transformations, - measurements, channels, etc. + The ansatz of the Fock-Bargmann representation. - The ansatz available in this representation is a linear combination of exponentials - of bilinear forms with a polynomial part: + Represents the ansatz function: - .. math:: - F(z) = \sum_i \textrm{poly}_i(z) \textrm{exp}(z^T A_i z / 2 + z^T b_i) + :math:`F(z) = \sum_i [\sum_k c^{(i)}_k \partial_y^k \textrm{exp}((z,y)^T A_i (z,y) / 2 + (z,y)^T b_i)|_{y=0}]` - This function allows for vector space operations on Bargmann objects including - linear combinations (``+``), outer product (``&``), and inner product (``@``). + with ``k`` being a multi-index. The matrices :math:`A_i` and vectors :math:`b_i` are + parameters of the exponential terms in the ansatz, and :math:`z` is a vector of variables, and and :math:`y` is a vector linked to the polynomial coefficients. + The dimension of ``z + y`` must be equal to the dimension of ``A`` and ``b``. - .. code-block :: + .. code-block:: - >>> from mrmustard.physics.representations import PolyExpAnsatz - >>> from mrmustard.physics.triples import displacement_gate_Abc, vacuum_state_Abc + >>> from mrmustard.physics.ansatze import PolyExpAnsatz - >>> # bargmann representation of one-mode vacuum - >>> rep_vac = PolyExpAnsatz(*vacuum_state_Abc(1)) - >>> # bargmann representation of one-mode dgate with gamma=1+0j - >>> rep_dgate = PolyExpAnsatz(*displacement_gate_Abc(1)) + >>> A = np.array([[1.0, 0.0], [0.0, 1.0]]) + >>> b = np.array([1.0, 1.0]) + >>> c = np.array([[1.0,2.0,3.0]]) - The inner product is defined as the contraction of two Bargmann objects across marked indices. - Indices are marked using ``__getitem__``. Once the indices are marked for contraction, they are - be used the next time the inner product (``@``) is called. For example: + >>> F = PolyExpAnsatz(A, b, c) + >>> z = np.array([[1.0],[2.0],[3.0]]) - .. code-block :: - - >>> import numpy as np - - >>> # mark indices for contraction - >>> idx_vac = [0] - >>> idx_rep = [1] - - >>> # bargmann representation of coh = vacuum >> dgate - >>> rep_coh = rep_vac[idx_vac] @ rep_dgate[idx_rep] - >>> assert np.allclose(rep_coh.A, [[0,],]) - >>> assert np.allclose(rep_coh.b, [1,]) - >>> assert np.allclose(rep_coh.c, 0.6065306597126334) - - This can also be used to contract existing indices in a single Bargmann object, e.g. - to implement the partial trace. - - .. code-block :: - - >>> trace = (rep_coh @ rep_coh.conj).trace([0], [1]) - >>> assert np.allclose(trace.A, 0) - >>> assert np.allclose(trace.b, 0) - >>> assert trace.c == 1 - - The ``A``, ``b``, and ``c`` parameters can be batched to represent superpositions. - - .. code-block :: - - >>> # bargmann representation of one-mode coherent state with gamma=1+0j - >>> A_plus = [[0,],] - >>> b_plus = [1,] - >>> c_plus = 0.6065306597126334 - - >>> # bargmann representation of one-mode coherent state with gamma=-1+0j - >>> A_minus = [[0,],] - >>> b_minus = [-1,] - >>> c_minus = 0.6065306597126334 - - >>> # bargmann representation of a superposition of coherent states - >>> A = [A_plus, A_minus] - >>> b = [b_plus, b_minus] - >>> c = [c_plus, c_minus] - >>> rep_coh_sup = Bargmann(A, b, c) - - Note that the operations that change the shape of the ansatz (outer product and inner - product) do not automatically modify the ordering of the combined or leftover indices. - However, the ``reordering`` method allows reordering the representation after the products - have been carried out. + >>> # calculate the value of the function at the three different ``z``, since z is batched. + >>> val = F(z) Args: A: A batch of quadratic coefficient :math:`A_i`. b: A batch of linear coefficients :math:`b_i`. c: A batch of arrays :math:`c_i`. - - Note: The args can be passed non-batched, as they will be automatically broadcasted to the - correct batch shape. """ def __init__( From 0bb23690769cf54d9471a678566b974bdf48a275 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 14:43:14 -0400 Subject: [PATCH 46/87] docs --- mrmustard/physics/ansatz/array_ansatz.py | 2 +- mrmustard/physics/ansatz/polyexp_ansatz.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index cb4204b14..4c9d2348b 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -42,7 +42,7 @@ class ArrayAnsatz(Ansatz): .. code-block:: - >>> from mrmustard.physics.ansatze import ArrayAnsatz + >>> from mrmustard.physics.ansatz import ArrayAnsatz >>> array = np.random.random((2, 4, 5)) >>> ansatz = ArrayAnsatz(array) diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index 0b224e1da..d060ead0c 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -70,7 +70,7 @@ class PolyExpAnsatz(Ansatz): .. code-block:: - >>> from mrmustard.physics.ansatze import PolyExpAnsatz + >>> from mrmustard.physics.ansatz import PolyExpAnsatz >>> A = np.array([[1.0, 0.0], [0.0, 1.0]]) From fb9cf8e953686484582d1350e2a4226992dc285a Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 16:38:47 -0400 Subject: [PATCH 47/87] docs --- mrmustard/lab_dev/circuit_components.py | 4 ++-- mrmustard/lab_dev/states/base.py | 6 +++--- mrmustard/physics/ansatz/array_ansatz.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index b2cdc4f18..379b0edbe 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -556,14 +556,14 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: .. code-block:: >>> from mrmustard.lab_dev import Dgate - >>> from mrmustard.physics.representations import Fock + >>> from mrmustard.physics.ansatz import ArrayAnsatz >>> d = Dgate([1], x=0.1, y=0.1) >>> d_fock = d.to_fock(shape=3) >>> assert d_fock.name == d.name >>> assert d_fock.wires == d.wires - >>> assert isinstance(d_fock.representation, Fock) + >>> assert isinstance(d_fock.representation, ArrayAnsatz) Args: shape: The shape of the returned representation. If ``shape``is given as diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 4f8fe998a..bbc60f688 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -224,16 +224,16 @@ def from_fock( .. code-block:: - >>> from mrmustard.physics.representations import Fock + >>> from mrmustard.physics.ansatz import ArrayAnsatz >>> from mrmustard.physics.triples import coherent_state_Abc >>> from mrmustard.lab_dev import Coherent, Ket >>> modes = [0] - >>> array = Coherent(modes, x=0.1).to_fock().representation.array + >>> array = Coherent(modes, x=0.1).to_fock().ansatz.array >>> coh = Ket.from_fock(modes, array, batched=True) >>> assert coh.modes == modes - >>> assert coh.representation == Fock(array, batched=True) + >>> assert coh.ansatz == ArrayAnsatz(array, batched=True) >>> assert isinstance(coh, Ket) Args: diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index 4c9d2348b..073de57e5 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -133,21 +133,21 @@ def reduce(self, shape: int | Sequence[int]) -> ArrayAnsatz: .. code-block:: >>> from mrmustard import math - >>> from mrmustard.physics.representations import Fock + >>> from mrmustard.physics.ansatz import ArrayAnsatz >>> array1 = math.arange(27).reshape((3, 3, 3)) - >>> fock1 = Fock(array1) + >>> fock1 = ArrayAnsatz(array1) >>> fock2 = fock1.reduce(3) >>> assert fock1 == fock2 >>> fock3 = fock1.reduce(2) >>> array3 = [[[0, 1], [3, 4]], [[9, 10], [12, 13]]] - >>> assert fock3 == Fock(array3) + >>> assert fock3 == ArrayAnsatz(array3) >>> fock4 = fock1.reduce((1, 3, 1)) >>> array4 = [[[0], [3], [6]]] - >>> assert fock4 == Fock(array4) + >>> assert fock4 == ArrayAnsatz(array4) Args: shape: The shape of the array of the returned ``Fock``. From 76646d808de69a0834646e29fe94f7c8b5a8807d Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 16:41:48 -0400 Subject: [PATCH 48/87] docs --- mrmustard/lab_dev/states/number.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index 6a7edf46f..3158e0f39 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -36,9 +36,10 @@ class Number(Ket): .. code-block:: >>> from mrmustard.lab_dev import Number + >>> from mrmustard.physics.ansatz import ArrayAnsatz >>> state = Number(modes=[0, 1], n=[10, 20]) - >>> assert state.representation.__class__.__name__ == "Fock" + >>> assert isinstance(state.ansatz, ArrayAnsatz) Args: modes: The modes of the number state. From 1c1c3fa510aa6a8755bc33cfa968e54804543b1a Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 16:49:39 -0400 Subject: [PATCH 49/87] docs --- doc/code/lab_dev.rst | 1 - doc/code/lab_dev/wires.rst | 8 -------- doc/code/physics.rst | 1 + mrmustard/lab_dev/circuit_components.py | 2 +- mrmustard/physics/wires.py | 2 +- 5 files changed, 3 insertions(+), 11 deletions(-) delete mode 100644 doc/code/lab_dev/wires.rst diff --git a/doc/code/lab_dev.rst b/doc/code/lab_dev.rst index 0d5554d98..1342f84b7 100644 --- a/doc/code/lab_dev.rst +++ b/doc/code/lab_dev.rst @@ -4,7 +4,6 @@ mrmustard.lab_dev .. toctree:: :maxdepth: 1 - lab_dev/wires lab_dev/circuit_components lab_dev/states lab_dev/transformations diff --git a/doc/code/lab_dev/wires.rst b/doc/code/lab_dev/wires.rst deleted file mode 100644 index 7f81a5056..000000000 --- a/doc/code/lab_dev/wires.rst +++ /dev/null @@ -1,8 +0,0 @@ -mrmustard.lab_dev.wires -======================= - -.. currentmodule:: mrmustard.lab_dev.wires - -.. automodapi:: mrmustard.lab_dev.wires - :no-heading: - :include-all-objects: diff --git a/doc/code/physics.rst b/doc/code/physics.rst index 1efa7ec53..d98824351 100644 --- a/doc/code/physics.rst +++ b/doc/code/physics.rst @@ -4,6 +4,7 @@ mrmustard.physics .. toctree:: :maxdepth: 1 + physics/wires physics/representations .. toctree:: diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 379b0edbe..f6dae33c2 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -563,7 +563,7 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: >>> assert d_fock.name == d.name >>> assert d_fock.wires == d.wires - >>> assert isinstance(d_fock.representation, ArrayAnsatz) + >>> assert isinstance(d_fock.ansatz, ArrayAnsatz) Args: shape: The shape of the returned representation. If ``shape``is given as diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index b68089549..422701342 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -80,7 +80,7 @@ class Wires: .. code-block:: - >>> from mrmustard.lab_dev.wires import Wires + >>> from mrmustard.physics.wires import Wires >>> modes_out_bra={0, 1} >>> modes_in_bra={1, 2} From d61a62666afcee6419460c12c644b7b6fa65a2f0 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 8 Oct 2024 16:55:21 -0400 Subject: [PATCH 50/87] docs --- doc/code/physics/wires.rst | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 doc/code/physics/wires.rst diff --git a/doc/code/physics/wires.rst b/doc/code/physics/wires.rst new file mode 100644 index 000000000..95199818c --- /dev/null +++ b/doc/code/physics/wires.rst @@ -0,0 +1,8 @@ +mrmustard.physics.wires +======================= + +.. currentmodule:: mrmustard.physics.wires + +.. automodapi:: mrmustard.physics.wires + :no-heading: + :include-all-objects: From fd40f6911af881d83f31a1ef45a0d5242ea04c00 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 10 Oct 2024 13:15:39 -0400 Subject: [PATCH 51/87] moving adjoint and dual --- mrmustard/lab_dev/circuit_components.py | 22 ++++++--------------- mrmustard/physics/representations.py | 26 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index f6dae33c2..7da4490ab 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -158,14 +158,10 @@ def _deserialize(cls, data: dict) -> CircuitComponent: def adjoint(self) -> CircuitComponent: r""" The adjoint of this component obtained by conjugating the representation and swapping - the ket and bra wires. The returned object is a view of the original component which - applies a conjugation and a swap of the wires, but does not copy the data in memory. + the ket and bra wires. """ - bras = self.wires.bra.indices - kets = self.wires.ket.indices - ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None - - ret = CircuitComponent(ansatz, self.wires.adjoint, self.name) + rep = self.representation.adjoint + ret = CircuitComponent(rep.ansatz, rep.wires, self.name) ret.short_name = self.short_name for param in self.parameter_set.all_parameters.values(): ret._add_parameter(param) @@ -175,16 +171,10 @@ def adjoint(self) -> CircuitComponent: def dual(self) -> CircuitComponent: r""" The dual of this component obtained by conjugating the representation and swapping - the input and output wires. The returned object is a view of the original component which - applies a conjugation and a swap of the wires, but does not copy the data in memory. + the input and output wires. """ - ok = self.wires.ket.output.indices - ik = self.wires.ket.input.indices - ib = self.wires.bra.input.indices - ob = self.wires.bra.output.indices - ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None - - ret = CircuitComponent(ansatz, self.wires.dual, self.name) + rep = self.representation.dual + ret = CircuitComponent(rep.ansatz, rep.wires, self.name) ret.short_name = self.short_name for param in self.parameter_set.all_parameters.values(): ret._add_parameter(param) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index b7d569858..9456b3a5d 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -89,6 +89,18 @@ def __init__( self._wires = wires self._wire_reps = wire_reps or dict.fromkeys(wires.modes, RepEnum.from_ansatz(ansatz)) + @property + def adjoint(self) -> Representation: + r""" + The adjoint of this representation obtained by conjugating the ansatz and swapping + the ket and bra wires. + """ + bras = self.wires.bra.indices + kets = self.wires.ket.indices + ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None + wires = self.wires.adjoint + return Representation(ansatz, wires) + @property def ansatz(self) -> Ansatz | None: r""" @@ -96,6 +108,20 @@ def ansatz(self) -> Ansatz | None: """ return self._ansatz + @property + def dual(self) -> Representation: + r""" + The dual of this representation obtained by conjugating the ansatz and swapping + the input and output wires. + """ + ok = self.wires.ket.output.indices + ik = self.wires.ket.input.indices + ib = self.wires.bra.input.indices + ob = self.wires.bra.output.indices + ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None + wires = self.wires.dual + return Representation(ansatz, wires) + @property def wires(self) -> Wires | None: r""" From 62bd9445f69b99b6a72b15defe6c78858c9b6591 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 10 Oct 2024 13:30:37 -0400 Subject: [PATCH 52/87] some more cleanup --- mrmustard/physics/ansatz/array_ansatz.py | 22 +++++++------- mrmustard/physics/ansatz/base.py | 12 ++++---- mrmustard/physics/ansatz/polyexp_ansatz.py | 34 +++++++++++----------- mrmustard/physics/representations.py | 10 +++---- 4 files changed, 38 insertions(+), 40 deletions(-) diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index 073de57e5..897d26775 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -64,7 +64,7 @@ def __init__(self, array: Batch[Tensor], batched=False): @property def array(self) -> Batch[Tensor]: r""" - The array of this representation. + The array of this ansatz. """ self._generate_ansatz() if not self._backend_array: @@ -98,8 +98,8 @@ def num_vars(self) -> int: @property def scalar(self) -> Scalar: r""" - The scalar part of the representation. - I.e. the vacuum component of the Fock object, whatever it may be. + The scalar part of the ansatz. + I.e. the vacuum component of the Fock array, whatever it may be. Given that the first axis of the array is the batch axis, this is the first element of the array. """ return self.array[(slice(None),) + (0,) * self.num_vars] @@ -107,12 +107,10 @@ def scalar(self) -> Scalar: @property def triple(self) -> tuple: r""" - The data of the original Bargmann if it exists. + The data of the original PolyExpAnsatz if it exists. """ if self._original_abc_data is None: - raise AttributeError( - "This Fock object does not have an original Bargmann representation." - ) + raise AttributeError("This ArrayAnsatz does not have (A,b,c) data.") return self._original_abc_data @classmethod @@ -128,7 +126,7 @@ def from_function(cls, fn: Callable, **kwargs: Any) -> ArrayAnsatz: def reduce(self, shape: int | Sequence[int]) -> ArrayAnsatz: r""" - Returns a new ``Fock`` with a sliced array. + Returns a new ``ArrayAnsatz`` with a sliced array. .. code-block:: @@ -150,7 +148,7 @@ def reduce(self, shape: int | Sequence[int]) -> ArrayAnsatz: >>> assert fock4 == ArrayAnsatz(array4) Args: - shape: The shape of the array of the returned ``Fock``. + shape: The shape of the array of the returned ``ArrayAnsatz``. """ if shape == self.array.shape[1:]: return self @@ -182,7 +180,7 @@ def sum_batch(self) -> ArrayAnsatz: Sums over the batch dimension of the array. Turns an object with any batch size to a batch size of 1. Returns: - The collapsed Fock object. + The collapsed ArrayAnsatz object. """ return ArrayAnsatz(math.expand_dims(math.sum(self.array, axes=[0]), 0), batched=True) @@ -191,7 +189,7 @@ def to_dict(self) -> dict[str, ArrayLike]: def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> ArrayAnsatz: if len(idxs1) != len(idxs2) or not set(idxs1).isdisjoint(idxs2): - raise ValueError("idxs must be of equal length and disjoint") + raise ValueError("The idxs must be of equal length and disjoint.") order = ( [0] + [i + 1 for i in range(len(self.array.shape) - 1) if i not in idxs1 + idxs2] @@ -235,7 +233,7 @@ def __and__(self, other: ArrayAnsatz) -> ArrayAnsatz: return ArrayAnsatz(array=new_array, batched=True) def __call__(self, z: Batch[Vector]) -> Scalar: - raise AttributeError("Cannot call Fock.") + raise AttributeError("Cannot call this ArrayAnsatz.") def __eq__(self, other: Ansatz) -> bool: slices = (slice(0, None),) + tuple( diff --git a/mrmustard/physics/ansatz/base.py b/mrmustard/physics/ansatz/base.py index a2bd44d57..9285eb9bd 100644 --- a/mrmustard/physics/ansatz/base.py +++ b/mrmustard/physics/ansatz/base.py @@ -64,7 +64,7 @@ def conj(self) -> Ansatz: def data(self) -> tuple | Tensor: r""" The data of the ansatz. - For now, it's the triple for Bargmann and the array for Fock. + For now, it's the triple for PolyExpAnsatz and the array for ArrayAnsatz. """ @property @@ -79,7 +79,7 @@ def num_vars(self) -> int: def scalar(self) -> Scalar: r""" The scalar part of the ansatz. - For now it's ``c`` for Bargmann and the array for Fock. + For now it's ``c`` for PolyExpAnsatz and the array for ArrayAnsatz. """ @property @@ -95,7 +95,7 @@ def triple( @abstractmethod def from_dict(cls, data: dict[str, ArrayLike]) -> Ansatz: r""" - Deserialize a Representation. + Deserialize an Ansatz. """ @classmethod @@ -114,7 +114,7 @@ def reorder(self, order: tuple[int, ...] | list[int]) -> Ansatz: @abstractmethod def to_dict(self) -> dict[str, ArrayLike]: r""" - Serialize a Representation. + Serialize an Ansatz. """ @abstractmethod @@ -218,13 +218,13 @@ def __neg__(self) -> Ansatz: Negates the values in the ansatz. """ - def __rmul__(self, other: Ansatz | Scalar) -> Ansatz: + def __rmul__(self, other: Scalar | Ansatz) -> Ansatz: r""" Multiplies this ansatz by another or by a scalar on the right. """ return self.__mul__(other) - def __sub__(self, other: Ansatz) -> Ansatz: + def __sub__(self, other: Scalar | Ansatz) -> Ansatz: r""" Subtracts other from this ansatz. """ diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index d060ead0c..805893e7a 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -209,9 +209,9 @@ def from_function(cls, fn: Callable, **kwargs: Any) -> PolyExpAnsatz: def decompose_ansatz(self) -> PolyExpAnsatz: r""" - This method decomposes a Bargmann representation. Given a representation of dimensions: + This method decomposes a PolyExp ansatz. Given an ansatz of dimension: A=(batch,n+m,n+m), b=(batch,n+m), c = (batch,k_1,k_2,...,k_m), - it can be rewritten as a representation of dimensions + it can be rewritten as an ansatz of dimension A=(batch,2n,2n), b=(batch,2n), c = (batch,l_1,l_2,...,l_n), with l_i = sum_j k_j This decomposition is typically favourable if m>n, and will only run if that is the case. The naming convention is ``n = dim_alpha`` and ``m = dim_beta`` and ``(k_1,k_2,...,k_m) = shape_beta`` @@ -300,11 +300,11 @@ def reorder(self, order: tuple[int, ...] | list[int]) -> PolyExpAnsatz: def simplify(self) -> None: r""" - Simplifies the representation by combining together terms that have the same + Simplifies the ansatz by combining together terms that have the same exponential part, i.e. two terms along the batch are considered equal if their matrix and vector are equal. In this case only one is kept and the arrays are added. - Does not run if the representation has already been simplified, so it is safe to call. + Does not run if the ansatz has already been simplified, so it is safe to call. """ if self._simplified: return @@ -358,7 +358,7 @@ def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> PolyExpAnsatz def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz: r""" - Value of this representation at ``z``. If ``z`` is batched a value of the function at each of the batches are returned. + Value of this ansatz at ``z``. If ``z`` is batched a value of the function at each of the batches are returned. If ``Abc`` is batched it is thought of as a linear combination, and thus the results are added linearly together. Note that the batch dimension of ``z`` and ``Abc`` can be different. @@ -581,7 +581,7 @@ def _order_batch(self): This method orders the batch dimension by the lexicographical order of the flattened arrays (A, b, c). This is a very cheap way to enforce an ordering of the batch dimension, which is useful for simplification and for - determining (in)equality between two Bargmann representations. + determining (in)equality between two PolyExp ansatz. """ generators = [ itertools.chain( @@ -598,7 +598,7 @@ def _order_batch(self): def __add__(self, other: PolyExpAnsatz) -> PolyExpAnsatz: r""" - Adds two Bargmann representations together. This means concatenating them in the batch dimension. + Adds two PolyExp ansatz together. This means concatenating them in the batch dimension. In the case where c is a polynomial of different shapes it will add padding zeros to make the shapes fit. Example: If the shape of c1 is (1,3,4,5) and the shape of c2 is (1,5,4,3) then the shape of the combined object will be (2,5,4,5). @@ -641,17 +641,17 @@ def __add__(self, other: PolyExpAnsatz) -> PolyExpAnsatz: def __and__(self, other: PolyExpAnsatz) -> PolyExpAnsatz: r""" - Tensor product of this Bargmann with another Bargmann. + Tensor product of this PolyExpAnsatz with another. Equivalent to :math:`F(a) * G(b)` (with different arguments, that is). As it distributes over addition on both self and other, the batch size of the result is the product of the batch - size of this representation and the other one. + size of this ansatz and the other one. Args: - other: Another Barmann. + other: Another PolyExpAnsatz. Returns: - The tensor product of this Bargmann and other. + The tensor product of this PolyExpAnsatz and other. """ def andA(A1, A2, dim_alpha1, dim_alpha2, dim_beta1, dim_beta2): @@ -728,13 +728,13 @@ def andc(c1, c2): def __call__(self, z: Batch[Vector]) -> Scalar | PolyExpAnsatz: r""" - Returns either the value of the representation or a new representation depending on the argument. - If the argument contains None, returns a new representation. - If the argument only contains numbers, returns the value of the representation at that argument. - Note that the batch dimensions are handled differently in the two cases. See subfunctions for furhter information. + Returns either the value of the ansatz or a new ansatz depending on the argument. + If the argument contains None, returns a new ansatz. + If the argument only contains numbers, returns the value of the ansatz at that argument. + Note that the batch dimensions are handled differently in the two cases. See subfunctions for further information. Args: - z: point in C^n where the function is evaluated + z: point in C^n where the function is evaluated. Returns: The value of the function if ``z`` has no ``None``, else it returns a new ansatz. @@ -752,7 +752,7 @@ def __getitem__(self, idx: int | tuple[int, ...]) -> PolyExpAnsatz: for i in idx: if i >= self.num_vars: raise IndexError( - f"Index {i} out of bounds for representation of dimension {self.num_vars}." + f"Index {i} out of bounds for ansatz of dimension {self.num_vars}." ) ret = PolyExpAnsatz(self.A, self.b, self.c) ret._contract_idxs = idx diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 9456b3a5d..f603dbf67 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -255,12 +255,12 @@ def __matmul__(self, other: Representation): wires_result, perm = self.wires @ other.wires idx_z, idx_zconj = self._matmul_indices(other) if type(self.ansatz) is type(other.ansatz): - self_rep = self.ansatz - other_rep = other.ansatz + self_ansatz = self.ansatz + other_ansatz = other.ansatz else: - self_rep = self.to_bargmann().ansatz - other_rep = other.to_bargmann().ansatz + self_ansatz = self.to_bargmann().ansatz + other_ansatz = other.to_bargmann().ansatz - rep = self_rep[idx_z] @ other_rep[idx_zconj] + rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] rep = rep.reorder(perm) if perm else rep return Representation(rep, wires_result) From 5389c0286f0ff3783c2a80f7dc2653773d19c4d8 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 10 Oct 2024 13:42:32 -0400 Subject: [PATCH 53/87] some more cleanup --- mrmustard/lab_dev/circuit_components.py | 40 ++------------------ mrmustard/physics/representations.py | 49 ++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 43 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 7da4490ab..ca03827bc 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -53,14 +53,13 @@ class CircuitComponent: r""" A base class for the circuit components (states, transformations, measurements, and any component made by combining CircuitComponents). CircuitComponents are - defined by their ``representation`` and ``wires`` attributes. See the :class:`Wires` - and :class:`Representation` classes (and their subclasses) for more details. + defined by their ``representation``. See :class:`Representation` for more details. Args: ansatz: An ansatz for this circuit component. wires: The wires of this component. Alternatively, can be a ``(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket)`` - where if any of the modes are out of order the representation + where if any of the modes are out of order the ansatz will be reordered. name: The name of this component. """ @@ -75,40 +74,7 @@ def __init__( ) -> None: self._name = name self._parameter_set = ParameterSet() - - if not isinstance(wires, Wires): - modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket = ( - [tuple(elem) for elem in wires] if wires else [(), (), (), ()] - ) - wires = Wires( - set(modes_out_bra), - set(modes_in_bra), - set(modes_out_ket), - set(modes_in_ket), - ) - # handle out-of-order modes - ob = tuple(sorted(modes_out_bra)) - ib = tuple(sorted(modes_in_bra)) - ok = tuple(sorted(modes_out_ket)) - ik = tuple(sorted(modes_in_ket)) - if ( - ob != modes_out_bra - or ib != modes_in_bra - or ok != modes_out_ket - or ik != modes_in_ket - ): - offsets = [len(ob), len(ob) + len(ib), len(ob) + len(ib) + len(ok)] - perm = ( - tuple(np.argsort(modes_out_bra)) - + tuple(np.argsort(modes_in_bra) + offsets[0]) - + tuple(np.argsort(modes_out_ket) + offsets[1]) - + tuple(np.argsort(modes_in_ket) + offsets[2]) - ) - if ansatz is not None: - self._representation = Representation(ansatz.reorder(tuple(perm)), wires) - - if not hasattr(self, "_representation"): - self._representation = Representation(ansatz, wires) + self._representation = Representation(ansatz, wires) def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: """ diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index f603dbf67..a5843c893 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -20,6 +20,8 @@ from typing import Sequence from enum import Enum +import numpy as np + from mrmustard import math from mrmustard.utils.typing import ( ComplexTensor, @@ -75,17 +77,52 @@ class Representation: Args: ansatz: An ansatz for this representation. - wires: The wires of this representation. + wires: The wires of this representation. Alternatively, can be + a ``(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket)`` + sequence where if any of the modes are out of order the ansatz + will be reordered. wire_reps: An optional dictionary for keeping track of each wire's representation. """ def __init__( self, ansatz: Ansatz | None, - wires: Wires | None, + wires: Wires | Sequence[tuple[int]] | None, wire_reps: dict | None = None, ) -> None: self._ansatz = ansatz + + if not isinstance(wires, Wires): + modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket = ( + [tuple(elem) for elem in wires] if wires else [(), (), (), ()] + ) + wires = Wires( + set(modes_out_bra), + set(modes_in_bra), + set(modes_out_ket), + set(modes_in_ket), + ) + # handle out-of-order modes + ob = tuple(sorted(modes_out_bra)) + ib = tuple(sorted(modes_in_bra)) + ok = tuple(sorted(modes_out_ket)) + ik = tuple(sorted(modes_in_ket)) + if ( + ob != modes_out_bra + or ib != modes_in_bra + or ok != modes_out_ket + or ik != modes_in_ket + ): + offsets = [len(ob), len(ob) + len(ib), len(ob) + len(ib) + len(ok)] + perm = ( + tuple(np.argsort(modes_out_bra)) + + tuple(np.argsort(modes_in_bra) + offsets[0]) + + tuple(np.argsort(modes_out_ket) + offsets[1]) + + tuple(np.argsort(modes_in_ket) + offsets[2]) + ) + if ansatz is not None: + self._ansatz = ansatz.reorder(tuple(perm)) + self._wires = wires self._wire_reps = wire_reps or dict.fromkeys(wires.modes, RepEnum.from_ansatz(ansatz)) @@ -153,17 +190,17 @@ def bargmann_triple( def fock_array(self, shape: int | Sequence[int], batched=False) -> ComplexTensor: r""" - Returns an array representation of this component in the Fock basis with the given shape. + Returns an array of this representation in the Fock basis with the given shape. If the shape is not given, it defaults to the ``auto_shape`` of the component if it is available, otherwise it defaults to the value of ``AUTOSHAPE_MAX`` in the settings. Args: - shape: The shape of the returned representation. If ``shape`` is given as an ``int``, + shape: The shape of the returned array. If ``shape`` is given as an ``int``, it is broadcasted to all the dimensions. If not given, it is estimated. - batched: Whether the returned representation is batched or not. If ``False`` (default) + batched: Whether the returned array is batched or not. If ``False`` (default) it will squeeze the batch dimension if it is 1. Returns: - array: The Fock representation of this component. + array: The Fock array of this representation. """ num_vars = self.ansatz.num_vars if isinstance(shape, int): From 19afefe3cf108f5ee93473be30c2c415e495f5af Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 10 Oct 2024 14:51:52 -0400 Subject: [PATCH 54/87] some more cleanup --- mrmustard/lab_dev/circuit_components.py | 4 +++- mrmustard/lab_dev/states/base.py | 22 ++++++++++++++-------- mrmustard/physics/representations.py | 15 +++++++++------ 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index ca03827bc..d20f16526 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -112,7 +112,9 @@ def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: @classmethod def _deserialize(cls, data: dict) -> CircuitComponent: - """Deserialization when within a circuit.""" + r""" + Deserialization when within a circuit. + """ if "rep_class" in data: rep_class, wires, name = map(data.pop, ["rep_class", "wires", "name"]) rep = locate(rep_class).from_dict(data) diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index bbc60f688..a57b3b236 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -657,15 +657,16 @@ def __init__( ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): + modes = set(modes) if ansatz and ansatz.num_vars != 2 * len(modes): raise ValueError( f"Expected a representation with {2*len(modes)} variables, found {ansatz.num_vars}." ) super().__init__( - wires=[modes, (), modes, ()], + ansatz=ansatz, + wires=Wires(modes_out_bra=modes, modes_out_ket=modes), name=name, ) - self._representation = Representation(ansatz, self.wires) @property def is_positive(self) -> bool: @@ -697,8 +698,10 @@ def is_physical(self) -> bool: @property def probability(self) -> float: - r"""Probability (trace) of this DM, using the batch dimension of the Ansatz - as a convex combination of states.""" + r""" + Probability (trace) of this DM, using the batch dimension of the Ansatz + as a convex combination of states. + """ return math.sum(self._probabilities) @property @@ -707,7 +710,8 @@ def purity(self) -> float: @property def _probabilities(self) -> RealVector: - r"""Element-wise probabilities along the batch dimension of this DM. + r""" + Element-wise probabilities along the batch dimension of this DM. Useful for cases where the batch dimension does not mean a convex combination of states. """ idx_ket = self.wires.output.ket.indices @@ -717,7 +721,8 @@ def _probabilities(self) -> RealVector: @property def _purities(self) -> RealVector: - r"""Element-wise purities along the batch dimension of this DM. + r""" + Element-wise purities along the batch dimension of this DM. Useful for cases where the batch dimension does not mean a convex combination of states. """ return self._L2_norms / self._probabilities @@ -986,15 +991,16 @@ def __init__( ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ): + modes = set(modes) if ansatz and ansatz.num_vars != len(modes): raise ValueError( f"Expected a representation with {len(modes)} variables, found {ansatz.num_vars}." ) super().__init__( - wires=[(), (), modes, ()], + ansatz=ansatz, + wires=Wires(modes_out_ket=modes), name=name, ) - self._representation = Representation(ansatz, self.wires) @property def is_physical(self) -> bool: diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index a5843c893..426bc1dc1 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -49,13 +49,16 @@ class RepEnum(Enum): PHASESPACE = 4 @classmethod - def from_ansatz(cls, value: Ansatz): + def from_ansatz(cls, ansatz: Ansatz): r""" Returns a ``RepEnum`` from an ``Ansatz``. + + Args: + ansatz: The ansatz. """ - if isinstance(value, PolyExpAnsatz): + if isinstance(ansatz, PolyExpAnsatz): return cls(1) - elif isinstance(value, ArrayAnsatz): + elif isinstance(ansatz, ArrayAnsatz): return cls(2) else: return cls(0) @@ -124,7 +127,7 @@ def __init__( self._ansatz = ansatz.reorder(tuple(perm)) self._wires = wires - self._wire_reps = wire_reps or dict.fromkeys(wires.modes, RepEnum.from_ansatz(ansatz)) + self._wire_reps = wire_reps or dict.fromkeys(wires.indices, RepEnum.from_ansatz(ansatz)) @property def adjoint(self) -> Representation: @@ -136,7 +139,7 @@ def adjoint(self) -> Representation: kets = self.wires.ket.indices ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None wires = self.wires.adjoint - return Representation(ansatz, wires) + return Representation(ansatz, wires, self._wire_reps) @property def ansatz(self) -> Ansatz | None: @@ -157,7 +160,7 @@ def dual(self) -> Representation: ob = self.wires.bra.output.indices ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None wires = self.wires.dual - return Representation(ansatz, wires) + return Representation(ansatz, wires, self._wire_reps) @property def wires(self) -> Wires | None: From 6781fe23532dd016f83fa7f05642eba0a9d1917e Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 10 Oct 2024 15:38:52 -0400 Subject: [PATCH 55/87] CC -> representation arg --- mrmustard/lab_dev/circuit_components.py | 22 +++----- .../circuit_components_utils/trace_out.py | 7 ++- mrmustard/lab_dev/states/base.py | 9 +--- mrmustard/lab_dev/transformations/base.py | 8 ++- mrmustard/physics/representations.py | 4 +- tests/test_lab_dev/test_circuit_components.py | 52 +++++++++++-------- .../test_states/test_states_base.py | 5 +- 7 files changed, 54 insertions(+), 53 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index d20f16526..e0d69565c 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -56,11 +56,7 @@ class CircuitComponent: defined by their ``representation``. See :class:`Representation` for more details. Args: - ansatz: An ansatz for this circuit component. - wires: The wires of this component. Alternatively, can be - a ``(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket)`` - where if any of the modes are out of order the ansatz - will be reordered. + represetation: The representation of this circuit component. name: The name of this component. """ @@ -68,13 +64,12 @@ class CircuitComponent: def __init__( self, - ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, - wires: Wires | Sequence[tuple[int]] | None = None, + representation: Representation | None = None, name: str | None = None, ) -> None: self._name = name self._parameter_set = ParameterSet() - self._representation = Representation(ansatz, wires) + self._representation = representation or Representation() def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: """ @@ -128,8 +123,7 @@ def adjoint(self) -> CircuitComponent: The adjoint of this component obtained by conjugating the representation and swapping the ket and bra wires. """ - rep = self.representation.adjoint - ret = CircuitComponent(rep.ansatz, rep.wires, self.name) + ret = CircuitComponent(self.representation.adjoint, self.name) ret.short_name = self.short_name for param in self.parameter_set.all_parameters.values(): ret._add_parameter(param) @@ -141,8 +135,7 @@ def dual(self) -> CircuitComponent: The dual of this component obtained by conjugating the representation and swapping the input and output wires. """ - rep = self.representation.dual - ret = CircuitComponent(rep.ansatz, rep.wires, self.name) + ret = CircuitComponent(self.representation.dual, self.name) ret.short_name = self.short_name for param in self.parameter_set.all_parameters.values(): ret._add_parameter(param) @@ -383,13 +376,14 @@ def _from_attributes( A circuit component with the given attributes. """ types = {"Ket", "DM", "Unitary", "Operation", "Channel", "Map"} + rep = Representation(ansatz, wires) for tp in cls.mro(): if tp.__name__ in types: ret = tp() ret._name = name - ret._representation = Representation(ansatz, wires) + ret._representation = rep return ret - return CircuitComponent(ansatz, wires, name) + return CircuitComponent(rep, name) def auto_shape(self, **_) -> tuple[int, ...]: r""" diff --git a/mrmustard/lab_dev/circuit_components_utils/trace_out.py b/mrmustard/lab_dev/circuit_components_utils/trace_out.py index ff862a5d9..f689d40a7 100644 --- a/mrmustard/lab_dev/circuit_components_utils/trace_out.py +++ b/mrmustard/lab_dev/circuit_components_utils/trace_out.py @@ -24,6 +24,7 @@ from ..circuit_components import CircuitComponent from ...physics.ansatz import PolyExpAnsatz +from ...physics.representations import Representation __all__ = ["TraceOut"] @@ -62,8 +63,10 @@ def __init__( modes: Sequence[int], ): super().__init__( - wires=[(), modes, (), modes], - ansatz=PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)), + Representation( + PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)), + [(), modes, (), modes], + ), name="Tr", ) diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index a57b3b236..fc26997c9 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -663,8 +663,7 @@ def __init__( f"Expected a representation with {2*len(modes)} variables, found {ansatz.num_vars}." ) super().__init__( - ansatz=ansatz, - wires=Wires(modes_out_bra=modes, modes_out_ket=modes), + Representation(ansatz=ansatz, wires=Wires(modes_out_bra=modes, modes_out_ket=modes)), name=name, ) @@ -996,11 +995,7 @@ def __init__( raise ValueError( f"Expected a representation with {len(modes)} variables, found {ansatz.num_vars}." ) - super().__init__( - ansatz=ansatz, - wires=Wires(modes_out_ket=modes), - name=name, - ) + super().__init__(Representation(ansatz=ansatz, wires=Wires(modes_out_ket=modes)), name=name) @property def is_physical(self) -> bool: diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index 719617c08..c179d94b4 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -30,6 +30,7 @@ from typing import Sequence from mrmustard import math, settings from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz +from mrmustard.physics.representations import Representation from mrmustard.utils.typing import ComplexMatrix from mrmustard.physics.bargmann_utils import au2Symplectic, symplectic2Au, XY_of_channel from ..circuit_components import CircuitComponent @@ -125,9 +126,7 @@ def __init__( name: str | None = None, ): super().__init__( - ansatz=ansatz, - wires=[(), (), modes_out, modes_in], - name=name, + Representation(ansatz=ansatz, wires=[(), (), modes_out, modes_in]), name=name ) @classmethod @@ -281,8 +280,7 @@ def __init__( name: str | None = None, ): super().__init__( - ansatz=ansatz, - wires=[modes_out, modes_in, modes_out, modes_in], + Representation(ansatz=ansatz, wires=[modes_out, modes_in, modes_out, modes_in]), name=name or self.__class__.__name__, ) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 426bc1dc1..23d492051 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -89,8 +89,8 @@ class Representation: def __init__( self, - ansatz: Ansatz | None, - wires: Wires | Sequence[tuple[int]] | None, + ansatz: Ansatz | None = None, + wires: Wires | Sequence[tuple[int]] | None = None, wire_reps: dict | None = None, ) -> None: self._ansatz = ansatz diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 5c9bb7776..5aa3edc8c 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -56,18 +56,20 @@ class TestCircuitComponent: @pytest.mark.parametrize("y", [0.4, [0.5, 0.6]]) def test_init(self, x, y): name = "my_component" - representation = PolyExpAnsatz(*displacement_gate_Abc(x, y)) - cc = CircuitComponent(representation, wires=[(), (), (1, 8), (1, 8)], name=name) + ansatz = PolyExpAnsatz(*displacement_gate_Abc(x, y)) + cc = CircuitComponent(Representation(ansatz, [(), (), (1, 8), (1, 8)]), name=name) assert cc.name == name assert list(cc.modes) == [1, 8] assert cc.wires == Wires(modes_out_ket={1, 8}, modes_in_ket={1, 8}) - assert cc.ansatz == representation + assert cc.ansatz == ansatz assert cc.manual_shape == [None] * 4 def test_missing_name(self): cc = CircuitComponent( - PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.2)), wires=[(), (), (1, 8), (1, 8)] + Representation( + PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.2)), [(), (), (1, 8), (1, 8)] + ) ) cc._name = None assert cc.name == "CC18" @@ -80,16 +82,16 @@ def test_modes_init_out_of_order(self): m1 = (8, 1) m2 = (1, 8) - r1 = PolyExpAnsatz(*displacement_gate_Abc(x=[0.1, 0.2])) - r2 = PolyExpAnsatz(*displacement_gate_Abc(x=[0.2, 0.1])) + a1 = PolyExpAnsatz(*displacement_gate_Abc(x=[0.1, 0.2])) + a2 = PolyExpAnsatz(*displacement_gate_Abc(x=[0.2, 0.1])) - cc1 = CircuitComponent(r1, wires=[(), (), m1, m1]) - cc2 = CircuitComponent(r2, wires=[(), (), m2, m2]) + cc1 = CircuitComponent(Representation(a1, wires=[(), (), m1, m1])) + cc2 = CircuitComponent(Representation(a2, wires=[(), (), m2, m2])) assert cc1 == cc2 - r3 = (cc1.adjoint @ cc1).ansatz - cc3 = CircuitComponent(r3, wires=[m2, m2, m2, m1]) - cc4 = CircuitComponent(r3, wires=[m2, m2, m2, m2]) + a3 = (cc1.adjoint @ cc1).ansatz + cc3 = CircuitComponent(Representation(a3, wires=[m2, m2, m2, m1])) + cc4 = CircuitComponent(Representation(a3, wires=[m2, m2, m2, m2])) assert cc3.ansatz == cc4.ansatz.reorder([0, 1, 2, 3, 4, 5, 7, 6]) @pytest.mark.parametrize("x", [0.1, [0.2, 0.3]]) @@ -153,7 +155,9 @@ def test_dual(self): def test_light_copy(self): d1 = CircuitComponent( - PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.1)), wires=[(), (), (1,), (1,)] + Representation( + PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.1)), wires=[(), (), (1,), (1,)] + ) ) d1_cp = d1._light_copy() @@ -207,8 +211,10 @@ def test_to_fock_bargmann_Dgate(self): def test_to_fock_poly_exp(self): A, b, _ = Abc_triple(3) c = np.random.random((1, 5)) - barg = PolyExpAnsatz(A, b, c) - fock_cc = CircuitComponent(barg, wires=[(), (), (0, 1), ()]).to_fock(shape=(10, 10)) + polyexp = PolyExpAnsatz(A, b, c) + fock_cc = CircuitComponent(Representation(polyexp, wires=[(), (), (0, 1), ()])).to_fock( + shape=(10, 10) + ) poly = math.hermite_renormalized(A, b, 1, (10, 10, 5)) assert fock_cc.ansatz._original_abc_data is None assert np.allclose(fock_cc.ansatz.data, np.einsum("ijk,k", poly, c[0])) @@ -436,8 +442,10 @@ def test_rshift_scalar(self): assert math.allclose(result2.ansatz.c, 0.8 * d0.ansatz.c) def test_repr(self): - c1 = CircuitComponent(wires=Wires(modes_out_ket=(0, 1, 2))) - c2 = CircuitComponent(wires=Wires(modes_out_ket=(0, 1, 2)), name="my_component") + c1 = CircuitComponent(Representation(wires=Wires(modes_out_ket=(0, 1, 2)))) + c2 = CircuitComponent( + Representation(wires=Wires(modes_out_ket=(0, 1, 2))), name="my_component" + ) assert repr(c1) == "CircuitComponent(modes=[0, 1, 2], name=CC012)" assert repr(c2) == "CircuitComponent(modes=[0, 1, 2], name=my_component)" @@ -505,8 +513,8 @@ def test_ipython_repr_invalid_obj(self, mock_display): def test_serialize_default_behaviour(self): """Test the default serializer.""" name = "my_component" - rep = PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.4)) - cc = CircuitComponent(rep, wires=[(), (), (1, 8), (1, 8)], name=name) + ansatz = PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.4)) + cc = CircuitComponent(Representation(ansatz, wires=[(), (), (1, 8), (1, 8)]), name=name) kwargs, arrays = cc._serialize() assert kwargs == { "class": f"{CircuitComponent.__module__}.CircuitComponent", @@ -514,7 +522,7 @@ def test_serialize_default_behaviour(self): "rep_class": f"{PolyExpAnsatz.__module__}.PolyExpAnsatz", "name": name, } - assert arrays == {"A": rep.A, "b": rep.b, "c": rep.c} + assert arrays == {"A": ansatz.A, "b": ansatz.b, "c": ansatz.c} def test_serialize_fail_when_no_modes_input(self): """Test that the serializer fails if no modes or name+wires are present.""" @@ -522,8 +530,10 @@ def test_serialize_fail_when_no_modes_input(self): class MyComponent(CircuitComponent): """A dummy class without a valid modes kwarg.""" - def __init__(self, rep, custom_modes): - super().__init__(rep, wires=[custom_modes] * 4, name="my_component") + def __init__(self, ansatz, custom_modes): + super().__init__( + Representation(ansatz, wires=[custom_modes] * 4), name="my_component" + ) cc = MyComponent(PolyExpAnsatz(*displacement_gate_Abc(0.1, 0.4)), [0, 1]) with pytest.raises( diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index 31b7f6a08..2743ec431 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -41,6 +41,7 @@ Vacuum, ) from mrmustard.lab_dev.transformations import Attenuator, Dgate, Sgate +from mrmustard.physics.representations import Representation from mrmustard.physics.wires import Wires from mrmustard.widgets import state as state_widget @@ -344,7 +345,7 @@ def test_expectation_error(self): with pytest.raises(ValueError, match="Cannot calculate the expectation value"): ket.expectation(op1) - op2 = CircuitComponent(wires=[(), (), (1,), (0,)]) + op2 = CircuitComponent(Representation(wires=[(), (), (1,), (0,)])) with pytest.raises(ValueError, match="different modes"): ket.expectation(op2) @@ -810,7 +811,7 @@ def test_expectation_error(self): with pytest.raises(ValueError, match="Cannot calculate the expectation value"): dm.expectation(op1) - op2 = CircuitComponent(wires=[(), (), (1,), (0,)]) + op2 = CircuitComponent(Representation(wires=[(), (), (1,), (0,)])) with pytest.raises(ValueError, match="different modes"): dm.expectation(op2) From 4fd0cd034385f4606154c6dfdc1d648434e1bf46 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 10 Oct 2024 15:56:51 -0400 Subject: [PATCH 56/87] doc --- mrmustard/lab_dev/circuit_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index e0d69565c..5f29c8530 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -56,7 +56,7 @@ class CircuitComponent: defined by their ``representation``. See :class:`Representation` for more details. Args: - represetation: The representation of this circuit component. + representation: The representation of this circuit component. name: The name of this component. """ From 5f3d82ed4dee777d4c072aaca188e9ee9797cb31 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 15 Oct 2024 11:49:37 -0400 Subject: [PATCH 57/87] states --- .../circuit_components_utils/b_to_q.py | 42 ------- mrmustard/lab_dev/states/base.py | 105 +++++++++++------- mrmustard/lab_dev/states/coherent.py | 12 +- .../lab_dev/states/displaced_squeezed.py | 12 +- mrmustard/lab_dev/states/number.py | 15 +-- .../lab_dev/states/quadrature_eigenstate.py | 14 ++- mrmustard/lab_dev/states/sauron.py | 13 ++- mrmustard/lab_dev/states/squeezed_vacuum.py | 12 +- mrmustard/lab_dev/states/thermal.py | 9 +- .../states/two_mode_squeezed_vacuum.py | 11 +- mrmustard/lab_dev/states/vacuum.py | 7 +- mrmustard/physics/representations.py | 4 +- .../test_states/test_states_base.py | 6 +- 13 files changed, 126 insertions(+), 136 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index 5eee24a13..39a3fbf0c 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -19,16 +19,11 @@ from __future__ import annotations from typing import Sequence -import numpy as np -import numbers - from mrmustard.physics import triples from mrmustard.math.parameters import Constant from ..transformations.base import Operation from ...physics.ansatz import PolyExpAnsatz -from ...physics.representations import RepEnum -from ..circuit_components import CircuitComponent __all__ = ["BtoQ"] @@ -59,40 +54,3 @@ def __init__( name="BtoQ", ) self._add_parameter(Constant(phi, "phi")) - - def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitComponent | complex: - if hasattr(other, "__custom_rrshift__"): - return other.__custom_rrshift__(self) - - if isinstance(other, (numbers.Number, np.ndarray)): - return self * other - - s_k = other.wires.ket - s_b = other.wires.bra - o_k = self.wires.ket - o_b = self.wires.bra - - only_ket = (not s_b and s_k) and (not o_b and o_k) - only_bra = (not s_k and s_b) and (not o_k and o_b) - both_sides = s_b and s_k and o_b and o_k - - self_needs_bra = (not s_b and s_k) and (o_b and o_k) - self_needs_ket = (not s_k and s_b) and (o_b and o_k) - - other_needs_bra = (s_b and s_k) and (not o_b and o_k) - other_needs_ket = (s_b and s_k) and (not o_k and o_b) - - if only_ket or only_bra or both_sides: - ret = other @ self - elif self_needs_bra or self_needs_ket: - ret = other.adjoint @ (other @ self) - elif other_needs_bra or other_needs_ket: - ret = (other @ self) @ self.adjoint - else: - msg = f"``>>`` not supported between {other} and {self} because it's not clear " - msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components." - raise ValueError(msg) - - temp = dict.fromkeys(self.modes, RepEnum.QUADRATURE) - ret._representation._wire_reps.update(temp) - return self._rshift_return(ret) diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index fc26997c9..8a75231fa 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -250,6 +250,26 @@ def from_fock( modes. """ + @classmethod + @abstractmethod + def from_modes( + cls, + modes: Sequence[int], + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, + name: str | None = None, + ) -> State: + r""" + Initializes a state of type ``cls`` given modes and an ansatz. + + Args: + modes: The modes of this state. + ansatz: The ansatz of this state. + name: The name of this state. + + Returns: + A state. + """ + @classmethod @abstractmethod def from_phase_space( @@ -651,22 +671,6 @@ class DM(State): short_name = "DM" - def __init__( - self, - modes: Sequence[int] = (), - ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, - name: str | None = None, - ): - modes = set(modes) - if ansatz and ansatz.num_vars != 2 * len(modes): - raise ValueError( - f"Expected a representation with {2*len(modes)} variables, found {ansatz.num_vars}." - ) - super().__init__( - Representation(ansatz=ansatz, wires=Wires(modes_out_bra=modes, modes_out_ket=modes)), - name=name, - ) - @property def is_positive(self) -> bool: r""" @@ -733,7 +737,7 @@ def from_bargmann( triple: tuple[ComplexMatrix, ComplexVector, complex], name: str | None = None, ) -> State: - return DM(modes, PolyExpAnsatz(*triple), name) + return DM.from_modes(modes, PolyExpAnsatz(*triple), name) @classmethod def from_fock( @@ -743,7 +747,22 @@ def from_fock( name: str | None = None, batched: bool = False, ) -> State: - return DM(modes, ArrayAnsatz(array, batched), name) + return DM.from_modes(modes, ArrayAnsatz(array, batched), name) + + @classmethod + def from_modes( + cls, + modes: Sequence[int], + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, + name: str | None = None, + ) -> State: + modes = set(modes) + if ansatz and ansatz.num_vars != 2 * len(modes): + raise ValueError( + f"Expected a representation with {2*len(modes)} variables, found {ansatz.num_vars}." + ) + wires = Wires(modes_out_bra=modes, modes_out_ket=modes) + return DM(Representation(ansatz, wires), name) @classmethod def from_phase_space( @@ -769,7 +788,7 @@ def from_phase_space( cov = math.astensor(cov) means = math.astensor(means) shape_check(cov, means, 2 * len(modes), "Phase space") - return coeff * DM( + return coeff * DM.from_modes( modes, PolyExpAnsatz.from_function(fn=wigner_to_bargmann_rho, cov=cov, means=means), name, @@ -801,8 +820,8 @@ def from_quadrature( with the number of modes. """ QtoB = BtoQ(modes, phi).inverse() - Q = DM(modes, PolyExpAnsatz(*triple)) - return DM(modes, (Q >> QtoB).ansatz, name) + Q = DM.from_modes(modes, PolyExpAnsatz(*triple)) + return DM.from_modes(modes, (Q >> QtoB).ansatz, name) @classmethod def random(cls, modes: Sequence[int], m: int | None = None, max_r: float = 1.0) -> DM: @@ -968,7 +987,7 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: w = result.wires if not w.input and w.bra.modes == w.ket.modes: - return DM(w.modes, result.ansatz) + return DM.from_modes(w.modes, result.ansatz) return result @@ -984,19 +1003,6 @@ class Ket(State): short_name = "Ket" - def __init__( - self, - modes: Sequence[int] = (), - ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, - name: str | None = None, - ): - modes = set(modes) - if ansatz and ansatz.num_vars != len(modes): - raise ValueError( - f"Expected a representation with {len(modes)} variables, found {ansatz.num_vars}." - ) - super().__init__(Representation(ansatz=ansatz, wires=Wires(modes_out_ket=modes)), name=name) - @property def is_physical(self) -> bool: r""" @@ -1035,7 +1041,7 @@ def from_bargmann( triple: tuple[ComplexMatrix, ComplexVector, complex], name: str | None = None, ) -> State: - return Ket(modes, PolyExpAnsatz(*triple), name) + return Ket.from_modes(modes, PolyExpAnsatz(*triple), name) @classmethod def from_fock( @@ -1045,7 +1051,22 @@ def from_fock( name: str | None = None, batched: bool = False, ) -> State: - return Ket(modes, ArrayAnsatz(array, batched), name) + return Ket.from_modes(modes, ArrayAnsatz(array, batched), name) + + @classmethod + def from_modes( + cls, + modes: Sequence[int], + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, + name: str | None = None, + ) -> State: + modes = set(modes) + if ansatz and ansatz.num_vars != len(modes): + raise ValueError( + f"Expected a representation with {len(modes)} variables, found {ansatz.num_vars}." + ) + wires = Wires(modes_out_ket=modes) + return Ket(Representation(ansatz, wires), name) @classmethod def from_phase_space( @@ -1064,7 +1085,7 @@ def from_phase_space( if p < 1.0 - atol_purity: msg = f"Cannot initialize a Ket: purity is {p:.5f} (must be at least 1.0-{atol_purity})." raise ValueError(msg) - return Ket( + return Ket.from_modes( modes, coeff * PolyExpAnsatz.from_function(fn=wigner_to_bargmann_psi, cov=cov, means=means), name, @@ -1079,8 +1100,8 @@ def from_quadrature( name: str | None = None, ) -> State: QtoB = BtoQ(modes, phi).inverse() - Q = Ket(modes, PolyExpAnsatz(*triple)) - return Ket(modes, (Q >> QtoB).ansatz, name) + Q = Ket.from_modes(modes, PolyExpAnsatz(*triple)) + return Ket.from_modes(modes, (Q >> QtoB).ansatz, name) @classmethod def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Ket: @@ -1259,7 +1280,7 @@ def __rshift__(self, other: CircuitComponent | Scalar) -> CircuitComponent | Bat if not result.wires.input: if not result.wires.bra: - return Ket(result.wires.modes, result.ansatz) + return Ket.from_modes(result.wires.modes, result.ansatz) elif result.wires.bra.modes == result.wires.ket.modes: - result = DM(result.wires.modes, result.ansatz) + result = DM.from_modes(result.wires.modes, result.ansatz) return result diff --git a/mrmustard/lab_dev/states/coherent.py b/mrmustard/lab_dev/states/coherent.py index 10a535fc5..41f98cbc3 100644 --- a/mrmustard/lab_dev/states/coherent.py +++ b/mrmustard/lab_dev/states/coherent.py @@ -20,7 +20,6 @@ from typing import Sequence -from mrmustard.physics.representations import Representation from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket @@ -78,12 +77,13 @@ def __init__( x_bounds: tuple[float | None, float | None] = (None, None), y_bounds: tuple[float | None, float | None] = (None, None), ): - super().__init__(modes=modes, name="Coherent") + super().__init__(name="Coherent") + xs, ys = list(reshape_params(len(modes), x=x, y=y)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function(fn=triples.coherent_state_Abc, x=self.x, y=self.y), - self.wires, - ) + self._representation = self.from_modes( + modes=modes, + ansatz=PolyExpAnsatz.from_function(fn=triples.coherent_state_Abc, x=self.x, y=self.y), + ).representation diff --git a/mrmustard/lab_dev/states/displaced_squeezed.py b/mrmustard/lab_dev/states/displaced_squeezed.py index d5b946dd6..3b3031260 100644 --- a/mrmustard/lab_dev/states/displaced_squeezed.py +++ b/mrmustard/lab_dev/states/displaced_squeezed.py @@ -20,7 +20,6 @@ from typing import Sequence -from mrmustard.physics.representations import Representation from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket @@ -77,7 +76,8 @@ def __init__( r_bounds: tuple[float | None, float | None] = (None, None), phi_bounds: tuple[float | None, float | None] = (None, None), ): - super().__init__(modes=modes, name="DisplacedSqueezed") + super().__init__(name="DisplacedSqueezed") + params = reshape_params(len(modes), x=x, y=y, r=r, phi=phi) xs, ys, rs, phis = list(params) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) @@ -85,13 +85,13 @@ def __init__( self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function( + self._representation = self.from_modes( + modes=modes, + ansatz=PolyExpAnsatz.from_function( fn=triples.displaced_squeezed_vacuum_state_Abc, x=self.x, y=self.y, r=self.r, phi=self.phi, ), - self.wires, - ) + ).representation diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index 3158e0f39..5d48af05a 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -20,7 +20,6 @@ from typing import Sequence -from mrmustard.physics.representations import Representation from mrmustard.physics.ansatz import ArrayAnsatz from mrmustard.physics.fock_utils import fock_state from .base import Ket @@ -68,15 +67,17 @@ def __init__( n: int | Sequence[int], cutoffs: int | Sequence[int] | None = None, ) -> None: - super().__init__(modes=modes, name="N") + super().__init__(name="N") + ns, cs = list(reshape_params(len(modes), n=n, cutoffs=n if cutoffs is None else cutoffs)) self._add_parameter(make_parameter(False, ns, "n", (None, None), dtype="int64")) self._add_parameter(make_parameter(False, cs, "cutoffs", (None, None))) + self._representation = self.from_modes( + modes=modes, + ansatz=ArrayAnsatz.from_function( + fock_state, n=self.n.value, cutoffs=self.cutoffs.value + ), + ).representation self.short_name = [str(int(n)) for n in self.n.value] for i, cutoff in enumerate(self.cutoffs.value): self.manual_shape[i] = int(cutoff) + 1 - - self._representation = Representation( - ArrayAnsatz.from_function(fock_state, n=self.n.value, cutoffs=self.cutoffs.value), - self.wires, - ) diff --git a/mrmustard/lab_dev/states/quadrature_eigenstate.py b/mrmustard/lab_dev/states/quadrature_eigenstate.py index 276e19811..38aa33e92 100644 --- a/mrmustard/lab_dev/states/quadrature_eigenstate.py +++ b/mrmustard/lab_dev/states/quadrature_eigenstate.py @@ -64,17 +64,19 @@ def __init__( x_bounds: tuple[float | None, float | None] = (None, None), phi_bounds: tuple[float | None, float | None] = (None, None), ): - super().__init__(modes=modes, name="QuadratureEigenstate") + super().__init__(name="QuadratureEigenstate") + xs, phis = list(reshape_params(len(modes), x=x, phi=phi)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function( + self.manual_shape = (50,) + + self._representation = self.from_modes( + modes=modes, + ansatz=PolyExpAnsatz.from_function( fn=triples.quadrature_eigenstates_Abc, x=self.x, phi=self.phi ), - self.wires, - ) - self.manual_shape = (50,) + ).representation @property def L2_norm(self): diff --git a/mrmustard/lab_dev/states/sauron.py b/mrmustard/lab_dev/states/sauron.py index 846018389..388cbc4e2 100644 --- a/mrmustard/lab_dev/states/sauron.py +++ b/mrmustard/lab_dev/states/sauron.py @@ -16,7 +16,6 @@ from typing import Sequence from mrmustard.lab_dev.states.base import Ket -from mrmustard.physics.representations import Representation from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples @@ -40,12 +39,14 @@ class Sauron(Ket): """ def __init__(self, modes: Sequence[int], n: int, epsilon: float = 0.1): - super().__init__(name=f"Sauron-{n}", modes=modes) + super().__init__(name=f"Sauron-{n}") + self._add_parameter(make_parameter(False, n, "n", (None, None), dtype="int64")) self._add_parameter(make_parameter(False, epsilon, "epsilon", (None, None))) - self._representation = Representation( - PolyExpAnsatz.from_function( + + self._representation = self.from_modes( + modes=modes, + ansatz=PolyExpAnsatz.from_function( triples.sauron_state_Abc, n=self.n.value, epsilon=self.epsilon.value ), - self.wires, - ) + ).representation diff --git a/mrmustard/lab_dev/states/squeezed_vacuum.py b/mrmustard/lab_dev/states/squeezed_vacuum.py index fd713e7ca..730055b9e 100644 --- a/mrmustard/lab_dev/states/squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/squeezed_vacuum.py @@ -65,13 +65,15 @@ def __init__( r_bounds: tuple[float | None, float | None] = (None, None), phi_bounds: tuple[float | None, float | None] = (None, None), ): - super().__init__(modes=modes, name="SqueezedVacuum") + super().__init__(name="SqueezedVacuum") + rs, phis = list(reshape_params(len(modes), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function( + + self._representation = self.from_modes( + modes=modes, + ansatz=PolyExpAnsatz.from_function( fn=triples.squeezed_vacuum_state_Abc, r=self.r, phi=self.phi ), - self.wires, - ) + ).representation diff --git a/mrmustard/lab_dev/states/thermal.py b/mrmustard/lab_dev/states/thermal.py index 35da098b8..f8171da5a 100644 --- a/mrmustard/lab_dev/states/thermal.py +++ b/mrmustard/lab_dev/states/thermal.py @@ -59,9 +59,10 @@ def __init__( nbar_trainable: bool = False, nbar_bounds: tuple[float | None, float | None] = (0, None), ) -> None: - super().__init__(modes=modes, name="Thermal") + super().__init__(name="Thermal") (nbars,) = list(reshape_params(len(modes), nbar=nbar)) self._add_parameter(make_parameter(nbar_trainable, nbars, "nbar", nbar_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function(fn=triples.thermal_state_Abc, nbar=self.nbar), self.wires - ) + self._representation = self.from_modes( + modes=modes, + ansatz=PolyExpAnsatz.from_function(fn=triples.thermal_state_Abc, nbar=self.nbar), + ).representation diff --git a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py index c8ed7a66b..511aee639 100644 --- a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py @@ -20,7 +20,6 @@ from typing import Sequence -from mrmustard.physics.representations import Representation from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket @@ -63,13 +62,13 @@ def __init__( r_bounds: tuple[float | None, float | None] = (None, None), phi_bounds: tuple[float | None, float | None] = (None, None), ): - super().__init__(modes=modes, name="TwoModeSqueezedVacuum") + super().__init__(name="TwoModeSqueezedVacuum") rs, phis = list(reshape_params(int(len(modes) / 2), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function( + self._representation = self.from_modes( + modes=modes, + ansatz=PolyExpAnsatz.from_function( fn=triples.two_mode_squeezed_vacuum_state_Abc, r=self.r, phi=self.phi ), - self.wires, - ) + ).representation diff --git a/mrmustard/lab_dev/states/vacuum.py b/mrmustard/lab_dev/states/vacuum.py index d50493a53..bae5cdbd3 100644 --- a/mrmustard/lab_dev/states/vacuum.py +++ b/mrmustard/lab_dev/states/vacuum.py @@ -60,8 +60,11 @@ def __init__( self, modes: Sequence[int], ) -> None: - ansatz = PolyExpAnsatz.from_function(fn=triples.vacuum_state_Abc, n_modes=len(modes)) - super().__init__(modes=modes, ansatz=ansatz, name="Vac") + super().__init__(name="Vac") + self._representation = self.from_modes( + modes=modes, + ansatz=PolyExpAnsatz.from_function(fn=triples.vacuum_state_Abc, n_modes=len(modes)), + ).representation for i in range(len(modes)): self.manual_shape[i] = 1 diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 23d492051..bc28ca476 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -127,7 +127,9 @@ def __init__( self._ansatz = ansatz.reorder(tuple(perm)) self._wires = wires - self._wire_reps = wire_reps or dict.fromkeys(wires.indices, RepEnum.from_ansatz(ansatz)) + self._wire_reps = wire_reps or dict.fromkeys( + wires.indices, (RepEnum.from_ansatz(ansatz), None) + ) @property def adjoint(self) -> Representation: diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index 4fdeda44b..8e657ab1e 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -75,7 +75,7 @@ class TestKet: # pylint: disable=too-many-public-methods @pytest.mark.parametrize("name", [None, "my_ket"]) @pytest.mark.parametrize("modes", [[0], [0, 1], [3, 19, 2]]) def test_init(self, name, modes): - state = Ket(modes, None, name) + state = Ket.from_modes(modes, None, name) assert state.name in ("Ket0", "Ket01", "Ket2319") if not name else name assert list(state.modes) == sorted(modes) @@ -197,7 +197,7 @@ def test_probability(self): @pytest.mark.parametrize("modes", [[0], [0, 1], [3, 19, 2]]) def test_purity(self, modes): - state = Ket(modes, None, "my_ket") + state = Ket.from_modes(modes, None, "my_ket") assert state.purity == 1 assert state.is_pure @@ -481,7 +481,7 @@ class TestDM: # pylint:disable=too-many-public-methods @pytest.mark.parametrize("name", [None, "my_dm"]) @pytest.mark.parametrize("modes", [{0}, {0, 1}, {3, 19, 2}]) def test_init(self, name, modes): - state = DM(modes, None, name) + state = DM.from_modes(modes, None, name) assert state.name in ("DM0", "DM01", "DM2319") if not name else name assert list(state.modes) == sorted(modes) From 381e49568b7bd197b0b3672f9e24b73e23ce7cf0 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 15 Oct 2024 13:15:39 -0400 Subject: [PATCH 58/87] transformations --- .../circuit_components_utils/b_to_ps.py | 8 +- .../circuit_components_utils/b_to_q.py | 15 +- .../lab_dev/transformations/amplifier.py | 11 +- .../lab_dev/transformations/attenuator.py | 12 +- mrmustard/lab_dev/transformations/base.py | 134 +++++++++++++----- mrmustard/lab_dev/transformations/bsgate.py | 12 +- mrmustard/lab_dev/transformations/cft.py | 14 +- mrmustard/lab_dev/transformations/dgate.py | 13 +- .../lab_dev/transformations/fockdamping.py | 11 +- mrmustard/lab_dev/transformations/ggate.py | 12 +- mrmustard/lab_dev/transformations/identity.py | 8 +- mrmustard/lab_dev/transformations/rgate.py | 11 +- mrmustard/lab_dev/transformations/s2gate.py | 12 +- mrmustard/lab_dev/transformations/sgate.py | 14 +- .../test_transformations_base.py | 4 +- 15 files changed, 181 insertions(+), 110 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index d741f60c6..cca0ca8b3 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -43,13 +43,13 @@ def __init__( modes: Sequence[int], s: float, ): - super().__init__( - modes_out=modes, + super().__init__(name="BtoPS") + self._representation = self.from_modes( modes_in=modes, + modes_out=modes, ansatz=PolyExpAnsatz.from_function( fn=triples.displacement_map_s_parametrized_Abc, s=s, n_modes=len(modes) ), - name="BtoPS", - ) + ).representation self._add_parameter(Constant(s, "s")) self.s = s diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index 39a3fbf0c..94acfcca5 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -44,13 +44,12 @@ def __init__( modes: Sequence[int], phi: float = 0.0, ): - ansatz = PolyExpAnsatz.from_function( - fn=triples.bargmann_to_quadrature_Abc, n_modes=len(modes), phi=phi - ) - super().__init__( - modes_out=modes, + super().__init__(name="BtoQ") + self._representation = self.from_modes( modes_in=modes, - ansatz=ansatz, - name="BtoQ", - ) + modes_out=modes, + ansatz=PolyExpAnsatz.from_function( + fn=triples.bargmann_to_quadrature_Abc, n_modes=len(modes), phi=phi + ), + ).representation self._add_parameter(Constant(phi, "phi")) diff --git a/mrmustard/lab_dev/transformations/amplifier.py b/mrmustard/lab_dev/transformations/amplifier.py index 8d6807e8e..209fc36dc 100644 --- a/mrmustard/lab_dev/transformations/amplifier.py +++ b/mrmustard/lab_dev/transformations/amplifier.py @@ -21,7 +21,6 @@ from typing import Sequence from .base import Channel -from ...physics.representations import Representation from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -85,7 +84,7 @@ def __init__( gain_trainable: bool = False, gain_bounds: tuple[float | None, float | None] = (1.0, None), ): - super().__init__(modes_out=modes, modes_in=modes, name="Amp") + super().__init__(name="Amp") (gs,) = list(reshape_params(len(modes), gain=gain)) self._add_parameter( make_parameter( @@ -96,6 +95,8 @@ def __init__( None, ) ) - self._representation = Representation( - PolyExpAnsatz.from_function(fn=triples.amplifier_Abc, g=self.gain), self.wires - ) + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function(fn=triples.amplifier_Abc, g=self.gain), + ).representation diff --git a/mrmustard/lab_dev/transformations/attenuator.py b/mrmustard/lab_dev/transformations/attenuator.py index 8847e3387..449305b21 100644 --- a/mrmustard/lab_dev/transformations/attenuator.py +++ b/mrmustard/lab_dev/transformations/attenuator.py @@ -21,7 +21,6 @@ from typing import Sequence from .base import Channel -from ...physics.representations import Representation from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -85,7 +84,7 @@ def __init__( transmissivity_trainable: bool = False, transmissivity_bounds: tuple[float | None, float | None] = (0.0, 1.0), ): - super().__init__(modes_out=modes, modes_in=modes, name="Att") + super().__init__(name="Att") (etas,) = list(reshape_params(len(modes), transmissivity=transmissivity)) self._add_parameter( make_parameter( @@ -96,7 +95,8 @@ def __init__( None, ) ) - self._representation = Representation( - PolyExpAnsatz.from_function(fn=triples.attenuator_Abc, eta=self.transmissivity), - self.wires, - ) + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function(fn=triples.attenuator_Abc, eta=self.transmissivity), + ).representation diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index c179d94b4..95343d8d7 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -31,6 +31,7 @@ from mrmustard import math, settings from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz from mrmustard.physics.representations import Representation +from mrmustard.physics.wires import Wires from mrmustard.utils.typing import ComplexMatrix from mrmustard.physics.bargmann_utils import au2Symplectic, symplectic2Au, XY_of_channel from ..circuit_components import CircuitComponent @@ -59,6 +60,28 @@ def from_bargmann( :math:`c * exp(0.5*z^T A z + b^T z)`. """ + @classmethod + @abstractmethod + def from_modes( + cls, + modes_out: Sequence[int], + modes_in: Sequence[int], + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, + name: str | None = None, + ) -> Transformation: + r""" + Initializes a transformation of type ``cls`` given modes and an ansatz. + + Args: + modes_out: The output modes of this transformation. + modes_in: The input modes of this transformation. + ansatz: The ansatz of this transformation. + name: The name of this transformation. + + Returns: + A transformation. + """ + @classmethod @abstractmethod def from_quadrature( @@ -118,17 +141,6 @@ class Operation(Transformation): short_name = "Op" - def __init__( - self, - modes_out: tuple[int, ...] = (), - modes_in: tuple[int, ...] = (), - ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, - name: str | None = None, - ): - super().__init__( - Representation(ansatz=ansatz, wires=[(), (), modes_out, modes_in]), name=name - ) - @classmethod def from_bargmann( cls, @@ -137,7 +149,22 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Operation(modes_out, modes_in, PolyExpAnsatz(*triple), name) + return Operation.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple), name) + + @classmethod + def from_modes( + cls, + modes_out: Sequence[int], + modes_in: Sequence[int], + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, + name: str | None = None, + ) -> Transformation: + modes_out = set(modes_out) + modes_in = set(modes_in) + return Operation( + representation=Representation(ansatz=ansatz, wires=Wires((), (), modes_out, modes_in)), + name=name, + ) @classmethod def from_quadrature( @@ -152,9 +179,9 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Operation(modes_out, modes_in, PolyExpAnsatz(*triple)) + QQ = Operation.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Operation(modes_out, modes_in, BB.ansatz, name) + return Operation.from_modes(modes_out, modes_in, BB.ansatz, name) class Unitary(Operation): @@ -187,7 +214,22 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Unitary(modes_out, modes_in, PolyExpAnsatz(*triple), name) + return Unitary.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple), name) + + @classmethod + def from_modes( + cls, + modes_out: Sequence[int], + modes_in: Sequence[int], + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, + name: str | None = None, + ) -> Transformation: + modes_out = set(modes_out) + modes_in = set(modes_in) + return Unitary( + representation=Representation(ansatz=ansatz, wires=Wires((), (), modes_out, modes_in)), + name=name, + ) @classmethod def from_quadrature( @@ -202,9 +244,9 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Unitary(modes_out, modes_in, PolyExpAnsatz(*triple)) + QQ = Unitary.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Unitary(modes_out, modes_in, BB.ansatz, name) + return Unitary.from_modes(modes_out, modes_in, BB.ansatz, name) @classmethod def from_symplectic(cls, modes, S) -> Unitary: @@ -272,18 +314,6 @@ class Map(Transformation): short_name = "Map" - def __init__( - self, - modes_out: tuple[int, ...] = (), - modes_in: tuple[int, ...] = (), - ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, - name: str | None = None, - ): - super().__init__( - Representation(ansatz=ansatz, wires=[modes_out, modes_in, modes_out, modes_in]), - name=name or self.__class__.__name__, - ) - @classmethod def from_bargmann( cls, @@ -292,7 +322,24 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Map(modes_out, modes_in, PolyExpAnsatz(*triple), name) + return Map.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple), name) + + @classmethod + def from_modes( + cls, + modes_out: Sequence[int], + modes_in: Sequence[int], + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, + name: str | None = None, + ) -> Transformation: + modes_out = set(modes_out) + modes_in = set(modes_in) + return Map( + representation=Representation( + ansatz=ansatz, wires=Wires(modes_out, modes_in, modes_out, modes_in) + ), + name=name, + ) @classmethod def from_quadrature( @@ -307,9 +354,9 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Map(modes_out, modes_in, PolyExpAnsatz(*triple)) + QQ = Map.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Map(modes_out, modes_in, BB.ansatz, name) + return Map.from_modes(modes_out, modes_in, BB.ansatz, name) class Channel(Map): @@ -380,7 +427,24 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Channel(modes_out, modes_in, PolyExpAnsatz(*triple), name) + return Channel.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple), name) + + @classmethod + def from_modes( + cls, + modes_out: Sequence[int], + modes_in: Sequence[int], + ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, + name: str | None = None, + ) -> Transformation: + modes_out = set(modes_out) + modes_in = set(modes_in) + return Channel( + representation=Representation( + ansatz=ansatz, wires=Wires(modes_out, modes_in, modes_out, modes_in) + ), + name=name, + ) @classmethod def from_quadrature( @@ -395,9 +459,9 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Channel(modes_out, modes_in, PolyExpAnsatz(*triple)) + QQ = Channel.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Channel(modes_out, modes_in, BB.ansatz, name) + return Channel.from_modes(modes_out, modes_in, BB.ansatz, name) @classmethod def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Channel: diff --git a/mrmustard/lab_dev/transformations/bsgate.py b/mrmustard/lab_dev/transformations/bsgate.py index 4953a1dee..39c06a473 100644 --- a/mrmustard/lab_dev/transformations/bsgate.py +++ b/mrmustard/lab_dev/transformations/bsgate.py @@ -21,7 +21,6 @@ from typing import Sequence from .base import Unitary -from ...physics.representations import Representation from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter @@ -102,12 +101,13 @@ def __init__( if len(modes) != 2: raise ValueError(f"Expected a pair of modes, found {modes}.") - super().__init__(modes_out=modes, modes_in=modes, name="BSgate") + super().__init__(name="BSgate") self._add_parameter(make_parameter(theta_trainable, theta, "theta", theta_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function( + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function( fn=triples.beamsplitter_gate_Abc, theta=self.theta, phi=self.phi ), - self.wires, - ) + ).representation diff --git a/mrmustard/lab_dev/transformations/cft.py b/mrmustard/lab_dev/transformations/cft.py index 669f07a46..5d26bebbc 100644 --- a/mrmustard/lab_dev/transformations/cft.py +++ b/mrmustard/lab_dev/transformations/cft.py @@ -18,7 +18,6 @@ from typing import Sequence from mrmustard.lab_dev.transformations.base import Map -from mrmustard.physics.representations import Representation from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples @@ -43,14 +42,11 @@ def __init__( self, modes: Sequence[int], ): - super().__init__( - modes_out=modes, + super().__init__(name="CFT") + self._representation = self.from_modes( modes_in=modes, - name="CFT", - ) - self._representation = Representation( - PolyExpAnsatz.from_function( + modes_out=modes, + ansatz=PolyExpAnsatz.from_function( fn=triples.complex_fourier_transform_Abc, n_modes=len(modes) ), - self.wires, - ) + ).representation diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index eeb0616df..7ce9a071f 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -91,14 +91,17 @@ def __init__( x_bounds: tuple[float | None, float | None] = (None, None), y_bounds: tuple[float | None, float | None] = (None, None), ) -> None: - super().__init__(modes_out=modes, modes_in=modes, name="Dgate") + super().__init__(name="Dgate") xs, ys = list(reshape_params(len(modes), x=x, y=y)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function(fn=triples.displacement_gate_Abc, x=self.x, y=self.y), - self.wires, - ) + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function( + fn=triples.displacement_gate_Abc, x=self.x, y=self.y + ), + ).representation def fock_array(self, shape: int | Sequence[int] = None, batched=False) -> ComplexTensor: r""" diff --git a/mrmustard/lab_dev/transformations/fockdamping.py b/mrmustard/lab_dev/transformations/fockdamping.py index 9ff01a718..65bec5840 100644 --- a/mrmustard/lab_dev/transformations/fockdamping.py +++ b/mrmustard/lab_dev/transformations/fockdamping.py @@ -21,7 +21,6 @@ from typing import Sequence from .base import Operation -from ...physics.representations import Representation from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -75,7 +74,7 @@ def __init__( damping_trainable: bool = False, damping_bounds: tuple[float | None, float | None] = (0.0, None), ): - super().__init__(modes_out=modes, modes_in=modes, name="FockDamping") + super().__init__(name="FockDamping") (betas,) = list(reshape_params(len(modes), damping=damping)) self._add_parameter( make_parameter( @@ -86,6 +85,8 @@ def __init__( None, ) ) - self._representation = Representation( - PolyExpAnsatz.from_function(fn=triples.fock_damping_Abc, beta=self.damping), self.wires - ) + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function(fn=triples.fock_damping_Abc, beta=self.damping), + ).representation diff --git a/mrmustard/lab_dev/transformations/ggate.py b/mrmustard/lab_dev/transformations/ggate.py index 5ecbf54e5..ec74411ab 100644 --- a/mrmustard/lab_dev/transformations/ggate.py +++ b/mrmustard/lab_dev/transformations/ggate.py @@ -22,7 +22,6 @@ from mrmustard.utils.typing import RealMatrix from .base import Unitary -from ...physics.representations import Representation from ...physics.ansatz import PolyExpAnsatz from ..utils import make_parameter @@ -55,16 +54,17 @@ def __init__( symplectic: RealMatrix, symplectic_trainable: bool = False, ): - super().__init__(modes_out=modes, modes_in=modes, name="Ggate") + super().__init__(name="Ggate") S = make_parameter(symplectic_trainable, symplectic, "symplectic", (None, None)) self.parameter_set.add_parameter(S) - self._representation = Representation( - PolyExpAnsatz.from_function( + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function( fn=lambda s: Unitary.from_symplectic(modes, s).bargmann_triple(), s=self.parameter_set.symplectic, ), - self.wires, - ) + ).representation @property def symplectic(self): diff --git a/mrmustard/lab_dev/transformations/identity.py b/mrmustard/lab_dev/transformations/identity.py index ab1ce6324..f433ac566 100644 --- a/mrmustard/lab_dev/transformations/identity.py +++ b/mrmustard/lab_dev/transformations/identity.py @@ -51,5 +51,9 @@ def __init__( self, modes: Sequence[int], ): - ansatz = PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)) - super().__init__(modes_out=modes, modes_in=modes, ansatz=ansatz, name="Identity") + super().__init__(name="Identity") + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)), + ).representation diff --git a/mrmustard/lab_dev/transformations/rgate.py b/mrmustard/lab_dev/transformations/rgate.py index f2e769daa..4eeb183ed 100644 --- a/mrmustard/lab_dev/transformations/rgate.py +++ b/mrmustard/lab_dev/transformations/rgate.py @@ -21,7 +21,6 @@ from typing import Sequence from .base import Unitary -from ...physics.representations import Representation from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -60,9 +59,11 @@ def __init__( phi_trainable: bool = False, phi_bounds: tuple[float | None, float | None] = (0.0, None), ): - super().__init__(modes_out=modes, modes_in=modes, name="Rgate") + super().__init__(name="Rgate") (phis,) = list(reshape_params(len(modes), phi=phi)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function(fn=triples.rotation_gate_Abc, theta=self.phi), self.wires - ) + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function(fn=triples.rotation_gate_Abc, theta=self.phi), + ).representation diff --git a/mrmustard/lab_dev/transformations/s2gate.py b/mrmustard/lab_dev/transformations/s2gate.py index b9661a22b..bdb4779c7 100644 --- a/mrmustard/lab_dev/transformations/s2gate.py +++ b/mrmustard/lab_dev/transformations/s2gate.py @@ -21,7 +21,6 @@ from typing import Sequence from .base import Unitary -from ...physics.representations import Representation from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter @@ -85,12 +84,13 @@ def __init__( if len(modes) != 2: raise ValueError(f"Expected a pair of modes, found {modes}.") - super().__init__(modes_out=modes, modes_in=modes, name="S2gate") + super().__init__(name="S2gate") self._add_parameter(make_parameter(r_trainable, r, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function( + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function( fn=triples.twomode_squeezing_gate_Abc, r=self.r, phi=self.phi ), - self.wires, - ) + ).representation diff --git a/mrmustard/lab_dev/transformations/sgate.py b/mrmustard/lab_dev/transformations/sgate.py index d5e687a70..4b0677499 100644 --- a/mrmustard/lab_dev/transformations/sgate.py +++ b/mrmustard/lab_dev/transformations/sgate.py @@ -21,7 +21,6 @@ from typing import Sequence from .base import Unitary -from ...physics.representations import Representation from ...physics.ansatz import PolyExpAnsatz from ...physics import triples from ..utils import make_parameter, reshape_params @@ -91,11 +90,14 @@ def __init__( r_bounds: tuple[float | None, float | None] = (0.0, None), phi_bounds: tuple[float | None, float | None] = (None, None), ): - super().__init__(modes_out=modes, modes_in=modes, name="Sgate") + super().__init__(name="Sgate") rs, phis = list(reshape_params(len(modes), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = Representation( - PolyExpAnsatz.from_function(fn=triples.squeezing_gate_Abc, r=self.r, delta=self.phi), - self.wires, - ) + self._representation = self.from_modes( + modes_in=modes, + modes_out=modes, + ansatz=PolyExpAnsatz.from_function( + fn=triples.squeezing_gate_Abc, r=self.r, delta=self.phi + ), + ).representation diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index aa6e384bf..a8308f16c 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -56,7 +56,7 @@ class TestUnitary: @pytest.mark.parametrize("name", [None, "my_unitary"]) @pytest.mark.parametrize("modes", [{0}, {0, 1}, {3, 19, 2}]) def test_init(self, name, modes): - gate = Unitary(modes, modes, name=name) + gate = Unitary.from_modes(modes, modes, name=name) assert gate.name[:1] == (name or "U")[:1] assert list(gate.modes) == sorted(modes) @@ -127,7 +127,7 @@ class TestChannel: @pytest.mark.parametrize("name", [None, "my_channel"]) @pytest.mark.parametrize("modes", [{0}, {0, 1}, {3, 19, 2}]) def test_init(self, name, modes): - gate = Channel(modes, modes, name=name) + gate = Channel.from_modes(modes, modes, name=name) assert gate.name[:2] == (name or "Ch")[:2] assert list(gate.modes) == sorted(modes) From a53a2be47c8b006ca8c99121b5ddea76376d82b7 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 16 Oct 2024 11:50:43 -0400 Subject: [PATCH 59/87] progress --- mrmustard/lab_dev/circuit_components.py | 33 ++++++++----------- .../circuit_components_utils/b_to_ps.py | 1 - .../circuit_components_utils/b_to_q.py | 3 ++ .../circuit_components_utils/trace_out.py | 4 ++- mrmustard/lab_dev/states/base.py | 14 ++++---- mrmustard/lab_dev/transformations/base.py | 18 +++++----- mrmustard/physics/representations.py | 19 +++++++---- tests/test_lab_dev/test_circuit_components.py | 8 ++--- tests/test_lab_dev/test_circuits.py | 2 +- .../test_states/test_states_base.py | 9 +++-- .../test_transformations_base.py | 12 +++---- 11 files changed, 65 insertions(+), 58 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 5f29c8530..af89f5811 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -113,7 +113,7 @@ def _deserialize(cls, data: dict) -> CircuitComponent: if "rep_class" in data: rep_class, wires, name = map(data.pop, ["rep_class", "wires", "name"]) rep = locate(rep_class).from_dict(data) - return cls._from_attributes(rep, Wires(*map(set, wires)), name=name) + return cls._from_attributes(Representation(rep, Wires(*map(set, wires))), name=name) return cls(**data) @@ -239,7 +239,7 @@ def from_bargmann( """ ansatz = PolyExpAnsatz(*triple) wires = Wires(set(modes_out_bra), set(modes_in_bra), set(modes_out_ket), set(modes_in_ket)) - return cls._from_attributes(ansatz, wires, name) + return cls._from_attributes(Representation(ansatz, wires), name) @classmethod def from_quadrature( @@ -276,9 +276,9 @@ def from_quadrature( QtoB_ok = BtoQ(modes_out_ket, phi).inverse() # output ket QtoB_ik = BtoQ(modes_in_ket, phi).inverse().dual # input ket # NOTE: the representation is Bargmann here because we use the inverse of BtoQ on the B side - QQQQ = CircuitComponent._from_attributes(PolyExpAnsatz(*triple), wires) + QQQQ = CircuitComponent._from_attributes(Representation(PolyExpAnsatz(*triple), wires)) BBBB = QtoB_ib @ (QtoB_ik @ QQQQ @ QtoB_ok) @ QtoB_ob - return cls._from_attributes(BBBB.ansatz, wires, name) + return cls._from_attributes(Representation(BBBB.ansatz, wires), name) def to_quadrature(self, phi: float = 0.0) -> CircuitComponent: r""" @@ -346,8 +346,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor: @classmethod def _from_attributes( cls, - ansatz: Ansatz, - wires: Wires, + representation: Representation, name: str | None = None, ) -> CircuitComponent: r""" @@ -376,14 +375,10 @@ def _from_attributes( A circuit component with the given attributes. """ types = {"Ket", "DM", "Unitary", "Operation", "Channel", "Map"} - rep = Representation(ansatz, wires) for tp in cls.mro(): if tp.__name__ in types: - ret = tp() - ret._name = name - ret._representation = rep - return ret - return CircuitComponent(rep, name) + return tp(representation=representation, name=name) + return CircuitComponent(representation, name) def auto_shape(self, **_) -> tuple[int, ...]: r""" @@ -496,7 +491,7 @@ def to_bargmann(self) -> CircuitComponent: ret = self._getitem_builtin(self.modes) ret._representation = rep except TypeError: - ret = self._from_attributes(rep.ansatz, rep.wires, self.name) + ret = self._from_attributes(rep, self.name) if "manual_shape" in ret.__dict__: del ret.manual_shape return ret @@ -527,7 +522,7 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: ret = self._getitem_builtin(self.modes) ret._representation = rep except TypeError: - ret = self._from_attributes(rep.ansatz, rep.wires, self.name) + ret = self._from_attributes(rep, self.name) if "manual_shape" in ret.__dict__: del ret.manual_shape return ret @@ -590,7 +585,7 @@ def __add__(self, other: CircuitComponent) -> CircuitComponent: raise ValueError("Cannot add components with different wires.") ansatz = self.ansatz + other.ansatz name = self.name if self.name == other.name else "" - return self._from_attributes(ansatz, self.wires, name) + return self._from_attributes(Representation(ansatz, self.wires), name) def __eq__(self, other) -> bool: r""" @@ -623,13 +618,13 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: if isinstance(other, (numbers.Number, np.ndarray)): return self * other result = self._representation @ other._representation - return CircuitComponent._from_attributes(result.ansatz, result.wires, None) + return CircuitComponent._from_attributes(result, None) def __mul__(self, other: Scalar) -> CircuitComponent: r""" Implements the multiplication by a scalar from the right. """ - return self._from_attributes(self.ansatz * other, self.wires, self.name) + return self._from_attributes(Representation(self.ansatz * other, self.wires), self.name) def __repr__(self) -> str: ansatz = self.ansatz @@ -728,13 +723,13 @@ def __sub__(self, other: CircuitComponent) -> CircuitComponent: raise ValueError("Cannot subtract components with different wires.") ansatz = self.ansatz - other.ansatz name = self.name if self.name == other.name else "" - return self._from_attributes(ansatz, self.wires, name) + return self._from_attributes(Representation(ansatz, self.wires), name) def __truediv__(self, other: Scalar) -> CircuitComponent: r""" Implements the division by a scalar for circuit components. """ - return self._from_attributes(self.ansatz / other, self.wires, self.name) + return self._from_attributes(Representation(self.ansatz / other, self.wires), self.name) def _ipython_display_(self): # both reps might return None diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index cca0ca8b3..b637f23da 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -52,4 +52,3 @@ def __init__( ), ).representation self._add_parameter(Constant(s, "s")) - self.s = s diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index 94acfcca5..b66a94541 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -24,6 +24,7 @@ from ..transformations.base import Operation from ...physics.ansatz import PolyExpAnsatz +from ...physics.representations import RepEnum __all__ = ["BtoQ"] @@ -53,3 +54,5 @@ def __init__( ), ).representation self._add_parameter(Constant(phi, "phi")) + for i in self.wires.output.ids: + self.representation._wire_reps[i] = (RepEnum.QUADRATURE, float(self.phi.value), tuple()) diff --git a/mrmustard/lab_dev/circuit_components_utils/trace_out.py b/mrmustard/lab_dev/circuit_components_utils/trace_out.py index f689d40a7..0c61a2f57 100644 --- a/mrmustard/lab_dev/circuit_components_utils/trace_out.py +++ b/mrmustard/lab_dev/circuit_components_utils/trace_out.py @@ -92,5 +92,7 @@ def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitCompon ansatz = other.ansatz.trace(idx_z, idx_zconj) wires, _ = other.wires @ self.wires - cpt = other._from_attributes(ansatz, wires) # pylint:disable=protected-access + cpt = other._from_attributes( + Representation(ansatz, wires) + ) # pylint:disable=protected-access return math.sum(cpt.ansatz.scalar) if len(cpt.wires) == 0 else cpt diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 8a75231fa..b74126eab 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -968,7 +968,7 @@ def __getitem__(self, modes: int | Sequence[int]) -> State: ansatz = self.ansatz.trace(idxz, idxz_conj) return self.__class__._from_attributes( - ansatz, wires, self.name + Representation(ansatz, wires), self.name ) # pylint: disable=protected-access def __rshift__(self, other: CircuitComponent) -> CircuitComponent: @@ -985,9 +985,8 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: if not isinstance(result, CircuitComponent): return result # scalar case handled here - w = result.wires - if not w.input and w.bra.modes == w.ket.modes: - return DM.from_modes(w.modes, result.ansatz) + if not result.wires.input and result.wires.bra.modes == result.wires.ket.modes: + return DM(result.representation) return result @@ -1186,7 +1185,7 @@ def dm(self) -> DM: The ``DM`` object obtained from this ``Ket``. """ dm = self @ self.adjoint - ret = DM._from_attributes(dm.ansatz, dm.wires, self.name) + ret = DM._from_attributes(dm.representation, self.name) ret.manual_shape = self.manual_shape + self.manual_shape return ret @@ -1280,7 +1279,8 @@ def __rshift__(self, other: CircuitComponent | Scalar) -> CircuitComponent | Bat if not result.wires.input: if not result.wires.bra: - return Ket.from_modes(result.wires.modes, result.ansatz) + print("result", result.representation._wire_reps) + return Ket(result.representation) elif result.wires.bra.modes == result.wires.ket.modes: - result = DM.from_modes(result.wires.modes, result.ansatz) + return DM(result.representation) return result diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index 95343d8d7..3db118b8d 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -121,13 +121,16 @@ def inverse(self) -> Transformation: # compute the inverse A, b, _ = self.dual.ansatz.conj.triple # apply X(.)X almost_inverse = self._from_attributes( - PolyExpAnsatz(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 + 0j), self.wires + Representation( + PolyExpAnsatz(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 + 0j), self.wires + ) ) almost_identity = self @ almost_inverse invert_this_c = almost_identity.ansatz.c actual_inverse = self._from_attributes( - PolyExpAnsatz(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 / invert_this_c), - self.wires, + Representation( + PolyExpAnsatz(math.inv(A[0]), -math.inv(A[0]) @ b[0], 1 / invert_this_c), self.wires + ), self.name + "_inv", ) return actual_inverse @@ -276,8 +279,7 @@ def random(cls, modes, max_r=1): def inverse(self) -> Unitary: unitary_dual = self.dual return Unitary._from_attributes( - ansatz=unitary_dual.ansatz, - wires=unitary_dual.wires, + representation=unitary_dual.representation, name=unitary_dual.name, ) @@ -295,9 +297,9 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: ret = super().__rshift__(other) if isinstance(other, Unitary): - return Unitary._from_attributes(ret.ansatz, ret.wires) + return Unitary._from_attributes(ret.representation) elif isinstance(other, Channel): - return Channel._from_attributes(ret.ansatz, ret.wires) + return Channel._from_attributes(ret.representation) return ret @@ -490,5 +492,5 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: """ ret = super().__rshift__(other) if isinstance(other, (Channel, Unitary)): - return Channel._from_attributes(ret.ansatz, ret.wires) + return Channel._from_attributes(ret.representation) return ret diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index bc28ca476..f601240e2 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -39,7 +39,8 @@ class RepEnum(Enum): r""" - An enum to represent what representation a wire is in. + An enum to represent what representation a wire is in. Also keeps track + of representation conversions. """ NONETYPE = 0 @@ -128,7 +129,7 @@ def __init__( self._wires = wires self._wire_reps = wire_reps or dict.fromkeys( - wires.indices, (RepEnum.from_ansatz(ansatz), None) + wires.ids, (RepEnum.from_ansatz(ansatz), None, tuple()) ) @property @@ -141,7 +142,7 @@ def adjoint(self) -> Representation: kets = self.wires.ket.indices ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None wires = self.wires.adjoint - return Representation(ansatz, wires, self._wire_reps) + return Representation(ansatz, wires, None) @property def ansatz(self) -> Ansatz | None: @@ -162,7 +163,7 @@ def dual(self) -> Representation: ob = self.wires.bra.output.indices ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None wires = self.wires.dual - return Representation(ansatz, wires, self._wire_reps) + return Representation(ansatz, wires, None) @property def wires(self) -> Wires | None: @@ -289,7 +290,7 @@ def __eq__(self, other): return ( self.ansatz == other.ansatz and self.wires == other.wires - and self._wire_reps == other._wire_reps + # and self._wire_reps == other._wire_reps ) return False @@ -305,4 +306,10 @@ def __matmul__(self, other: Representation): rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] rep = rep.reorder(perm) if perm else rep - return Representation(rep, wires_result) + wire_reps = {} + for id in wires_result.ids: + try: + wire_reps[id] = self._wire_reps[id] + except KeyError: + wire_reps[id] = other._wire_reps[id] + return Representation(rep, wires_result, wire_reps) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 5aa3edc8c..fd116856c 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -99,9 +99,9 @@ def test_modes_init_out_of_order(self): def test_from_attributes(self, x, y): cc = Dgate([1, 8], x=x, y=y) - cc1 = Dgate._from_attributes(cc.ansatz, cc.wires, cc.name) - cc2 = Unitary._from_attributes(cc.ansatz, cc.wires, cc.name) - cc3 = CircuitComponent._from_attributes(cc.ansatz, cc.wires, cc.name) + cc1 = Dgate._from_attributes(cc.representation, cc.name) + cc2 = Unitary._from_attributes(cc.representation, cc.name) + cc3 = CircuitComponent._from_attributes(cc.representation, cc.name) assert cc1 == cc assert cc2 == cc @@ -113,7 +113,7 @@ def test_from_attributes(self, x, y): def test_from_to_quadrature(self): c = Dgate([0], x=0.1, y=0.2) >> Sgate([0], r=1.0, phi=0.1) - cc = CircuitComponent._from_attributes(c.ansatz, c.wires, c.name) + cc = CircuitComponent._from_attributes(c.representation, c.name) ccc = CircuitComponent.from_quadrature(tuple(), tuple(), (0,), (0,), cc.quadrature_triple()) assert cc == ccc diff --git a/tests/test_lab_dev/test_circuits.py b/tests/test_lab_dev/test_circuits.py index f25db3d69..6acaa8973 100644 --- a/tests/test_lab_dev/test_circuits.py +++ b/tests/test_lab_dev/test_circuits.py @@ -179,7 +179,7 @@ def test_repr(self): n12 = Number([0, 1], n=3) n2 = Number([2], n=3) cc = CircuitComponent._from_attributes( - bs01.ansatz, bs01.wires, "my_cc" + bs01.representation, "my_cc" ) # pylint: disable=protected-access assert repr(Circuit()) == "" diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index 8e657ab1e..4f1a6ca17 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -357,12 +357,11 @@ def test_rshift(self): ket = Coherent([0, 1], 1) unitary = Dgate([0], 1) u_component = CircuitComponent._from_attributes( - unitary.ansatz, unitary.wires, unitary.name + unitary.representation, unitary.name ) # pylint: disable=protected-access channel = Attenuator([1], 1) ch_component = CircuitComponent._from_attributes( - channel.ansatz, - channel.wires, + channel.representation, channel.name, ) # pylint: disable=protected-access @@ -823,11 +822,11 @@ def test_rshift(self): ket = Coherent([0, 1], 1) unitary = Dgate([0], 1) u_component = CircuitComponent._from_attributes( - unitary.ansatz, unitary.wires, unitary.name + unitary.representation, unitary.name ) # pylint: disable=protected-access channel = Attenuator([1], 1) ch_component = CircuitComponent._from_attributes( - channel.ansatz, channel.wires, channel.name + channel.representation, channel.name ) # pylint: disable=protected-access dm = ket >> channel diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index a8308f16c..7fdc48ad3 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -66,11 +66,11 @@ def test_rshift(self): unitary1 = Dgate([0, 1], 1) unitary2 = Dgate([1, 2], 2) u_component = CircuitComponent._from_attributes( - unitary1.ansatz, unitary1.wires, unitary1.name + unitary1.representation, unitary1.name ) # pylint: disable=protected-access channel = Attenuator([1], 1) ch_component = CircuitComponent._from_attributes( - channel.ansatz, channel.wires, channel.name + channel.representation, channel.name ) # pylint: disable=protected-access assert isinstance(unitary1 >> unitary2, Unitary) @@ -81,7 +81,7 @@ def test_rshift(self): def test_repr(self): unitary1 = Dgate([0, 1], 1) u_component = CircuitComponent._from_attributes( - unitary1.ansatz, unitary1.wires, unitary1.name + unitary1.representation, unitary1.name ) # pylint: disable=protected-access assert repr(unitary1) == "Dgate(modes=[0, 1], name=Dgate, repr=PolyExpAnsatz)" assert repr(unitary1.to_fock(5)) == "Dgate(modes=[0, 1], name=Dgate, repr=ArrayAnsatz)" @@ -149,12 +149,12 @@ def test_init_from_bargmann(self): def test_rshift(self): unitary = Dgate([0, 1], 1) u_component = CircuitComponent._from_attributes( - unitary.ansatz, unitary.wires, unitary.name + unitary.representation, unitary.name ) # pylint: disable=protected-access channel1 = Attenuator([1, 2], 0.9) channel2 = Attenuator([2, 3], 0.9) ch_component = CircuitComponent._from_attributes( - channel1.ansatz, channel1.wires, channel1.name + channel1.representation, channel1.name ) # pylint: disable=protected-access assert isinstance(channel1 >> unitary, Channel) @@ -165,7 +165,7 @@ def test_rshift(self): def test_repr(self): channel1 = Attenuator([0, 1], 0.9) ch_component = CircuitComponent._from_attributes( - channel1.ansatz, channel1.wires, channel1.name + channel1.representation, channel1.name ) # pylint: disable=protected-access assert repr(channel1) == "Attenuator(modes=[0, 1], name=Att, repr=PolyExpAnsatz)" From fabfda05002b5c04f01663d0d23c7c41d79806c2 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 17 Oct 2024 09:38:55 -0400 Subject: [PATCH 60/87] print rem --- mrmustard/lab_dev/states/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index b74126eab..cf0f8dff0 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -1279,7 +1279,6 @@ def __rshift__(self, other: CircuitComponent | Scalar) -> CircuitComponent | Bat if not result.wires.input: if not result.wires.bra: - print("result", result.representation._wire_reps) return Ket(result.representation) elif result.wires.bra.modes == result.wires.ket.modes: return DM(result.representation) From 5c61ed4387e79a572c3f0da24af462a708173ea1 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 17 Oct 2024 10:13:30 -0400 Subject: [PATCH 61/87] progress --- .../circuit_components_utils/b_to_ps.py | 3 ++ .../circuit_components_utils/b_to_q.py | 2 +- mrmustard/physics/representations.py | 33 ++++++++++++------- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index b637f23da..5561cb960 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -24,6 +24,7 @@ from ..transformations.base import Map from ...physics.ansatz import PolyExpAnsatz +from ...physics.representations import RepEnum __all__ = ["BtoPS"] @@ -52,3 +53,5 @@ def __init__( ), ).representation self._add_parameter(Constant(s, "s")) + for i in self.wires.output.indices: + self.representation._wire_reps[i] = (RepEnum.PHASESPACE, float(self.s.value), tuple()) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index b66a94541..e58e9816a 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -54,5 +54,5 @@ def __init__( ), ).representation self._add_parameter(Constant(phi, "phi")) - for i in self.wires.output.ids: + for i in self.wires.output.indices: self.representation._wire_reps[i] = (RepEnum.QUADRATURE, float(self.phi.value), tuple()) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index f601240e2..1dd8d17f6 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -129,7 +129,7 @@ def __init__( self._wires = wires self._wire_reps = wire_reps or dict.fromkeys( - wires.ids, (RepEnum.from_ansatz(ansatz), None, tuple()) + wires.indices, (RepEnum.from_ansatz(ansatz), None, tuple()) ) @property @@ -142,7 +142,12 @@ def adjoint(self) -> Representation: kets = self.wires.ket.indices ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None wires = self.wires.adjoint - return Representation(ansatz, wires, None) + wire_reps = {} + for i, j in enumerate(kets): + wire_reps[i] = self._wire_reps[j] + for i, j in enumerate(bras): + wire_reps[i + len(kets)] = self._wire_reps[j] + return Representation(ansatz, wires, wire_reps) @property def ansatz(self) -> Ansatz | None: @@ -163,7 +168,16 @@ def dual(self) -> Representation: ob = self.wires.bra.output.indices ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None wires = self.wires.dual - return Representation(ansatz, wires, None) + wire_reps = {} + for i, j in enumerate(ib): + wire_reps[i] = self._wire_reps[j] + for i, j in enumerate(ob): + wire_reps[i + len(ib)] = self._wire_reps[j] + for i, j in enumerate(ik): + wire_reps[i + len(ib + ob)] = self._wire_reps[j] + for i, j in enumerate(ok): + wire_reps[i + len(ib + ob + ik)] = self._wire_reps[j] + return Representation(ansatz, wires, wire_reps) @property def wires(self) -> Wires | None: @@ -290,7 +304,7 @@ def __eq__(self, other): return ( self.ansatz == other.ansatz and self.wires == other.wires - # and self._wire_reps == other._wire_reps + and self._wire_reps == other._wire_reps ) return False @@ -306,10 +320,7 @@ def __matmul__(self, other: Representation): rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] rep = rep.reorder(perm) if perm else rep - wire_reps = {} - for id in wires_result.ids: - try: - wire_reps[id] = self._wire_reps[id] - except KeyError: - wire_reps[id] = other._wire_reps[id] - return Representation(rep, wires_result, wire_reps) + + # TODO: update wire reps + + return Representation(rep, wires_result) From 497c434b246236c3784e83f70dcebf71dba09458 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 17 Oct 2024 14:53:31 -0400 Subject: [PATCH 62/87] matmul working --- mrmustard/physics/representations.py | 18 +++++++++++++++--- mrmustard/physics/wires.py | 17 +++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 1dd8d17f6..a53fd8336 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -321,6 +321,18 @@ def __matmul__(self, other: Representation): rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] rep = rep.reorder(perm) if perm else rep - # TODO: update wire reps - - return Representation(rep, wires_result) + wire_reps = {} + for id in wires_result.ids: + if id in self.wires.ids: + temp_rep = self + else: + temp_rep = other + for t in (0, 1, 2, 3, 4, 5): + try: + idx = temp_rep.wires.ids_index_dicts[t][id] + n_idx = wires_result.ids_index_dicts[t][id] + wire_reps[n_idx] = temp_rep._wire_reps[idx] + break + except KeyError: + continue + return Representation(rep, wires_result, wire_reps) diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index 422701342..c06c8b83e 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -308,6 +308,23 @@ def index_dicts(self) -> list[dict[int, int]]: for t, lst in enumerate(self.sorted_args) ] + @property + def ids_index_dicts(self) -> list[dict[int, int]]: + r""" + A list of dictionary mapping ids to indices, one for each of the subsets + (``output.bra``, ``input.bra``, ``output.ket``, ``input.ket``, + ``output.classical``, and ``input.classical``). + + If subsets are taken, ``ids_index_dicts`` refers to the parent object rather than to the + child. + """ + if self.original: + return self.original.ids_index_dicts + return [ + {v: self.index_dicts[t][k] for k, v in self.ids_dicts[t].items()} + for t in (0, 1, 2, 3, 4, 5) + ] + @cached_property def indices(self) -> tuple[int, ...]: r""" From d5085ce4a47340f4dae78fc03c8903d87344e44c Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 17 Oct 2024 15:55:54 -0400 Subject: [PATCH 63/87] rename --- .../circuit_components_utils/b_to_ps.py | 2 +- .../circuit_components_utils/b_to_q.py | 2 +- mrmustard/physics/representations.py | 38 ++++++++++--------- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index 5561cb960..bcce01a7b 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -54,4 +54,4 @@ def __init__( ).representation self._add_parameter(Constant(s, "s")) for i in self.wires.output.indices: - self.representation._wire_reps[i] = (RepEnum.PHASESPACE, float(self.s.value), tuple()) + self.representation._idx_reps[i] = (RepEnum.PHASESPACE, float(self.s.value), tuple()) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index e58e9816a..68a700f5f 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -55,4 +55,4 @@ def __init__( ).representation self._add_parameter(Constant(phi, "phi")) for i in self.wires.output.indices: - self.representation._wire_reps[i] = (RepEnum.QUADRATURE, float(self.phi.value), tuple()) + self.representation._idx_reps[i] = (RepEnum.QUADRATURE, float(self.phi.value), tuple()) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index a53fd8336..93f6e10a9 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -79,20 +79,23 @@ class Representation: A representation handles the underlying ansatz, wires and keeps track of each wire's representation. + The dictionary to keep track of representations maps the indices of the wires + to a tuple of the form ``(RepEnum, parameter, (coupled_indices, ...))``. + Args: ansatz: An ansatz for this representation. wires: The wires of this representation. Alternatively, can be a ``(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket)`` sequence where if any of the modes are out of order the ansatz will be reordered. - wire_reps: An optional dictionary for keeping track of each wire's representation. + idx_reps: An optional dictionary for keeping track of each wire's representation. """ def __init__( self, ansatz: Ansatz | None = None, wires: Wires | Sequence[tuple[int]] | None = None, - wire_reps: dict | None = None, + idx_reps: dict | None = None, ) -> None: self._ansatz = ansatz @@ -128,7 +131,7 @@ def __init__( self._ansatz = ansatz.reorder(tuple(perm)) self._wires = wires - self._wire_reps = wire_reps or dict.fromkeys( + self._idx_reps = idx_reps or dict.fromkeys( wires.indices, (RepEnum.from_ansatz(ansatz), None, tuple()) ) @@ -142,12 +145,12 @@ def adjoint(self) -> Representation: kets = self.wires.ket.indices ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None wires = self.wires.adjoint - wire_reps = {} + idx_reps = {} for i, j in enumerate(kets): - wire_reps[i] = self._wire_reps[j] + idx_reps[i] = self._idx_reps[j] for i, j in enumerate(bras): - wire_reps[i + len(kets)] = self._wire_reps[j] - return Representation(ansatz, wires, wire_reps) + idx_reps[i + len(kets)] = self._idx_reps[j] + return Representation(ansatz, wires, idx_reps) @property def ansatz(self) -> Ansatz | None: @@ -168,16 +171,16 @@ def dual(self) -> Representation: ob = self.wires.bra.output.indices ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None wires = self.wires.dual - wire_reps = {} + idx_reps = {} for i, j in enumerate(ib): - wire_reps[i] = self._wire_reps[j] + idx_reps[i] = self._idx_reps[j] for i, j in enumerate(ob): - wire_reps[i + len(ib)] = self._wire_reps[j] + idx_reps[i + len(ib)] = self._idx_reps[j] for i, j in enumerate(ik): - wire_reps[i + len(ib + ob)] = self._wire_reps[j] + idx_reps[i + len(ib + ob)] = self._idx_reps[j] for i, j in enumerate(ok): - wire_reps[i + len(ib + ob + ik)] = self._wire_reps[j] - return Representation(ansatz, wires, wire_reps) + idx_reps[i + len(ib + ob + ik)] = self._idx_reps[j] + return Representation(ansatz, wires, idx_reps) @property def wires(self) -> Wires | None: @@ -304,13 +307,14 @@ def __eq__(self, other): return ( self.ansatz == other.ansatz and self.wires == other.wires - and self._wire_reps == other._wire_reps + and self._idx_reps == other._idx_reps ) return False def __matmul__(self, other: Representation): wires_result, perm = self.wires @ other.wires idx_z, idx_zconj = self._matmul_indices(other) + if type(self.ansatz) is type(other.ansatz): self_ansatz = self.ansatz other_ansatz = other.ansatz @@ -321,7 +325,7 @@ def __matmul__(self, other: Representation): rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] rep = rep.reorder(perm) if perm else rep - wire_reps = {} + idx_reps = {} for id in wires_result.ids: if id in self.wires.ids: temp_rep = self @@ -331,8 +335,8 @@ def __matmul__(self, other: Representation): try: idx = temp_rep.wires.ids_index_dicts[t][id] n_idx = wires_result.ids_index_dicts[t][id] - wire_reps[n_idx] = temp_rep._wire_reps[idx] + idx_reps[n_idx] = temp_rep._idx_reps[idx] break except KeyError: continue - return Representation(rep, wires_result, wire_reps) + return Representation(rep, wires_result, idx_reps) From 5bac7994f2127088402655d194d9da8366b3ad71 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 21 Oct 2024 17:23:09 -0400 Subject: [PATCH 64/87] initial test file --- tests/test_physics/test_representations.py | 54 ++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/test_physics/test_representations.py diff --git a/tests/test_physics/test_representations.py b/tests/test_physics/test_representations.py new file mode 100644 index 000000000..0f6c344e2 --- /dev/null +++ b/tests/test_physics/test_representations.py @@ -0,0 +1,54 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Representation class.""" + +# pylint: disable=missing-function-docstring + +from unittest.mock import patch + +from ipywidgets import HTML +import pytest + +from mrmustard.physics.representations import Representation, RepEnum +from mrmustard.physics.wires import Wires +from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz + +from ..random import Abc_triple + + +class TestRepresentation: + r""" + Tests for the Representation class. + """ + + Abc_n1 = Abc_triple(1) + Abc_n2 = Abc_triple(2) + Abc_n3 = Abc_triple(3) + + @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) + def test_init(self, triple): + empty_rep = Representation() + assert empty_rep.ansatz is None + assert empty_rep.wires == Wires() + assert empty_rep._idx_reps == {} + + ansatz = PolyExpAnsatz(*triple) + wires = Wires() + rep = Representation(ansatz, wires) + assert rep.ansatz == ansatz + assert rep.wires == wires + assert rep._idx_reps == dict.fromkeys( + wires.indices, (RepEnum.from_ansatz(ansatz), None, tuple()) + ) From f756139beeae89083b2e56450bc8d49213443c59 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 22 Oct 2024 15:05:02 -0400 Subject: [PATCH 65/87] pylint --- .pylintrc | 2 +- mrmustard/lab_dev/circuit_components.py | 2 +- .../circuit_components_utils/b_to_ps.py | 1 - .../circuit_components_utils/b_to_q.py | 1 - .../circuit_components_utils/trace_out.py | 4 +-- mrmustard/lab_dev/states/base.py | 6 ++--- mrmustard/physics/__init__.py | 1 - mrmustard/physics/ansatz/array_ansatz.py | 2 +- mrmustard/physics/ansatz/polyexp_ansatz.py | 2 +- mrmustard/physics/wires.py | 12 ++++----- mrmustard/widgets/__init__.py | 1 - tests/test_lab/test_gates_fock.py | 1 - tests/test_lab/test_states.py | 10 +++---- tests/test_lab_dev/test_circuit_components.py | 6 ++--- .../test_b_to_ps.py | 6 ++--- .../test_b_to_q.py | 2 +- .../test_trace_out.py | 2 +- tests/test_lab_dev/test_circuits.py | 6 ++--- tests/test_lab_dev/test_samplers.py | 14 +++------- .../test_lab_dev/test_states/test_coherent.py | 2 +- .../test_states/test_displaced_squeezed.py | 2 +- tests/test_lab_dev/test_states/test_number.py | 2 +- .../test_states/test_quadrature_eigenstate.py | 2 +- .../test_states/test_squeezed_vacuum.py | 2 +- .../test_states/test_states_base.py | 20 +++++--------- .../test_states/test_states_visualization.py | 2 +- .../test_lab_dev/test_states/test_thermal.py | 2 +- .../test_two_mode_squeezed_vacuum.py | 2 +- tests/test_lab_dev/test_states/test_vacuum.py | 2 +- .../test_transformations/test_amplifier.py | 2 +- .../test_transformations/test_attenuator.py | 2 +- .../test_transformations/test_bsgate.py | 2 +- .../test_transformations/test_dgate.py | 2 +- .../test_transformations/test_fockdamping.py | 2 +- .../test_transformations/test_identity.py | 2 +- .../test_transformations/test_rgate.py | 2 +- .../test_transformations/test_s2gate.py | 2 +- .../test_transformations/test_sgate.py | 2 +- .../test_transformations_base.py | 26 +++++-------------- tests/test_math/test_backend_manager.py | 2 +- .../test_ansatz/test_array_ansatz.py | 6 ++--- .../test_ansatz/test_polyexp_ansatz.py | 6 ++--- tests/test_physics/test_representations.py | 5 +--- tests/test_physics/test_wires.py | 8 +++--- 44 files changed, 75 insertions(+), 115 deletions(-) diff --git a/.pylintrc b/.pylintrc index ba760f8c4..47a5c2046 100644 --- a/.pylintrc +++ b/.pylintrc @@ -28,4 +28,4 @@ ignored-classes=numpy,tensorflow,scipy,networkx,strawberryfields,thewalrus # can either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). -disable=fixme,no-member,line-too-long,invalid-name,too-many-lines,redefined-builtin,too-many-locals,duplicate-code,too-many-arguments,too-few-public-methods,no-else-return,isinstance-second-argument-not-valid-type,no-self-argument, arguments-differ +disable=fixme,no-member,line-too-long,invalid-name,too-many-lines,redefined-builtin,too-many-locals,duplicate-code,too-many-arguments,too-few-public-methods,no-else-return,isinstance-second-argument-not-valid-type,no-self-argument, arguments-differ, protected-access diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index c63674567..88f598ab3 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -16,7 +16,7 @@ A base class for the components of quantum circuits. """ -# pylint: disable=super-init-not-called, protected-access, import-outside-toplevel +# pylint: disable=super-init-not-called, import-outside-toplevel from __future__ import annotations from inspect import signature diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index a1eb2bd90..3798661a7 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -15,7 +15,6 @@ """ The class representing an operation that changes Bargmann into phase space. """ -# pylint: disable=protected-access from __future__ import annotations from typing import Sequence diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index 29c1e0ad3..81db9afd5 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -15,7 +15,6 @@ """ The class representing an operation that changes Bargmann into quadrature. """ -# pylint: disable=protected-access from __future__ import annotations from typing import Sequence diff --git a/mrmustard/lab_dev/circuit_components_utils/trace_out.py b/mrmustard/lab_dev/circuit_components_utils/trace_out.py index 0c61a2f57..07cc83801 100644 --- a/mrmustard/lab_dev/circuit_components_utils/trace_out.py +++ b/mrmustard/lab_dev/circuit_components_utils/trace_out.py @@ -92,7 +92,5 @@ def __custom_rrshift__(self, other: CircuitComponent | complex) -> CircuitCompon ansatz = other.ansatz.trace(idx_z, idx_zconj) wires, _ = other.wires @ self.wires - cpt = other._from_attributes( - Representation(ansatz, wires) - ) # pylint:disable=protected-access + cpt = other._from_attributes(Representation(ansatz, wires)) return math.sum(cpt.ansatz.scalar) if len(cpt.wires) == 0 else cpt diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index cf0f8dff0..d80300136 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=abstract-method, chained-comparison, use-dict-literal, protected-access, inconsistent-return-statements +# pylint: disable=abstract-method, chained-comparison, use-dict-literal, inconsistent-return-statements """ This module contains the base classes for the available quantum states. @@ -967,9 +967,7 @@ def __getitem__(self, modes: int | Sequence[int]) -> State: idxz_conj = [i + len(self.modes) for i, m in enumerate(self.modes) if m not in modes] ansatz = self.ansatz.trace(idxz, idxz_conj) - return self.__class__._from_attributes( - Representation(ansatz, wires), self.name - ) # pylint: disable=protected-access + return self.__class__._from_attributes(Representation(ansatz, wires), self.name) def __rshift__(self, other: CircuitComponent) -> CircuitComponent: r""" diff --git a/mrmustard/physics/__init__.py b/mrmustard/physics/__init__.py index 03bcd4422..6e447f399 100644 --- a/mrmustard/physics/__init__.py +++ b/mrmustard/physics/__init__.py @@ -23,7 +23,6 @@ from mrmustard.physics import fock_utils, gaussian -# pylint: disable=protected-access def fidelity(A, B) -> float: r"""Calculates the fidelity between two quantum states. diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index 4a11cfc86..a8501725a 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -84,7 +84,7 @@ def batch_size(self): @property def conj(self): ret = ArrayAnsatz(math.conj(self.array), batched=True) - ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access + ret._contract_idxs = self._contract_idxs return ret @property diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index 4b07b1bf7..8631b42be 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -159,7 +159,7 @@ def c(self, value): @property def conj(self): ret = PolyExpAnsatz(math.conj(self.A), math.conj(self.b), math.conj(self.c)) - ret._contract_idxs = self._contract_idxs # pylint: disable=protected-access + ret._contract_idxs = self._contract_idxs return ret @property diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index c06c8b83e..db14a9f48 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -207,7 +207,7 @@ def bra(self) -> Wires: New ``Wires`` object with only bra wires. """ ret = Wires(modes_out_bra=self.args[0], modes_in_bra=self.args[1]) - ret._original = self.original or self # pylint: disable=protected-access + ret._original = self.original or self return ret @cached_property @@ -216,7 +216,7 @@ def classical(self) -> Wires: New ``Wires`` object with only classical wires. """ ret = Wires(classical_out=self.args[4], classical_in=self.args[5]) - ret._original = self.original or self # pylint: disable=protected-access + ret._original = self.original or self return ret @cached_property @@ -230,7 +230,7 @@ def quantum(self) -> Wires: modes_out_ket=self.args[2], modes_in_ket=self.args[3], ) - ret._original = self.original or self # pylint: disable=protected-access + ret._original = self.original or self return ret @cached_property @@ -348,7 +348,7 @@ def input(self) -> Wires: New ``Wires`` object without output wires. """ ret = Wires(set(), self.args[1], set(), self.args[3], set(), self.args[5]) - ret._original = self.original or self # pylint: disable=protected-access + ret._original = self.original or self return ret @cached_property @@ -357,7 +357,7 @@ def ket(self) -> Wires: New ``Wires`` object with only ket wires. """ ret = Wires(modes_out_ket=self.args[2], modes_in_ket=self.args[3]) - ret._original = self.original or self # pylint: disable=protected-access + ret._original = self.original or self return ret @cached_property @@ -380,7 +380,7 @@ def output(self) -> Wires: New ``Wires`` object with only output wires. """ ret = Wires(self.args[0], set(), self.args[2], set(), self.args[4], set()) - ret._original = self.original or self # pylint: disable=protected-access + ret._original = self.original or self return ret @cached_property diff --git a/mrmustard/widgets/__init__.py b/mrmustard/widgets/__init__.py index 0d560e3bf..4c9026dd0 100644 --- a/mrmustard/widgets/__init__.py +++ b/mrmustard/widgets/__init__.py @@ -130,7 +130,6 @@ def get_abc_str(A, b, c, round_val): ) # Replace config to hide the Plotly mode bar # See: https://github.com/plotly/plotly.py/issues/1074#issuecomment-1471486307 - # pylint:disable=protected-access eigvals_w._config = eigvals_w._config | {"displayModeBar": False} eigvals_w.add_shape( type="circle", diff --git a/tests/test_lab/test_gates_fock.py b/tests/test_lab/test_gates_fock.py index c499d8b9e..42c89ff05 100644 --- a/tests/test_lab/test_gates_fock.py +++ b/tests/test_lab/test_gates_fock.py @@ -360,7 +360,6 @@ def test_schwinger_bs_equals_vanilla_bs_for_small_cutoffs(theta, phi): assert np.allclose(U_vanilla, U_schwinger, atol=1e-6) -# pylint: disable=protected-access @given(phase_stdev=medium_float.filter(lambda x: x > 0)) def test_phasenoise_creates_dm(phase_stdev): """test that the phase noise gate is correctly applied""" diff --git a/tests/test_lab/test_states.py b/tests/test_lab/test_states.py index 7538bd7c5..e05eae9d3 100644 --- a/tests/test_lab/test_states.py +++ b/tests/test_lab/test_states.py @@ -35,8 +35,6 @@ hbar0 = settings.HBAR -# pylint: disable=protected-access - @st.composite def xy_arrays(draw): @@ -298,7 +296,7 @@ def test_padding_ket(): "Test that padding a ket works correctly." state = State(ket=SqueezedVacuum(r=1.0).ket(cutoffs=[20])) assert len(state.ket(cutoffs=[10])) == 10 - assert len(state._ket) == 20 # pylint: disable=protected-access + assert len(state._ket) == 20 def test_padding_dm(): @@ -308,18 +306,18 @@ def test_padding_dm(): assert tuple(int(c) for c in state._dm.shape) == ( 20, 20, - ) # pylint: disable=protected-access + ) def test_state_repr_small_prob(): "test that small probabilities are displayed correctly" state = State(ket=np.array([0.0001, 0.0001])) - table = state._repr_markdown_() # pylint: disable=protected-access + table = state._repr_markdown_() assert "2.000e-06 %" in table def test_state_repr_big_prob(): "test that big probabilities are displayed correctly" state = State(ket=np.array([0.5, 0.5])) - table = state._repr_markdown_() # pylint: disable=protected-access + table = state._repr_markdown_() assert "50.000%" in table diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index cc96bb31b..d77d33693 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -14,7 +14,7 @@ """Tests for circuit components.""" -# pylint: disable=fixme, missing-function-docstring, protected-access, pointless-statement +# pylint: disable=fixme, missing-function-docstring, pointless-statement from unittest.mock import patch @@ -522,7 +522,7 @@ def test_ipython_repr(self, mock_display, is_fock, widget_cls): dgate = Dgate([1], x=0.1, y=0.1) if is_fock: dgate = dgate.to_fock() - dgate._ipython_display_() # pylint:disable=protected-access + dgate._ipython_display_() [box] = mock_display.call_args.args assert isinstance(box, Box) [wires_widget, rep_widget] = box.children @@ -533,7 +533,7 @@ def test_ipython_repr(self, mock_display, is_fock, widget_cls): def test_ipython_repr_invalid_obj(self, mock_display): """Test the IPython repr function.""" dgate = Dgate([1, 2], x=0.1, y=0.1).to_fock() - dgate._ipython_display_() # pylint:disable=protected-access + dgate._ipython_display_() [box] = mock_display.call_args.args assert isinstance(box, VBox) [title_widget, wires_widget] = box.children diff --git a/tests/test_lab_dev/test_circuit_components_utils/test_b_to_ps.py b/tests/test_lab_dev/test_circuit_components_utils/test_b_to_ps.py index 8e9688502..ccff92b7e 100644 --- a/tests/test_lab_dev/test_circuit_components_utils/test_b_to_ps.py +++ b/tests/test_lab_dev/test_circuit_components_utils/test_b_to_ps.py @@ -14,7 +14,7 @@ """Tests for BtoPS.""" -# pylint: disable=fixme, missing-function-docstring, protected-access, pointless-statement +# pylint: disable=fixme, missing-function-docstring, pointless-statement import numpy as np import pytest @@ -90,7 +90,7 @@ def testBtoPS_contraction_with_state(self): state_bargmann_triple = state.bargmann_triple() # get new triple by right shift - state_after = state >> BtoPS(modes=[0], s=0) # pylint: disable=protected-access + state_after = state >> BtoPS(modes=[0], s=0) A1, b1, c1 = state_after.bargmann_triple(batched=True) # get new triple by contraction @@ -118,7 +118,7 @@ def testBtoPS_contraction_with_state(self): state_bargmann_triple = state.bargmann_triple() # get new triple by right shift - state_after = state >> BtoPS(modes=[0, 1], s=0) # pylint: disable=protected-access + state_after = state >> BtoPS(modes=[0, 1], s=0) A1, b1, c1 = state_after.bargmann_triple(batched=True) # get new triple by contraction diff --git a/tests/test_lab_dev/test_circuit_components_utils/test_b_to_q.py b/tests/test_lab_dev/test_circuit_components_utils/test_b_to_q.py index cb0a14676..e02c47a05 100644 --- a/tests/test_lab_dev/test_circuit_components_utils/test_b_to_q.py +++ b/tests/test_lab_dev/test_circuit_components_utils/test_b_to_q.py @@ -14,7 +14,7 @@ """Tests for BtoQ.""" -# pylint: disable=fixme, missing-function-docstring, protected-access, pointless-statement +# pylint: disable=fixme, missing-function-docstring, pointless-statement import numpy as np diff --git a/tests/test_lab_dev/test_circuit_components_utils/test_trace_out.py b/tests/test_lab_dev/test_circuit_components_utils/test_trace_out.py index 4ef146670..e73f34212 100644 --- a/tests/test_lab_dev/test_circuit_components_utils/test_trace_out.py +++ b/tests/test_lab_dev/test_circuit_components_utils/test_trace_out.py @@ -14,7 +14,7 @@ """Tests for trace out.""" -# pylint: disable=fixme, missing-function-docstring, protected-access, pointless-statement +# pylint: disable=fixme, missing-function-docstring, pointless-statement import numpy as np import pytest diff --git a/tests/test_lab_dev/test_circuits.py b/tests/test_lab_dev/test_circuits.py index 6acaa8973..580ace6e9 100644 --- a/tests/test_lab_dev/test_circuits.py +++ b/tests/test_lab_dev/test_circuits.py @@ -14,7 +14,7 @@ """Tests for the ``Circuit`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import pytest @@ -178,9 +178,7 @@ def test_repr(self): bs12 = BSgate([1, 2]) n12 = Number([0, 1], n=3) n2 = Number([2], n=3) - cc = CircuitComponent._from_attributes( - bs01.representation, "my_cc" - ) # pylint: disable=protected-access + cc = CircuitComponent._from_attributes(bs01.representation, "my_cc") assert repr(Circuit()) == "" diff --git a/tests/test_lab_dev/test_samplers.py b/tests/test_lab_dev/test_samplers.py index c5d130dad..5fdc44c5c 100644 --- a/tests/test_lab_dev/test_samplers.py +++ b/tests/test_lab_dev/test_samplers.py @@ -78,32 +78,26 @@ class TestHomodyneSampler: def test_init(self): sampler = HomodyneSampler(phi=0.5, bounds=(-5, 5), num=100) assert sampler.povms is None - assert sampler._phi == 0.5 # pylint: disable=protected-access + assert sampler._phi == 0.5 assert math.allclose(sampler.meas_outcomes, list(np.linspace(-5, 5, 100))) def test_povm_error(self): sampler = HomodyneSampler() with pytest.raises(ValueError, match="no POVMs"): - sampler._get_povm(0, 0) # pylint: disable=protected-access + sampler._get_povm(0, 0) def test_probabilties(self): sampler = HomodyneSampler() state = Coherent([0], x=[0.1]) - exp_probs = ( - state.quadrature_distribution(sampler.meas_outcomes) - * sampler._step # pylint: disable=protected-access - ) + exp_probs = state.quadrature_distribution(sampler.meas_outcomes) * sampler._step assert math.allclose(sampler.probabilities(state), exp_probs) sampler2 = HomodyneSampler(phi=np.pi / 2) exp_probs = ( - state.quadrature_distribution( - sampler2.meas_outcomes, sampler2._phi # pylint: disable=protected-access - ) - * sampler2._step # pylint: disable=protected-access + state.quadrature_distribution(sampler2.meas_outcomes, sampler2._phi) * sampler2._step ) assert math.allclose(sampler2.probabilities(state), exp_probs) diff --git a/tests/test_lab_dev/test_states/test_coherent.py b/tests/test_lab_dev/test_states/test_coherent.py index c1ec6dac6..7a38934ea 100644 --- a/tests/test_lab_dev/test_states/test_coherent.py +++ b/tests/test_lab_dev/test_states/test_coherent.py @@ -14,7 +14,7 @@ """Tests for the Coherent class.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement import numpy as np import pytest diff --git a/tests/test_lab_dev/test_states/test_displaced_squeezed.py b/tests/test_lab_dev/test_states/test_displaced_squeezed.py index 276ec5be1..c0f5f88ec 100644 --- a/tests/test_lab_dev/test_states/test_displaced_squeezed.py +++ b/tests/test_lab_dev/test_states/test_displaced_squeezed.py @@ -14,7 +14,7 @@ """Tests for the DisplacedSqueezed class.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement import pytest diff --git a/tests/test_lab_dev/test_states/test_number.py b/tests/test_lab_dev/test_states/test_number.py index 998923a62..ec5d90d55 100644 --- a/tests/test_lab_dev/test_states/test_number.py +++ b/tests/test_lab_dev/test_states/test_number.py @@ -14,7 +14,7 @@ """Tests for the ``Number`` class.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement import pytest diff --git a/tests/test_lab_dev/test_states/test_quadrature_eigenstate.py b/tests/test_lab_dev/test_states/test_quadrature_eigenstate.py index b203545ea..81d2748ea 100644 --- a/tests/test_lab_dev/test_states/test_quadrature_eigenstate.py +++ b/tests/test_lab_dev/test_states/test_quadrature_eigenstate.py @@ -14,7 +14,7 @@ """Tests for the ``QuadratureEigenstate`` class.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement import numpy as np import pytest diff --git a/tests/test_lab_dev/test_states/test_squeezed_vacuum.py b/tests/test_lab_dev/test_states/test_squeezed_vacuum.py index ccb0d6233..57b3740db 100644 --- a/tests/test_lab_dev/test_states/test_squeezed_vacuum.py +++ b/tests/test_lab_dev/test_states/test_squeezed_vacuum.py @@ -14,7 +14,7 @@ """Tests for the ``SqueezedVacuum`` class.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement import pytest diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index 4f1a6ca17..052b4870a 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -14,7 +14,7 @@ """Tests for the base state subpackage.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement from itertools import product import numpy as np @@ -356,14 +356,12 @@ def test_expectation_error(self): def test_rshift(self): ket = Coherent([0, 1], 1) unitary = Dgate([0], 1) - u_component = CircuitComponent._from_attributes( - unitary.representation, unitary.name - ) # pylint: disable=protected-access + u_component = CircuitComponent._from_attributes(unitary.representation, unitary.name) channel = Attenuator([1], 1) ch_component = CircuitComponent._from_attributes( channel.representation, channel.name, - ) # pylint: disable=protected-access + ) # gates assert isinstance(ket >> unitary, Ket) @@ -529,7 +527,7 @@ def test_bargmann_Abc_to_phasespace_cov_means(self): state_cov = np.array([[0.32210229, -0.99732956], [-0.99732956, 6.1926484]]) state_means = np.array([0.2, 0.3]) state = DM.from_bargmann([0], wigner_to_bargmann_rho(state_cov, state_means)) - state_after = state >> BtoPS(modes=[0], s=0) # pylint: disable=protected-access + state_after = state >> BtoPS(modes=[0], s=0) A1, b1, c1 = state_after.bargmann_triple() ( new_state_cov, @@ -552,7 +550,7 @@ def test_bargmann_Abc_to_phasespace_cov_means(self): A, b, c = wigner_to_bargmann_rho(state_cov, state_means) state = DM.from_bargmann(modes=[0, 1], triple=(A, b, c)) - state_after = state >> BtoPS(modes=[0, 1], s=0) # pylint: disable=protected-access + state_after = state >> BtoPS(modes=[0, 1], s=0) A1, b1, c1 = state_after.bargmann_triple() ( new_state_cov1, @@ -821,13 +819,9 @@ def test_expectation_error(self): def test_rshift(self): ket = Coherent([0, 1], 1) unitary = Dgate([0], 1) - u_component = CircuitComponent._from_attributes( - unitary.representation, unitary.name - ) # pylint: disable=protected-access + u_component = CircuitComponent._from_attributes(unitary.representation, unitary.name) channel = Attenuator([1], 1) - ch_component = CircuitComponent._from_attributes( - channel.representation, channel.name - ) # pylint: disable=protected-access + ch_component = CircuitComponent._from_attributes(channel.representation, channel.name) dm = ket >> channel diff --git a/tests/test_lab_dev/test_states/test_states_visualization.py b/tests/test_lab_dev/test_states/test_states_visualization.py index 1d17d28e8..a480d6f82 100644 --- a/tests/test_lab_dev/test_states/test_states_visualization.py +++ b/tests/test_lab_dev/test_states/test_states_visualization.py @@ -14,7 +14,7 @@ """Tests for the state visualization.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement import json from pathlib import Path diff --git a/tests/test_lab_dev/test_states/test_thermal.py b/tests/test_lab_dev/test_states/test_thermal.py index d73ffe4dc..6d1cadc6f 100644 --- a/tests/test_lab_dev/test_states/test_thermal.py +++ b/tests/test_lab_dev/test_states/test_thermal.py @@ -14,7 +14,7 @@ """Tests for the ``Thermal`` class.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement import pytest diff --git a/tests/test_lab_dev/test_states/test_two_mode_squeezed_vacuum.py b/tests/test_lab_dev/test_states/test_two_mode_squeezed_vacuum.py index 4846761b7..f4b249080 100644 --- a/tests/test_lab_dev/test_states/test_two_mode_squeezed_vacuum.py +++ b/tests/test_lab_dev/test_states/test_two_mode_squeezed_vacuum.py @@ -14,7 +14,7 @@ """Tests for the ``TwoModeSqueezedVacuum`` class.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement import pytest diff --git a/tests/test_lab_dev/test_states/test_vacuum.py b/tests/test_lab_dev/test_states/test_vacuum.py index 19a750296..3a9de3314 100644 --- a/tests/test_lab_dev/test_states/test_vacuum.py +++ b/tests/test_lab_dev/test_states/test_vacuum.py @@ -14,7 +14,7 @@ """Tests for the ``Vacuum`` class.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement import numpy as np import pytest diff --git a/tests/test_lab_dev/test_transformations/test_amplifier.py b/tests/test_lab_dev/test_transformations/test_amplifier.py index 2862d9dda..a492805fe 100644 --- a/tests/test_lab_dev/test_transformations/test_amplifier.py +++ b/tests/test_lab_dev/test_transformations/test_amplifier.py @@ -14,7 +14,7 @@ """Tests for the ``Amplifier`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np import pytest diff --git a/tests/test_lab_dev/test_transformations/test_attenuator.py b/tests/test_lab_dev/test_transformations/test_attenuator.py index ef7ebaff0..d78343f9c 100644 --- a/tests/test_lab_dev/test_transformations/test_attenuator.py +++ b/tests/test_lab_dev/test_transformations/test_attenuator.py @@ -14,7 +14,7 @@ """Tests for the ``Attenuator`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np import pytest diff --git a/tests/test_lab_dev/test_transformations/test_bsgate.py b/tests/test_lab_dev/test_transformations/test_bsgate.py index 62a2aa50d..748e6c0e2 100644 --- a/tests/test_lab_dev/test_transformations/test_bsgate.py +++ b/tests/test_lab_dev/test_transformations/test_bsgate.py @@ -14,7 +14,7 @@ """Tests for the ``BSgate`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np import pytest diff --git a/tests/test_lab_dev/test_transformations/test_dgate.py b/tests/test_lab_dev/test_transformations/test_dgate.py index b2ff27a5e..2fa57c950 100644 --- a/tests/test_lab_dev/test_transformations/test_dgate.py +++ b/tests/test_lab_dev/test_transformations/test_dgate.py @@ -14,7 +14,7 @@ """Tests for the ``Dgate`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import pytest import numpy as np diff --git a/tests/test_lab_dev/test_transformations/test_fockdamping.py b/tests/test_lab_dev/test_transformations/test_fockdamping.py index a59308d0f..e99cb038c 100644 --- a/tests/test_lab_dev/test_transformations/test_fockdamping.py +++ b/tests/test_lab_dev/test_transformations/test_fockdamping.py @@ -14,7 +14,7 @@ """Tests for the ``FockDamping`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np import pytest diff --git a/tests/test_lab_dev/test_transformations/test_identity.py b/tests/test_lab_dev/test_transformations/test_identity.py index 1d5691959..452752cc2 100644 --- a/tests/test_lab_dev/test_transformations/test_identity.py +++ b/tests/test_lab_dev/test_transformations/test_identity.py @@ -14,7 +14,7 @@ """Tests for the ``Identity`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np import pytest diff --git a/tests/test_lab_dev/test_transformations/test_rgate.py b/tests/test_lab_dev/test_transformations/test_rgate.py index d1b9b9786..a742d59d8 100644 --- a/tests/test_lab_dev/test_transformations/test_rgate.py +++ b/tests/test_lab_dev/test_transformations/test_rgate.py @@ -14,7 +14,7 @@ """Tests for the ``Rgate`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np import pytest diff --git a/tests/test_lab_dev/test_transformations/test_s2gate.py b/tests/test_lab_dev/test_transformations/test_s2gate.py index 030ff51bd..ac712ea3b 100644 --- a/tests/test_lab_dev/test_transformations/test_s2gate.py +++ b/tests/test_lab_dev/test_transformations/test_s2gate.py @@ -14,7 +14,7 @@ """Tests for the ``S2gate`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np import pytest diff --git a/tests/test_lab_dev/test_transformations/test_sgate.py b/tests/test_lab_dev/test_transformations/test_sgate.py index 73a927044..5018374eb 100644 --- a/tests/test_lab_dev/test_transformations/test_sgate.py +++ b/tests/test_lab_dev/test_transformations/test_sgate.py @@ -14,7 +14,7 @@ """Tests for the ``Sgate`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np import pytest diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 7fdc48ad3..1ff69bc14 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -14,7 +14,7 @@ """Tests for the base transformation subpackage.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np import pytest @@ -65,13 +65,9 @@ def test_init(self, name, modes): def test_rshift(self): unitary1 = Dgate([0, 1], 1) unitary2 = Dgate([1, 2], 2) - u_component = CircuitComponent._from_attributes( - unitary1.representation, unitary1.name - ) # pylint: disable=protected-access + u_component = CircuitComponent._from_attributes(unitary1.representation, unitary1.name) channel = Attenuator([1], 1) - ch_component = CircuitComponent._from_attributes( - channel.representation, channel.name - ) # pylint: disable=protected-access + ch_component = CircuitComponent._from_attributes(channel.representation, channel.name) assert isinstance(unitary1 >> unitary2, Unitary) assert isinstance(unitary1 >> channel, Channel) @@ -80,9 +76,7 @@ def test_rshift(self): def test_repr(self): unitary1 = Dgate([0, 1], 1) - u_component = CircuitComponent._from_attributes( - unitary1.representation, unitary1.name - ) # pylint: disable=protected-access + u_component = CircuitComponent._from_attributes(unitary1.representation, unitary1.name) assert repr(unitary1) == "Dgate(modes=[0, 1], name=Dgate, repr=PolyExpAnsatz)" assert repr(unitary1.to_fock(5)) == "Dgate(modes=[0, 1], name=Dgate, repr=ArrayAnsatz)" assert repr(u_component) == "CircuitComponent(modes=[0, 1], name=Dgate, repr=PolyExpAnsatz)" @@ -148,14 +142,10 @@ def test_init_from_bargmann(self): def test_rshift(self): unitary = Dgate([0, 1], 1) - u_component = CircuitComponent._from_attributes( - unitary.representation, unitary.name - ) # pylint: disable=protected-access + u_component = CircuitComponent._from_attributes(unitary.representation, unitary.name) channel1 = Attenuator([1, 2], 0.9) channel2 = Attenuator([2, 3], 0.9) - ch_component = CircuitComponent._from_attributes( - channel1.representation, channel1.name - ) # pylint: disable=protected-access + ch_component = CircuitComponent._from_attributes(channel1.representation, channel1.name) assert isinstance(channel1 >> unitary, Channel) assert isinstance(channel1 >> channel2, Channel) @@ -164,9 +154,7 @@ def test_rshift(self): def test_repr(self): channel1 = Attenuator([0, 1], 0.9) - ch_component = CircuitComponent._from_attributes( - channel1.representation, channel1.name - ) # pylint: disable=protected-access + ch_component = CircuitComponent._from_attributes(channel1.representation, channel1.name) assert repr(channel1) == "Attenuator(modes=[0, 1], name=Att, repr=PolyExpAnsatz)" assert repr(ch_component) == "CircuitComponent(modes=[0, 1], name=Att, repr=PolyExpAnsatz)" diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index 6b30771e4..5562c6596 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -26,7 +26,7 @@ from ..conftest import skip_np -# pylint: disable=protected-access, too-many-public-methods +# pylint: disable=too-many-public-methods class TestBackendManager: r""" Tests the BackendManager. diff --git a/tests/test_physics/test_ansatz/test_array_ansatz.py b/tests/test_physics/test_ansatz/test_array_ansatz.py index 5fecae9bf..02cc5a46e 100644 --- a/tests/test_physics/test_ansatz/test_array_ansatz.py +++ b/tests/test_physics/test_ansatz/test_array_ansatz.py @@ -224,7 +224,7 @@ def test_truediv_a_scalar(self): def test_ipython_repr(self, mock_display, shape): """Test the IPython repr function.""" rep = ArrayAnsatz(np.random.random(shape), batched=True) - rep._ipython_display_() # pylint:disable=protected-access + rep._ipython_display_() [hbox] = mock_display.call_args.args assert isinstance(hbox, HBox) @@ -247,12 +247,12 @@ def test_ipython_repr(self, mock_display, shape): def test_ipython_repr_expects_batch_1(self, mock_display): """Test the IPython repr function does nothing with real batch.""" rep = ArrayAnsatz(np.random.random((2, 8)), batched=True) - rep._ipython_display_() # pylint:disable=protected-access + rep._ipython_display_() mock_display.assert_not_called() @patch("mrmustard.physics.ansatz.array_ansatz.display") def test_ipython_repr_expects_3_dims_or_less(self, mock_display): """Test the IPython repr function does nothing with 4+ dims.""" rep = ArrayAnsatz(np.random.random((1, 4, 4, 4)), batched=True) - rep._ipython_display_() # pylint:disable=protected-access + rep._ipython_display_() mock_display.assert_not_called() diff --git a/tests/test_physics/test_ansatz/test_polyexp_ansatz.py b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py index 7e15c1d68..678750a8b 100644 --- a/tests/test_physics/test_ansatz/test_polyexp_ansatz.py +++ b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py @@ -281,7 +281,7 @@ def test_inconsistent_poly_shapes(self): def test_ipython_repr(self, mock_display): """Test the IPython repr function.""" rep = PolyExpAnsatz(*Abc_triple(2)) - rep._ipython_display_() # pylint:disable=protected-access + rep._ipython_display_() [box] = mock_display.call_args.args assert isinstance(box, Box) assert box.layout.max_width == "50%" @@ -308,7 +308,7 @@ def test_ipython_repr_batched(self, mock_display): A1, b1, c1 = Abc_triple(2) A2, b2, c2 = Abc_triple(2) rep = PolyExpAnsatz(np.array([A1, A2]), np.array([b1, b2]), np.array([c1, c2])) - rep._ipython_display_() # pylint:disable=protected-access + rep._ipython_display_() [vbox] = mock_display.call_args.args assert isinstance(vbox, VBox) @@ -360,7 +360,7 @@ def test_order_batch(self): b=[np.array([1]), np.array([0])], c=[1, 2], ) - ansatz._order_batch() # pylint: disable=protected-access + ansatz._order_batch() assert np.allclose(ansatz.A[0], np.array([[1]])) assert np.allclose(ansatz.b[0], np.array([0])) diff --git a/tests/test_physics/test_representations.py b/tests/test_physics/test_representations.py index 0f6c344e2..50466efa8 100644 --- a/tests/test_physics/test_representations.py +++ b/tests/test_physics/test_representations.py @@ -16,14 +16,11 @@ # pylint: disable=missing-function-docstring -from unittest.mock import patch - -from ipywidgets import HTML import pytest from mrmustard.physics.representations import Representation, RepEnum from mrmustard.physics.wires import Wires -from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz +from mrmustard.physics.ansatz import PolyExpAnsatz from ..random import Abc_triple diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index 4a9ca81a5..1b55794ef 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -153,15 +153,15 @@ def test_getitem(self): w0 = Wires({0}, {0}) assert w[0] == w0 - assert w._mode_cache == {(0,): w0} # pylint: disable=protected-access + assert w._mode_cache == {(0,): w0} w1 = Wires({1}) assert w[1] == w1 - assert w._mode_cache == {(0,): w0, (1,): w1} # pylint: disable=protected-access + assert w._mode_cache == {(0,): w0, (1,): w1} w2 = Wires(set(), {2}) assert w[2] == w2 - assert w._mode_cache == { # pylint: disable=protected-access + assert w._mode_cache == { (0,): w0, (1,): w1, (2,): w2, @@ -217,6 +217,6 @@ def test_matmul_error(self): def test_ipython_repr(self, mock_display): """Test the IPython repr function.""" wires = Wires({0}, {}, {3}, {3, 4}) - wires._ipython_display_() # pylint:disable=protected-access + wires._ipython_display_() [widget] = mock_display.call_args.args assert isinstance(widget, HTML) From 00b280e9a7addeb730a6d85e6974b37768adf887 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 22 Oct 2024 15:15:12 -0400 Subject: [PATCH 66/87] codefactor --- mrmustard/lab_dev/states/quadrature_eigenstate.py | 1 - mrmustard/lab_dev/states/squeezed_vacuum.py | 1 - mrmustard/lab_dev/states/thermal.py | 1 - mrmustard/physics/ansatz/array_ansatz.py | 14 +++++++------- mrmustard/physics/ansatz/base.py | 6 +++--- mrmustard/physics/representations.py | 4 ++-- 6 files changed, 12 insertions(+), 15 deletions(-) diff --git a/mrmustard/lab_dev/states/quadrature_eigenstate.py b/mrmustard/lab_dev/states/quadrature_eigenstate.py index 38aa33e92..cc2259a8e 100644 --- a/mrmustard/lab_dev/states/quadrature_eigenstate.py +++ b/mrmustard/lab_dev/states/quadrature_eigenstate.py @@ -22,7 +22,6 @@ import numpy as np -from mrmustard.physics.representations import Representation from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket diff --git a/mrmustard/lab_dev/states/squeezed_vacuum.py b/mrmustard/lab_dev/states/squeezed_vacuum.py index 730055b9e..5774477e8 100644 --- a/mrmustard/lab_dev/states/squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/squeezed_vacuum.py @@ -20,7 +20,6 @@ from typing import Sequence -from mrmustard.physics.representations import Representation from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import Ket diff --git a/mrmustard/lab_dev/states/thermal.py b/mrmustard/lab_dev/states/thermal.py index f8171da5a..e01efd8c8 100644 --- a/mrmustard/lab_dev/states/thermal.py +++ b/mrmustard/lab_dev/states/thermal.py @@ -20,7 +20,6 @@ from typing import Sequence -from mrmustard.physics.representations import Representation from mrmustard.physics.ansatz import PolyExpAnsatz from mrmustard.physics import triples from .base import DM diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index a8501725a..7455a0af1 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -187,18 +187,18 @@ def sum_batch(self) -> ArrayAnsatz: def to_dict(self) -> dict[str, ArrayLike]: return {"array": self.data} - def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> ArrayAnsatz: - if len(idxs1) != len(idxs2) or not set(idxs1).isdisjoint(idxs2): + def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> ArrayAnsatz: + if len(idx_z) != len(idx_zconj) or not set(idx_z).isdisjoint(idx_zconj): raise ValueError("The idxs must be of equal length and disjoint.") order = ( [0] - + [i + 1 for i in range(len(self.array.shape) - 1) if i not in idxs1 + idxs2] - + [i + 1 for i in idxs1] - + [i + 1 for i in idxs2] + + [i + 1 for i in range(len(self.array.shape) - 1) if i not in idx_z + idx_zconj] + + [i + 1 for i in idx_z] + + [i + 1 for i in idx_zconj] ) new_array = math.transpose(self.array, order) - n = np.prod(new_array.shape[-len(idxs2) :]) - new_array = math.reshape(new_array, new_array.shape[: -2 * len(idxs1)] + (n, n)) + n = np.prod(new_array.shape[-len(idx_zconj) :]) + new_array = math.reshape(new_array, new_array.shape[: -2 * len(idx_z)] + (n, n)) trace = math.trace(new_array) return ArrayAnsatz([trace] if trace.shape == () else trace, batched=True) diff --git a/mrmustard/physics/ansatz/base.py b/mrmustard/physics/ansatz/base.py index 9285eb9bd..26086bc7d 100644 --- a/mrmustard/physics/ansatz/base.py +++ b/mrmustard/physics/ansatz/base.py @@ -118,13 +118,13 @@ def to_dict(self) -> dict[str, ArrayLike]: """ @abstractmethod - def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Ansatz: + def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> Ansatz: r""" Implements the partial trace over the given index pairs. Args: - idxs1: The first part of the pairs of indices to trace over. - idxs2: The second part. + idx_z: The first part of the pairs of indices to trace over. + idx_zconj: The second part. Returns: The traced-over ansatz. diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index dd3b1e81a..f7a057b87 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -246,11 +246,11 @@ def fock_array(self, shape: int | Sequence[int], batched=False) -> ComplexTensor ) for A, b, c in zip(As, bs, cs) ] - except AttributeError: + except AttributeError as e: if len(shape) != num_vars: raise ValueError( f"Expected Fock shape of length {num_vars}, got length {len(shape)}" - ) + ) from e arrays = self.ansatz.reduce(shape).array array = math.sum(arrays, axes=[0]) arrays = math.expand_dims(array, 0) if batched else array From 20c63e0490b0b2c53173e2232b76b4b1531236ae Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 22 Oct 2024 15:21:06 -0400 Subject: [PATCH 67/87] fix --- tests/test_physics/test_ansatz/test_array_ansatz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_physics/test_ansatz/test_array_ansatz.py b/tests/test_physics/test_ansatz/test_array_ansatz.py index 02cc5a46e..17f32703c 100644 --- a/tests/test_physics/test_ansatz/test_array_ansatz.py +++ b/tests/test_physics/test_ansatz/test_array_ansatz.py @@ -198,7 +198,7 @@ def test_sum_batch(self): def test_trace(self): array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) fock1 = ArrayAnsatz(array1, batched=True) - fock2 = fock1.trace(idxs1=[0, 3], idxs2=[1, 6]) + fock2 = fock1.trace([0, 3], [1, 6]) assert fock2.array.shape == (2, 1, 4, 1, 3) assert np.allclose(fock2.array, np.einsum("bccefghfj -> beghj", array1)) From 9c846dd1442dbdea87b34a7470dafdb7c23fc988 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 23 Oct 2024 12:19:23 -0400 Subject: [PATCH 68/87] patch 100 --- .codecov.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.codecov.yml b/.codecov.yml index e1bb03ea7..d1515b335 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -5,5 +5,5 @@ coverage: target: 89% patch: default: - target: auto + target: 100% threshold: 0% From 7253b75a205d6ace7e5a08469d6d77fb5c3cdb43 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 23 Oct 2024 14:21:15 -0400 Subject: [PATCH 69/87] some tests --- tests/test_physics/test_representations.py | 69 +++++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/tests/test_physics/test_representations.py b/tests/test_physics/test_representations.py index 50466efa8..f6cf96598 100644 --- a/tests/test_physics/test_representations.py +++ b/tests/test_physics/test_representations.py @@ -18,9 +18,12 @@ import pytest +from mrmustard import math + from mrmustard.physics.representations import Representation, RepEnum from mrmustard.physics.wires import Wires -from mrmustard.physics.ansatz import PolyExpAnsatz +from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz +from mrmustard.physics.triples import displacement_gate_Abc, bargmann_to_quadrature_Abc from ..random import Abc_triple @@ -34,6 +37,23 @@ class TestRepresentation: Abc_n2 = Abc_triple(2) Abc_n3 = Abc_triple(3) + @pytest.fixture + def d_gate_rep(self): + ansatz = PolyExpAnsatz.from_function(fn=displacement_gate_Abc, x=0.1, y=0.1) + wires = Wires((), (), set([0]), set([0])) + return Representation(ansatz, wires) + + @pytest.fixture + def btoq_rep(self): + ansatz = PolyExpAnsatz.from_function(fn=bargmann_to_quadrature_Abc, n_modes=1, phi=0.2) + wires = Wires((), (), set([0]), set([0])) + idx_reps = {} + for i in wires.input.indices: + idx_reps[i] = (RepEnum.BARGMANN, None, tuple()) + for i in wires.output.indices: + idx_reps[i] = (RepEnum.QUADRATURE, float(0.2), tuple()) + return Representation(ansatz, wires, idx_reps) + @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) def test_init(self, triple): empty_rep = Representation() @@ -42,10 +62,55 @@ def test_init(self, triple): assert empty_rep._idx_reps == {} ansatz = PolyExpAnsatz(*triple) - wires = Wires() + wires = Wires(set([0, 1])) rep = Representation(ansatz, wires) assert rep.ansatz == ansatz assert rep.wires == wires assert rep._idx_reps == dict.fromkeys( wires.indices, (RepEnum.from_ansatz(ansatz), None, tuple()) ) + + @pytest.mark.parametrize("triple", [Abc_n2]) + def test_adjoint_idx_reps(self, triple): + ansatz = PolyExpAnsatz(*triple) + wires = Wires(modes_out_bra=set([0]), modes_out_ket=set([0])) + idx_reps = {0: (RepEnum.BARGMANN, None, tuple()), 1: (RepEnum.QUADRATURE, 0.1, tuple())} + rep = Representation(ansatz, wires, idx_reps) + adj_rep = rep.adjoint + assert adj_rep._idx_reps == { + 1: (RepEnum.BARGMANN, None, tuple()), + 0: (RepEnum.QUADRATURE, 0.1, tuple()), + } + + @pytest.mark.parametrize("triple", [Abc_n2]) + def test_dual_idx_reps(self, triple): + ansatz = PolyExpAnsatz(*triple) + wires = Wires(modes_out_bra=set([0]), modes_in_bra=set([0])) + idx_reps = {0: (RepEnum.BARGMANN, None, tuple()), 1: (RepEnum.QUADRATURE, 0.1, tuple())} + rep = Representation(ansatz, wires, idx_reps) + adj_rep = rep.dual + assert adj_rep._idx_reps == { + 1: (RepEnum.BARGMANN, None, tuple()), + 0: (RepEnum.QUADRATURE, 0.1, tuple()), + } + + def test_matmul_btoq(self, d_gate_rep, btoq_rep): + q_dgate = d_gate_rep @ btoq_rep + assert q_dgate._idx_reps == { + 0: (RepEnum.QUADRATURE, 0.2, ()), + 1: (RepEnum.BARGMANN, None, ()), + } + + def test_to_bargmann(self, d_gate_rep): + d_fock = d_gate_rep.to_fock(shape=(4, 6)) + d_barg = d_fock.to_bargmann() + assert d_fock.ansatz._original_abc_data == d_gate_rep.ansatz.triple + assert d_barg == d_gate_rep + assert all([k[0] == RepEnum.BARGMANN for k in d_barg._idx_reps.values()]) + + def test_to_fock(self, d_gate_rep): + d_fock = d_gate_rep.to_fock(shape=(4, 6)) + assert d_fock.ansatz == ArrayAnsatz( + math.hermite_renormalized(*displacement_gate_Abc(x=0.1, y=0.1), shape=(4, 6)) + ) + assert all([k[0] == RepEnum.FOCK for k in d_fock._idx_reps.values()]) From e0f66fd8f8e1d9b65be03f53c5633c57e52e8188 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 23 Oct 2024 14:37:09 -0400 Subject: [PATCH 70/87] rem --- mrmustard/lab_dev/circuit_components.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 88f598ab3..08fe8fb9c 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -483,18 +483,15 @@ def to_bargmann(self) -> CircuitComponent: >>> assert d_bargmann.wires == d.wires >>> assert isinstance(d_bargmann.ansatz, PolyExpAnsatz) """ - if isinstance(self.ansatz, PolyExpAnsatz): - return self - else: - rep = self._representation.to_bargmann() - try: - ret = self._getitem_builtin(self.modes) - ret._representation = rep - except TypeError: - ret = self._from_attributes(rep, self.name) - if "manual_shape" in ret.__dict__: - del ret.manual_shape - return ret + rep = self._representation.to_bargmann() + try: + ret = self._getitem_builtin(self.modes) + ret._representation = rep + except TypeError: + ret = self._from_attributes(rep, self.name) + if "manual_shape" in ret.__dict__: + del ret.manual_shape + return ret def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: r""" From 348a702e2c584edf0cf3c543287da9526ee68636 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 23 Oct 2024 14:46:35 -0400 Subject: [PATCH 71/87] cleanup _from_attribute --- mrmustard/lab_dev/circuit_components.py | 4 ++-- mrmustard/lab_dev/states/base.py | 4 ++-- mrmustard/lab_dev/transformations/base.py | 8 ++++---- tests/test_lab_dev/test_circuit_components.py | 2 +- tests/test_lab_dev/test_circuits.py | 2 +- tests/test_lab_dev/test_states/test_states_base.py | 8 ++++---- .../test_transformations_base.py | 12 ++++++------ 7 files changed, 20 insertions(+), 20 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 08fe8fb9c..e2d8549f8 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -276,7 +276,7 @@ def from_quadrature( QtoB_ok = BtoQ(modes_out_ket, phi).inverse() # output ket QtoB_ik = BtoQ(modes_in_ket, phi).inverse().dual # input ket # NOTE: the representation is Bargmann here because we use the inverse of BtoQ on the B side - QQQQ = CircuitComponent._from_attributes(Representation(PolyExpAnsatz(*triple), wires)) + QQQQ = CircuitComponent(Representation(PolyExpAnsatz(*triple), wires)) BBBB = QtoB_ib @ (QtoB_ik @ QQQQ @ QtoB_ok) @ QtoB_ob return cls._from_attributes(Representation(BBBB.ansatz, wires), name) @@ -615,7 +615,7 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: if isinstance(other, (numbers.Number, np.ndarray)): return self * other result = self._representation @ other._representation - return CircuitComponent._from_attributes(result, None) + return CircuitComponent(result, None) def __mul__(self, other: Scalar) -> CircuitComponent: r""" diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index d80300136..fe7dcadd0 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -967,7 +967,7 @@ def __getitem__(self, modes: int | Sequence[int]) -> State: idxz_conj = [i + len(self.modes) for i, m in enumerate(self.modes) if m not in modes] ansatz = self.ansatz.trace(idxz, idxz_conj) - return self.__class__._from_attributes(Representation(ansatz, wires), self.name) + return self._from_attributes(Representation(ansatz, wires), self.name) def __rshift__(self, other: CircuitComponent) -> CircuitComponent: r""" @@ -1183,7 +1183,7 @@ def dm(self) -> DM: The ``DM`` object obtained from this ``Ket``. """ dm = self @ self.adjoint - ret = DM._from_attributes(dm.representation, self.name) + ret = DM(dm.representation, self.name) ret.manual_shape = self.manual_shape + self.manual_shape return ret diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index 3db118b8d..644b999aa 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -278,7 +278,7 @@ def random(cls, modes, max_r=1): def inverse(self) -> Unitary: unitary_dual = self.dual - return Unitary._from_attributes( + return Unitary( representation=unitary_dual.representation, name=unitary_dual.name, ) @@ -297,9 +297,9 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: ret = super().__rshift__(other) if isinstance(other, Unitary): - return Unitary._from_attributes(ret.representation) + return Unitary(ret.representation) elif isinstance(other, Channel): - return Channel._from_attributes(ret.representation) + return Channel(ret.representation) return ret @@ -492,5 +492,5 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: """ ret = super().__rshift__(other) if isinstance(other, (Channel, Unitary)): - return Channel._from_attributes(ret.representation) + return Channel(ret.representation) return ret diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index d77d33693..2ca9202a0 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -118,7 +118,7 @@ def test_from_attributes(self, x, y): def test_from_to_quadrature(self): c = Dgate([0], x=0.1, y=0.2) >> Sgate([0], r=1.0, phi=0.1) - cc = CircuitComponent._from_attributes(c.representation, c.name) + cc = CircuitComponent(c.representation, c.name) ccc = CircuitComponent.from_quadrature(tuple(), tuple(), (0,), (0,), cc.quadrature_triple()) assert cc == ccc diff --git a/tests/test_lab_dev/test_circuits.py b/tests/test_lab_dev/test_circuits.py index 580ace6e9..9e426db3e 100644 --- a/tests/test_lab_dev/test_circuits.py +++ b/tests/test_lab_dev/test_circuits.py @@ -178,7 +178,7 @@ def test_repr(self): bs12 = BSgate([1, 2]) n12 = Number([0, 1], n=3) n2 = Number([2], n=3) - cc = CircuitComponent._from_attributes(bs01.representation, "my_cc") + cc = CircuitComponent(bs01.representation, "my_cc") assert repr(Circuit()) == "" diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index 052b4870a..d13ffd850 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -356,9 +356,9 @@ def test_expectation_error(self): def test_rshift(self): ket = Coherent([0, 1], 1) unitary = Dgate([0], 1) - u_component = CircuitComponent._from_attributes(unitary.representation, unitary.name) + u_component = CircuitComponent(unitary.representation, unitary.name) channel = Attenuator([1], 1) - ch_component = CircuitComponent._from_attributes( + ch_component = CircuitComponent( channel.representation, channel.name, ) @@ -819,9 +819,9 @@ def test_expectation_error(self): def test_rshift(self): ket = Coherent([0, 1], 1) unitary = Dgate([0], 1) - u_component = CircuitComponent._from_attributes(unitary.representation, unitary.name) + u_component = CircuitComponent(unitary.representation, unitary.name) channel = Attenuator([1], 1) - ch_component = CircuitComponent._from_attributes(channel.representation, channel.name) + ch_component = CircuitComponent(channel.representation, channel.name) dm = ket >> channel diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 1ff69bc14..c36d2b220 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -65,9 +65,9 @@ def test_init(self, name, modes): def test_rshift(self): unitary1 = Dgate([0, 1], 1) unitary2 = Dgate([1, 2], 2) - u_component = CircuitComponent._from_attributes(unitary1.representation, unitary1.name) + u_component = CircuitComponent(unitary1.representation, unitary1.name) channel = Attenuator([1], 1) - ch_component = CircuitComponent._from_attributes(channel.representation, channel.name) + ch_component = CircuitComponent(channel.representation, channel.name) assert isinstance(unitary1 >> unitary2, Unitary) assert isinstance(unitary1 >> channel, Channel) @@ -76,7 +76,7 @@ def test_rshift(self): def test_repr(self): unitary1 = Dgate([0, 1], 1) - u_component = CircuitComponent._from_attributes(unitary1.representation, unitary1.name) + u_component = CircuitComponent(unitary1.representation, unitary1.name) assert repr(unitary1) == "Dgate(modes=[0, 1], name=Dgate, repr=PolyExpAnsatz)" assert repr(unitary1.to_fock(5)) == "Dgate(modes=[0, 1], name=Dgate, repr=ArrayAnsatz)" assert repr(u_component) == "CircuitComponent(modes=[0, 1], name=Dgate, repr=PolyExpAnsatz)" @@ -142,10 +142,10 @@ def test_init_from_bargmann(self): def test_rshift(self): unitary = Dgate([0, 1], 1) - u_component = CircuitComponent._from_attributes(unitary.representation, unitary.name) + u_component = CircuitComponent(unitary.representation, unitary.name) channel1 = Attenuator([1, 2], 0.9) channel2 = Attenuator([2, 3], 0.9) - ch_component = CircuitComponent._from_attributes(channel1.representation, channel1.name) + ch_component = CircuitComponent(channel1.representation, channel1.name) assert isinstance(channel1 >> unitary, Channel) assert isinstance(channel1 >> channel2, Channel) @@ -154,7 +154,7 @@ def test_rshift(self): def test_repr(self): channel1 = Attenuator([0, 1], 0.9) - ch_component = CircuitComponent._from_attributes(channel1.representation, channel1.name) + ch_component = CircuitComponent(channel1.representation, channel1.name) assert repr(channel1) == "Attenuator(modes=[0, 1], name=Att, repr=PolyExpAnsatz)" assert repr(ch_component) == "CircuitComponent(modes=[0, 1], name=Att, repr=PolyExpAnsatz)" From ab2c99f16c70f273b304eca1a00a66b968c24f50 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 23 Oct 2024 16:27:11 -0400 Subject: [PATCH 72/87] some cov --- mrmustard/physics/representations.py | 4 ---- tests/test_lab_dev/test_circuit_components.py | 22 ++++++++++++++----- .../test_transformations_base.py | 7 +++--- tests/test_physics/test_wires.py | 15 +++++++++++++ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index f7a057b87..c1afd0902 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -64,10 +64,6 @@ def from_ansatz(cls, ansatz: Ansatz): else: return cls(0) - @classmethod - def _missing_(cls, value): - return cls.NONETYPE - def __repr__(self) -> str: return self.name diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 2ca9202a0..074cd116d 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -41,6 +41,8 @@ Unitary, Sgate, Channel, + Operation, + Map, ) from mrmustard.physics.wires import Wires from mrmustard.physics.representations import Representation @@ -499,22 +501,32 @@ def test_quadrature_ket(self): back = Ket.from_quadrature([0], ket.quadrature_triple()) assert ket == back + def test_quadrature_channel(self): + C = Sgate([0], 0.5, 0.4) >> Dgate([0], 0.3, 0.2) >> Attenuator([0], 0.9) + back = Channel.from_quadrature([0], [0], C.quadrature_triple()) + assert C == back + def test_quadrature_dm(self): "tests that transforming to quadrature and back gives the same density matrix" dm = SqueezedVacuum([0], 0.4, 0.5) >> Dgate([0], 0.3, 0.2) >> Attenuator([0], 0.9) back = DM.from_quadrature([0], dm.quadrature_triple()) assert dm == back + def test_quadrature_map(self): + C = Sgate([0], 0.5, 0.4) >> Dgate([0], 0.3, 0.2) >> Attenuator([0], 0.9) + back = Map.from_quadrature([0], [0], C.quadrature_triple()) + assert C == back + + def test_quadrature_operation(self): + U = Sgate([0], 0.5, 0.4) >> Dgate([0], 0.3, 0.2) + back = Operation.from_quadrature([0], [0], U.quadrature_triple()) + assert U == back + def test_quadrature_unitary(self): U = Sgate([0], 0.5, 0.4) >> Dgate([0], 0.3, 0.2) back = Unitary.from_quadrature([0], [0], U.quadrature_triple()) assert U == back - def test_quadrature_channel(self): - C = Sgate([0], 0.5, 0.4) >> Dgate([0], 0.3, 0.2) >> Attenuator([0], 0.9) - back = Channel.from_quadrature([0], [0], C.quadrature_triple()) - assert C == back - @pytest.mark.parametrize("is_fock,widget_cls", [(False, Box), (True, HBox)]) @patch("mrmustard.lab_dev.circuit_components.display") def test_ipython_repr(self, mock_display, is_fock, widget_cls): diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index c36d2b220..3bbe1c51a 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -20,18 +20,19 @@ import pytest from mrmustard import math -from mrmustard.lab_dev.circuit_components import CircuitComponent -from mrmustard.lab_dev.transformations import ( +from mrmustard.lab_dev import ( Attenuator, + CircuitComponent, Channel, Dgate, Sgate, Identity, Unitary, Operation, + Vacuum, + BtoQ, ) from mrmustard.physics.wires import Wires -from mrmustard.lab_dev.states import Vacuum class TestOperation: diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index 1b55794ef..d74238d27 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -115,6 +115,21 @@ def test_ids_dicts(self): assert w.input.ids_dicts == d assert w.input.bra.ids_dicts == d + def test_ids_index_dicts(self): + w = Wires({0, 2, 1}, {6, 7, 8}, {3, 4}, {4}, {5}, {9}) + d = [ + {w.id: 0, w.id + 1: 1, w.id + 2: 2}, + {w.id + 3: 3, w.id + 4: 4, w.id + 5: 5}, + {w.id + 6: 6, w.id + 7: 7}, + {w.id + 8: 8}, + {w.id + 9: 9}, + {w.id + 10: 10}, + ] + + assert w.ids_index_dicts == d + assert w.input.ids_index_dicts == d + assert w.input.bra.ids_index_dicts == d + def test_adjoint(self): w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}) w_adj = w.adjoint From b8cfc7a603355b70c731266936d7d16b551d9048 Mon Sep 17 00:00:00 2001 From: Anthony Date: Wed, 23 Oct 2024 16:29:51 -0400 Subject: [PATCH 73/87] codefactor --- .../test_transformations/test_transformations_base.py | 1 - tests/test_physics/test_representations.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 3bbe1c51a..4f512b553 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -30,7 +30,6 @@ Unitary, Operation, Vacuum, - BtoQ, ) from mrmustard.physics.wires import Wires diff --git a/tests/test_physics/test_representations.py b/tests/test_physics/test_representations.py index f6cf96598..5fc785797 100644 --- a/tests/test_physics/test_representations.py +++ b/tests/test_physics/test_representations.py @@ -106,11 +106,11 @@ def test_to_bargmann(self, d_gate_rep): d_barg = d_fock.to_bargmann() assert d_fock.ansatz._original_abc_data == d_gate_rep.ansatz.triple assert d_barg == d_gate_rep - assert all([k[0] == RepEnum.BARGMANN for k in d_barg._idx_reps.values()]) + assert all((k[0] == RepEnum.BARGMANN for k in d_barg._idx_reps.values())) def test_to_fock(self, d_gate_rep): d_fock = d_gate_rep.to_fock(shape=(4, 6)) assert d_fock.ansatz == ArrayAnsatz( math.hermite_renormalized(*displacement_gate_Abc(x=0.1, y=0.1), shape=(4, 6)) ) - assert all([k[0] == RepEnum.FOCK for k in d_fock._idx_reps.values()]) + assert all((k[0] == RepEnum.FOCK for k in d_fock._idx_reps.values())) From b5c43340fd8986f63c9148d2dfefeb3877a228e1 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 24 Oct 2024 10:27:35 -0400 Subject: [PATCH 74/87] cleanup imports --- mrmustard/lab_dev/states/base.py | 12 +++++++----- mrmustard/lab_dev/states/dm.py | 17 +++++++++-------- mrmustard/lab_dev/states/ket.py | 16 ++++++++-------- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 5cc30074b..94cae80c1 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -36,6 +36,10 @@ import plotly.graph_objects as go from mrmustard import math, settings +from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz +from mrmustard.physics.bargmann_utils import ( + bargmann_Abc_to_phasespace_cov_means, +) from mrmustard.physics.fock_utils import quadrature_distribution from mrmustard.physics.wigner import wigner_discretized from mrmustard.utils.typing import ( @@ -44,11 +48,9 @@ ComplexVector, RealVector, ) -from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz -from mrmustard.physics.bargmann_utils import ( - bargmann_Abc_to_phasespace_cov_means, -) -from mrmustard.lab_dev import CircuitComponent, BtoPS + +from ..circuit_components import CircuitComponent +from ..circuit_components_utils import BtoPS __all__ = ["State"] diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index 202656f15..04fa9d28f 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -23,19 +23,20 @@ import warnings import numpy as np from IPython.display import display + from mrmustard import math, settings, widgets -from mrmustard.utils.typing import ComplexMatrix, ComplexVector, ComplexTensor, RealVector -from mrmustard.lab_dev.circuit_components import CircuitComponent -from mrmustard.lab_dev.states.base import State, _validate_operator, OperatorType -from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho -from mrmustard.lab_dev.circuit_components_utils import BtoQ, TraceOut -from mrmustard.lab_dev.utils import shape_check from mrmustard.math.lattice.strategies.vanilla import autoshape_numba - from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz +from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho +from mrmustard.physics.gaussian_integrals import complex_gaussian_integral_2 from mrmustard.physics.representations import Representation from mrmustard.physics.wires import Wires -from mrmustard.physics.gaussian_integrals import complex_gaussian_integral_2 +from mrmustard.utils.typing import ComplexMatrix, ComplexVector, ComplexTensor, RealVector + +from .base import State, _validate_operator, OperatorType +from ..circuit_components import CircuitComponent +from ..circuit_components_utils import BtoQ, TraceOut +from ..utils import shape_check __all__ = ["DM"] diff --git a/mrmustard/lab_dev/states/ket.py b/mrmustard/lab_dev/states/ket.py index d6d0e4a49..a8cee4d75 100644 --- a/mrmustard/lab_dev/states/ket.py +++ b/mrmustard/lab_dev/states/ket.py @@ -23,8 +23,14 @@ import warnings import numpy as np from IPython.display import display + from mrmustard import math, settings, widgets +from mrmustard.math.lattice.strategies.vanilla import autoshape_numba +from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz +from mrmustard.physics.bargmann_utils import wigner_to_bargmann_psi from mrmustard.physics.gaussian import purity +from mrmustard.physics.representations import Representation +from mrmustard.physics.wires import Wires from mrmustard.utils.typing import ( ComplexMatrix, ComplexVector, @@ -33,18 +39,12 @@ Scalar, Batch, ) -from mrmustard.lab_dev.states.base import _validate_operator, OperatorType -from mrmustard.physics.bargmann_utils import wigner_to_bargmann_psi -from mrmustard.lab_dev.utils import shape_check -from mrmustard.math.lattice.strategies.vanilla import autoshape_numba -from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz -from mrmustard.physics.wires import Wires -from mrmustard.physics.representations import Representation -from .base import State +from .base import State, _validate_operator, OperatorType from .dm import DM from ..circuit_components import CircuitComponent from ..circuit_components_utils import BtoQ, TraceOut +from ..utils import shape_check __all__ = ["Ket"] From 5e98362373684616182ad6f3ee19946729e2ca3f Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 24 Oct 2024 10:32:03 -0400 Subject: [PATCH 75/87] cleanup --- mrmustard/lab_dev/states/dm.py | 2 +- tests/test_lab_dev/test_states/test_dm.py | 10 +++------- .../test_transformations/test_gaussrandnoise.py | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index 04fa9d28f..8a033e571 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -399,7 +399,7 @@ def __getitem__(self, modes: int | Sequence[int]) -> State: idxz_conj = [i + len(self.modes) for i, m in enumerate(self.modes) if m not in modes] ansatz = self.ansatz.trace(idxz, idxz_conj) - return self._from_attributes(Representation(ansatz, wires), self.name) + return DM(Representation(ansatz, wires), self.name) def __rshift__(self, other: CircuitComponent) -> CircuitComponent: r""" diff --git a/tests/test_lab_dev/test_states/test_dm.py b/tests/test_lab_dev/test_states/test_dm.py index 12b6ff2bd..471321fcf 100644 --- a/tests/test_lab_dev/test_states/test_dm.py +++ b/tests/test_lab_dev/test_states/test_dm.py @@ -14,7 +14,7 @@ """Tests for the density matrix.""" -# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement +# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement from itertools import product import numpy as np @@ -346,13 +346,9 @@ def test_expectation_error(self): def test_rshift(self): ket = Coherent([0, 1], 1) unitary = Dgate([0], 1) - u_component = CircuitComponent._from_attributes( - unitary.representation, unitary.name - ) # pylint: disable=protected-access + u_component = CircuitComponent(unitary.representation, unitary.name) channel = Attenuator([1], 1) - ch_component = CircuitComponent._from_attributes( - channel.representation, channel.name - ) # pylint: disable=protected-access + ch_component = CircuitComponent(channel.representation, channel.name) dm = ket >> channel diff --git a/tests/test_lab_dev/test_transformations/test_gaussrandnoise.py b/tests/test_lab_dev/test_transformations/test_gaussrandnoise.py index 6a9d76f1e..9a955adc3 100644 --- a/tests/test_lab_dev/test_transformations/test_gaussrandnoise.py +++ b/tests/test_lab_dev/test_transformations/test_gaussrandnoise.py @@ -14,7 +14,7 @@ """Tests for the ``GaussRandNoise`` class.""" -# pylint: disable=protected-access, missing-function-docstring, expression-not-assigned +# pylint: disable=missing-function-docstring, expression-not-assigned import numpy as np From 01f9502e1de69bb71786346f869e9af85ce4fb13 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 24 Oct 2024 10:41:27 -0400 Subject: [PATCH 76/87] doc cleanup --- mrmustard/lab_dev/circuit_components.py | 17 ++++++----------- mrmustard/lab_dev/states/dm.py | 5 ----- mrmustard/lab_dev/states/ket.py | 5 ----- mrmustard/lab_dev/transformations/base.py | 19 ------------------- 4 files changed, 6 insertions(+), 40 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index e2d8549f8..7a6c1e732 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -350,16 +350,12 @@ def _from_attributes( name: str | None = None, ) -> CircuitComponent: r""" - Initializes a circuit component from a ``Representation``, a set of ``Wires``, a name. - It differs from the __init__ in that it takes a set of wires directly. - Note there are deliberately no checks to ensure types and wires are compatible - in the standard way (e.g. one could pass a representation for a single mode ket - and wires for a two-mode one). - - The return type is the closest parent among the types ``Ket``, ``DM``, ``Unitary``, - ``Operation``, ``Channel``, and ``Map``. This is to ensure the right properties - are used when calling methods on the returned object, e.g. when adding two - coherent states we don't get a generic ``CircuitComponent`` but a ``Ket``: + Initializes a circuit component from a ``Representation`` and a name. + It differs from the __init__ in that the return type is the closest parent + among the types ``Ket``, ``DM``, ``Unitary``, ``Operation``, ``Channel``, + and ``Map``. This is to ensure the right properties are used when calling + methods on the returned object, e.g. when adding two coherent states we + don't get a generic ``CircuitComponent`` but a ``Ket``: .. code-block:: >>> from mrmustard.lab_dev import Coherent, Ket @@ -368,7 +364,6 @@ def _from_attributes( Args: representation: A representation for this circuit component. - wires: The wires of this component. name: The name for this component (optional). Returns: diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index 8a033e571..71a0e30c3 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -44,11 +44,6 @@ class DM(State): r""" Base class for density matrices. - - Args: - modes: The modes of this density matrix. - ansatz: The ansatz of this density matrix. - name: The name of this density matrix. """ short_name = "DM" diff --git a/mrmustard/lab_dev/states/ket.py b/mrmustard/lab_dev/states/ket.py index a8cee4d75..1a5f99254 100644 --- a/mrmustard/lab_dev/states/ket.py +++ b/mrmustard/lab_dev/states/ket.py @@ -52,11 +52,6 @@ class Ket(State): r""" Base class for all Hilbert space vectors. - - Arguments: - modes: The modes of this ket. - ansatz: The ansatz of this ket. - name: The name of this ket. """ short_name = "Ket" diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index 644b999aa..afc3980c0 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -190,13 +190,6 @@ def from_quadrature( class Unitary(Operation): r""" Base class for all unitary transformations. - Note the default initializer is in the parent class ``Operation``. - - Arguments: - modes_out: The output modes of this Unitary. - modes_in: The input modes of this Unitary. - ansatz: The ansatz of this Unitary. - name: The name of this Unitary. """ short_name = "U" @@ -306,12 +299,6 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: class Map(Transformation): r""" A CircuitComponent more general than Channels, which are CPTP Maps. - - Arguments: - modes_out: The output modes of this Map. - modes_in: The input modes of this Map. - ansatz: The ansatz of this Map. - name: The name of this Map. """ short_name = "Map" @@ -364,12 +351,6 @@ def from_quadrature( class Channel(Map): r""" Base class for all CPTP channels. - - Arguments: - modes_out: The output modes of this Channel. - modes_in: The input modes of this Channel. - ansatz: The ansatz of this Channel. - name: The name of this Channel """ short_name = "Ch" From 00b714ed8d47766fae93cd22cb64ccff9a4e44ae Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 24 Oct 2024 10:59:32 -0400 Subject: [PATCH 77/87] AI cleanup --- mrmustard/lab_dev/circuit_components.py | 8 ++-- mrmustard/physics/ansatz/polyexp_ansatz.py | 27 +++++++----- mrmustard/physics/representations.py | 50 ++++++++++++++-------- 3 files changed, 51 insertions(+), 34 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 7a6c1e732..fba4fb48c 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -561,12 +561,12 @@ def _light_copy(self, wires: Wires | None = None) -> CircuitComponent: return instance def _rshift_return( - self, ret: CircuitComponent | np.ndarray | complex + self, result: CircuitComponent | np.ndarray | complex ) -> CircuitComponent | np.ndarray | complex: "internal convenience method for right-shift, to return the right type of object" - if len(ret.wires) > 0: - return ret - scalar = ret.ansatz.scalar + if len(result.wires) > 0: + return result + scalar = result.ansatz.scalar return math.sum(scalar) if not settings.UNSAFE_ZIP_BATCH else scalar def __add__(self, other: CircuitComponent) -> CircuitComponent: diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index 8631b42be..f14c5fc6a 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -202,10 +202,10 @@ def from_dict(cls, data: dict[str, ArrayLike]) -> PolyExpAnsatz: @classmethod def from_function(cls, fn: Callable, **kwargs: Any) -> PolyExpAnsatz: - ret = cls(None, None, None) - ret._fn = fn - ret._kwargs = kwargs - return ret + ansatz = cls(None, None, None) + ansatz._fn = fn + ansatz._kwargs = kwargs + return ansatz def decompose_ansatz(self) -> PolyExpAnsatz: r""" @@ -432,16 +432,21 @@ def _call_none(self, z: Batch[Vector]) -> PolyExpAnsatz: batch_abc = self.batch_size batch_arg = z.shape[0] - Abc = [] if batch_abc == 1 and batch_arg > 1: - for i in range(batch_arg): - Abc.append(self._call_none_single(self.A[0], self.b[0], self.c[0], z[i])) + Abc = [ + self._call_none_single(self.A[0], self.b[0], self.c[0], z[i]) + for i in range(batch_arg) + ] elif batch_arg == 1 and batch_abc > 1: - for i in range(batch_abc): - Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[0])) + Abc = [ + self._call_none_single(self.A[i], self.b[i], self.c[i], z[0]) + for i in range(batch_abc) + ] elif batch_abc == batch_arg: - for i in range(batch_abc): - Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[i])) + Abc = [ + self._call_none_single(self.A[i], self.b[i], self.c[i], z[i]) + for i in range(batch_abc) + ] else: raise ValueError( "Batch size of the ansatz and argument must match or one of the batch sizes must be 1." diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index c1afd0902..2cc447d0f 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -95,10 +95,12 @@ def __init__( ) -> None: self._ansatz = ansatz - if not isinstance(wires, Wires): - modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket = ( - [tuple(elem) for elem in wires] if wires else [(), (), (), ()] - ) + if wires is None: + wires = Wires(set(), set(), set(), set()) + elif not isinstance(wires, Wires): + modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket = [ + tuple(elem) for elem in wires + ] wires = Wires( set(modes_out_bra), set(modes_in_bra), @@ -298,6 +300,30 @@ def _matmul_indices(self, other: Representation) -> tuple[tuple[int, ...], tuple idx_zconj += other.wires.ket.input[ket_modes].indices return idx_z, idx_zconj + def _get_idx_reps(self, wires_result: Wires, other: Representation): + r""" + Returns the new representation mappings when contracting ``self`` and ``other``. + + Args: + wires_result: The resulting wires after contraction. + other: The representation contracting with. + """ + idx_reps = {} + for id in wires_result.ids: + if id in other.wires.ids: + temp_rep = other + else: + temp_rep = self + for t in (0, 1, 2, 3, 4, 5): + try: + idx = temp_rep.wires.ids_index_dicts[t][id] + n_idx = wires_result.ids_index_dicts[t][id] + idx_reps[n_idx] = temp_rep._idx_reps[idx] + break + except KeyError: + continue + return idx_reps + def __eq__(self, other): if isinstance(other, Representation): return ( @@ -320,19 +346,5 @@ def __matmul__(self, other: Representation): rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] rep = rep.reorder(perm) if perm else rep - - idx_reps = {} - for id in wires_result.ids: - if id in other.wires.ids: - temp_rep = other - else: - temp_rep = self - for t in (0, 1, 2, 3, 4, 5): - try: - idx = temp_rep.wires.ids_index_dicts[t][id] - n_idx = wires_result.ids_index_dicts[t][id] - idx_reps[n_idx] = temp_rep._idx_reps[idx] - break - except KeyError: - continue + idx_reps = self._get_idx_reps(wires_result, other) return Representation(rep, wires_result, idx_reps) From b11e463f2518c5e95ba7344bd026d4826470298f Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 24 Oct 2024 11:08:32 -0400 Subject: [PATCH 78/87] remove tuple --- .../circuit_components_utils/b_to_ps.py | 4 ++-- .../circuit_components_utils/b_to_q.py | 4 ++-- mrmustard/physics/representations.py | 4 ++-- tests/test_physics/test_representations.py | 24 +++++++++---------- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index 3798661a7..260bee7aa 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -54,9 +54,9 @@ def __init__( ), ).representation for i in self.wires.input.indices: - self.representation._idx_reps[i] = (RepEnum.BARGMANN, None, tuple()) + self.representation._idx_reps[i] = (RepEnum.BARGMANN, None) for i in self.wires.output.indices: - self.representation._idx_reps[i] = (RepEnum.PHASESPACE, float(self.s.value), tuple()) + self.representation._idx_reps[i] = (RepEnum.PHASESPACE, float(self.s.value)) def inverse(self): ret = BtoPS(self.modes, self.s) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index 81db9afd5..0c75b0d2d 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -54,9 +54,9 @@ def __init__( ), ).representation for i in self.wires.input.indices: - self.representation._idx_reps[i] = (RepEnum.BARGMANN, None, tuple()) + self.representation._idx_reps[i] = (RepEnum.BARGMANN, None) for i in self.wires.output.indices: - self.representation._idx_reps[i] = (RepEnum.QUADRATURE, float(self.phi.value), tuple()) + self.representation._idx_reps[i] = (RepEnum.QUADRATURE, float(self.phi.value)) def inverse(self): ret = BtoQ(self.modes, self.phi) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 2cc447d0f..c903705a2 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -76,7 +76,7 @@ class Representation: of each wire's representation. The dictionary to keep track of representations maps the indices of the wires - to a tuple of the form ``(RepEnum, parameter, (coupled_indices, ...))``. + to a tuple of the form ``(RepEnum, parameter)``. Args: ansatz: An ansatz for this representation. @@ -130,7 +130,7 @@ def __init__( self._wires = wires self._idx_reps = idx_reps or dict.fromkeys( - wires.indices, (RepEnum.from_ansatz(ansatz), None, tuple()) + wires.indices, (RepEnum.from_ansatz(ansatz), None) ) @property diff --git a/tests/test_physics/test_representations.py b/tests/test_physics/test_representations.py index 5fc785797..291175e8a 100644 --- a/tests/test_physics/test_representations.py +++ b/tests/test_physics/test_representations.py @@ -49,9 +49,9 @@ def btoq_rep(self): wires = Wires((), (), set([0]), set([0])) idx_reps = {} for i in wires.input.indices: - idx_reps[i] = (RepEnum.BARGMANN, None, tuple()) + idx_reps[i] = (RepEnum.BARGMANN, None) for i in wires.output.indices: - idx_reps[i] = (RepEnum.QUADRATURE, float(0.2), tuple()) + idx_reps[i] = (RepEnum.QUADRATURE, float(0.2)) return Representation(ansatz, wires, idx_reps) @pytest.mark.parametrize("triple", [Abc_n1, Abc_n2, Abc_n3]) @@ -66,39 +66,37 @@ def test_init(self, triple): rep = Representation(ansatz, wires) assert rep.ansatz == ansatz assert rep.wires == wires - assert rep._idx_reps == dict.fromkeys( - wires.indices, (RepEnum.from_ansatz(ansatz), None, tuple()) - ) + assert rep._idx_reps == dict.fromkeys(wires.indices, (RepEnum.from_ansatz(ansatz), None)) @pytest.mark.parametrize("triple", [Abc_n2]) def test_adjoint_idx_reps(self, triple): ansatz = PolyExpAnsatz(*triple) wires = Wires(modes_out_bra=set([0]), modes_out_ket=set([0])) - idx_reps = {0: (RepEnum.BARGMANN, None, tuple()), 1: (RepEnum.QUADRATURE, 0.1, tuple())} + idx_reps = {0: (RepEnum.BARGMANN, None), 1: (RepEnum.QUADRATURE, 0.1)} rep = Representation(ansatz, wires, idx_reps) adj_rep = rep.adjoint assert adj_rep._idx_reps == { - 1: (RepEnum.BARGMANN, None, tuple()), - 0: (RepEnum.QUADRATURE, 0.1, tuple()), + 1: (RepEnum.BARGMANN, None), + 0: (RepEnum.QUADRATURE, 0.1), } @pytest.mark.parametrize("triple", [Abc_n2]) def test_dual_idx_reps(self, triple): ansatz = PolyExpAnsatz(*triple) wires = Wires(modes_out_bra=set([0]), modes_in_bra=set([0])) - idx_reps = {0: (RepEnum.BARGMANN, None, tuple()), 1: (RepEnum.QUADRATURE, 0.1, tuple())} + idx_reps = {0: (RepEnum.BARGMANN, None), 1: (RepEnum.QUADRATURE, 0.1)} rep = Representation(ansatz, wires, idx_reps) adj_rep = rep.dual assert adj_rep._idx_reps == { - 1: (RepEnum.BARGMANN, None, tuple()), - 0: (RepEnum.QUADRATURE, 0.1, tuple()), + 1: (RepEnum.BARGMANN, None), + 0: (RepEnum.QUADRATURE, 0.1), } def test_matmul_btoq(self, d_gate_rep, btoq_rep): q_dgate = d_gate_rep @ btoq_rep assert q_dgate._idx_reps == { - 0: (RepEnum.QUADRATURE, 0.2, ()), - 1: (RepEnum.BARGMANN, None, ()), + 0: (RepEnum.QUADRATURE, 0.2), + 1: (RepEnum.BARGMANN, None), } def test_to_bargmann(self, d_gate_rep): From a06cc3ba889125fdff43810988abbbdb7399897c Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 24 Oct 2024 11:55:59 -0400 Subject: [PATCH 79/87] coverage --- .codecov.yml | 2 +- tests/test_lab_dev/test_circuit_components.py | 5 +++++ .../test_transformations_base.py | 15 +++++++++++++++ .../test_physics/test_ansatz/test_array_ansatz.py | 5 +++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/.codecov.yml b/.codecov.yml index d1515b335..5423eb7b4 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -5,5 +5,5 @@ coverage: target: 89% patch: default: - target: 100% + target: 99% threshold: 0% diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 074cd116d..24620c7ad 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -191,6 +191,11 @@ def test_on_error(self): with pytest.raises(ValueError): Vacuum([1, 2]).on([3]) + def test_to_bargmann_unitary(self): + d = Dgate([1], x=0.1, y=0.1) + fock = Unitary(d.representation.to_fock(shape=(4, 6))) + assert fock.to_bargmann() == d + def test_to_fock_ket(self): vac = Vacuum([1, 2]) vac_fock = vac.to_fock(shape=[1, 2]) diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 4f512b553..6e2380ac5 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -29,6 +29,7 @@ Identity, Unitary, Operation, + Map, Vacuum, ) from mrmustard.physics.wires import Wires @@ -113,6 +114,20 @@ def test_random(self): assert (u >> u.dual) == Identity(modes) +class TestMap: + r""" + Tests the Map class. + """ + + def test_init_from_bargmann(self): + A = np.arange(16).reshape(4, 4) + b = np.array([0, 1, 2, 3]) + c = 1 + map = Map.from_bargmann([0], [0], (A, b, c), "my_map") + assert np.allclose(map.ansatz.A[None, ...], A) + assert np.allclose(map.ansatz.b[None, ...], b) + + class TestChannel: r""" Tests for the ``Channel`` class. diff --git a/tests/test_physics/test_ansatz/test_array_ansatz.py b/tests/test_physics/test_ansatz/test_array_ansatz.py index 17f32703c..7c7d8adc1 100644 --- a/tests/test_physics/test_ansatz/test_array_ansatz.py +++ b/tests/test_physics/test_ansatz/test_array_ansatz.py @@ -195,6 +195,11 @@ def test_sum_batch(self): assert fock_collapsed.array.shape == (1, 5, 7, 8) assert np.allclose(fock_collapsed.array, np.sum(self.array2578, axis=0)) + def test_to_from_dict(self): + array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) + fock1 = ArrayAnsatz(array1, batched=True) + assert ArrayAnsatz.from_dict(fock1.to_dict()) == fock1 + def test_trace(self): array1 = math.astensor(np.random.random((2, 5, 5, 1, 7, 4, 1, 7, 3))) fock1 = ArrayAnsatz(array1, batched=True) From 53eba3265bb8ceb37d08ea69c6a8a33d28e31ff0 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 24 Oct 2024 12:28:12 -0400 Subject: [PATCH 80/87] doc --- mrmustard/physics/representations.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index c903705a2..61165eff6 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -39,8 +39,7 @@ class RepEnum(Enum): r""" - An enum to represent what representation a wire is in. Also keeps track - of representation conversions. + An enum to represent what representation a wire is in. """ NONETYPE = 0 From fbb50832de55f1b6727eb41dab82c9ed36d911f1 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 24 Oct 2024 12:29:35 -0400 Subject: [PATCH 81/87] rename --- mrmustard/physics/representations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index 61165eff6..f1aabbec6 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -299,7 +299,7 @@ def _matmul_indices(self, other: Representation) -> tuple[tuple[int, ...], tuple idx_zconj += other.wires.ket.input[ket_modes].indices return idx_z, idx_zconj - def _get_idx_reps(self, wires_result: Wires, other: Representation): + def _matmul_idx_reps(self, wires_result: Wires, other: Representation): r""" Returns the new representation mappings when contracting ``self`` and ``other``. @@ -345,5 +345,5 @@ def __matmul__(self, other: Representation): rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj] rep = rep.reorder(perm) if perm else rep - idx_reps = self._get_idx_reps(wires_result, other) + idx_reps = self._matmul_idx_reps(wires_result, other) return Representation(rep, wires_result, idx_reps) From 22e7696839a779ac6716f5d65766c82580b68ff7 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 28 Oct 2024 13:56:29 -0400 Subject: [PATCH 82/87] cr --- mrmustard/lab_dev/circuit_components.py | 12 ++++++------ tests/test_lab_dev/test_circuit_components.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index fba4fb48c..d0e1297df 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -82,10 +82,10 @@ def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: serializable = {"class": f"{cls.__module__}.{cls.__qualname__}"} params = signature(cls).parameters if "name" in params: # assume abstract type, serialize the representation - rep_cls = type(self.ansatz) + ansatz_cls = type(self.ansatz) serializable["name"] = self.name serializable["wires"] = self.wires.sorted_args - serializable["rep_class"] = f"{rep_cls.__module__}.{rep_cls.__qualname__}" + serializable["ansatz_cls"] = f"{ansatz_cls.__module__}.{ansatz_cls.__qualname__}" return serializable, self.ansatz.to_dict() # handle modes parameter @@ -110,10 +110,10 @@ def _deserialize(cls, data: dict) -> CircuitComponent: r""" Deserialization when within a circuit. """ - if "rep_class" in data: - rep_class, wires, name = map(data.pop, ["rep_class", "wires", "name"]) - rep = locate(rep_class).from_dict(data) - return cls._from_attributes(Representation(rep, Wires(*map(set, wires))), name=name) + if "ansatz_cls" in data: + ansatz_cls, wires, name = map(data.pop, ["ansatz_cls", "wires", "name"]) + ansatz = locate(ansatz_cls).from_dict(data) + return cls._from_attributes(Representation(ansatz, Wires(*map(set, wires))), name=name) return cls(**data) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 24620c7ad..3abefac46 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -566,7 +566,7 @@ def test_serialize_default_behaviour(self): assert kwargs == { "class": f"{CircuitComponent.__module__}.CircuitComponent", "wires": cc.wires.sorted_args, - "rep_class": f"{PolyExpAnsatz.__module__}.PolyExpAnsatz", + "ansatz_cls": f"{PolyExpAnsatz.__module__}.PolyExpAnsatz", "name": name, } assert arrays == {"A": ansatz.A, "b": ansatz.b, "c": ansatz.c} From c3f29c2fad889043e6bfb273d351f1e809459dc7 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 29 Oct 2024 09:36:23 -0400 Subject: [PATCH 83/87] cr --- .../circuit_components_utils/b_to_ps.py | 2 +- .../circuit_components_utils/b_to_q.py | 2 +- mrmustard/lab_dev/samplers.py | 4 +-- mrmustard/lab_dev/states/base.py | 2 +- mrmustard/lab_dev/states/coherent.py | 2 +- .../lab_dev/states/displaced_squeezed.py | 2 +- mrmustard/lab_dev/states/dm.py | 14 ++++---- mrmustard/lab_dev/states/ket.py | 16 ++++----- mrmustard/lab_dev/states/number.py | 2 +- .../lab_dev/states/quadrature_eigenstate.py | 2 +- mrmustard/lab_dev/states/sauron.py | 2 +- mrmustard/lab_dev/states/squeezed_vacuum.py | 2 +- mrmustard/lab_dev/states/thermal.py | 2 +- .../states/two_mode_squeezed_vacuum.py | 2 +- mrmustard/lab_dev/states/vacuum.py | 2 +- .../lab_dev/transformations/amplifier.py | 2 +- .../lab_dev/transformations/attenuator.py | 2 +- mrmustard/lab_dev/transformations/base.py | 34 +++++++++---------- mrmustard/lab_dev/transformations/bsgate.py | 2 +- mrmustard/lab_dev/transformations/cft.py | 2 +- mrmustard/lab_dev/transformations/dgate.py | 2 +- .../lab_dev/transformations/fockdamping.py | 2 +- .../lab_dev/transformations/gaussrandnoise.py | 2 +- mrmustard/lab_dev/transformations/ggate.py | 2 +- mrmustard/lab_dev/transformations/identity.py | 2 +- mrmustard/lab_dev/transformations/rgate.py | 2 +- mrmustard/lab_dev/transformations/s2gate.py | 2 +- mrmustard/lab_dev/transformations/sgate.py | 2 +- mrmustard/physics/ansatz/__init__.py | 2 +- mrmustard/physics/representations.py | 2 +- tests/test_lab_dev/test_states/test_dm.py | 2 +- tests/test_lab_dev/test_states/test_ket.py | 4 +-- .../test_transformations_base.py | 4 +-- 33 files changed, 65 insertions(+), 65 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index 260bee7aa..22e96dfbf 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -46,7 +46,7 @@ def __init__( ): super().__init__(name="BtoPS") self._add_parameter(make_parameter(False, s, "s", (None, None))) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function( diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index 0c75b0d2d..3cf4d2d5a 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -46,7 +46,7 @@ def __init__( ): super().__init__(name="BtoQ") self._add_parameter(make_parameter(False, phi, "phi", (None, None))) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function( diff --git a/mrmustard/lab_dev/samplers.py b/mrmustard/lab_dev/samplers.py index c088061e8..8069b9622 100644 --- a/mrmustard/lab_dev/samplers.py +++ b/mrmustard/lab_dev/samplers.py @@ -229,8 +229,8 @@ def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) - for unique_sample, counts in zip(unique_samples, counts): quad = np.array([[unique_sample] + [None] * (state.n_modes - 1)]) quad = quad if isinstance(state, Ket) else math.tile(quad, (1, 2)) - reduced_rep = (state >> BtoQ([initial_mode], phi=self._phi)).ansatz(quad) - reduced_state = state.__class__.from_bargmann(state.modes[1:], reduced_rep.triple) + reduced_ansatz = (state >> BtoQ([initial_mode], phi=self._phi)).ansatz(quad) + reduced_state = state.from_bargmann(state.modes[1:], reduced_ansatz.triple) prob = probs[initial_samples.tolist().index(unique_sample)] / self._step norm = math.sqrt(prob) if isinstance(state, Ket) else prob normalized_reduced_state = reduced_state / norm diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 94cae80c1..69d7b3ac4 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -239,7 +239,7 @@ def from_fock( @classmethod @abstractmethod - def from_modes( + def from_ansatz( cls, modes: Sequence[int], ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, diff --git a/mrmustard/lab_dev/states/coherent.py b/mrmustard/lab_dev/states/coherent.py index 97bef7c29..a56ad5732 100644 --- a/mrmustard/lab_dev/states/coherent.py +++ b/mrmustard/lab_dev/states/coherent.py @@ -83,7 +83,7 @@ def __init__( self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes=modes, ansatz=PolyExpAnsatz.from_function(fn=triples.coherent_state_Abc, x=self.x, y=self.y), ).representation diff --git a/mrmustard/lab_dev/states/displaced_squeezed.py b/mrmustard/lab_dev/states/displaced_squeezed.py index ba850acfc..7691fa48a 100644 --- a/mrmustard/lab_dev/states/displaced_squeezed.py +++ b/mrmustard/lab_dev/states/displaced_squeezed.py @@ -85,7 +85,7 @@ def __init__( self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes=modes, ansatz=PolyExpAnsatz.from_function( fn=triples.displaced_squeezed_vacuum_state_Abc, diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index 71a0e30c3..15c2a7a91 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -114,7 +114,7 @@ def from_bargmann( triple: tuple[ComplexMatrix, ComplexVector, complex], name: str | None = None, ) -> State: - return DM.from_modes(modes, PolyExpAnsatz(*triple), name) + return DM.from_ansatz(modes, PolyExpAnsatz(*triple), name) @classmethod def from_fock( @@ -124,10 +124,10 @@ def from_fock( name: str | None = None, batched: bool = False, ) -> State: - return DM.from_modes(modes, ArrayAnsatz(array, batched), name) + return DM.from_ansatz(modes, ArrayAnsatz(array, batched), name) @classmethod - def from_modes( + def from_ansatz( cls, modes: Sequence[int], ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, @@ -136,7 +136,7 @@ def from_modes( modes = set(modes) if ansatz and ansatz.num_vars != 2 * len(modes): raise ValueError( - f"Expected a representation with {2*len(modes)} variables, found {ansatz.num_vars}." + f"Expected an ansatz with {2*len(modes)} variables, found {ansatz.num_vars}." ) wires = Wires(modes_out_bra=modes, modes_out_ket=modes) return DM(Representation(ansatz, wires), name) @@ -165,7 +165,7 @@ def from_phase_space( cov = math.astensor(cov) means = math.astensor(means) shape_check(cov, means, 2 * len(modes), "Phase space") - return coeff * DM.from_modes( + return coeff * DM.from_ansatz( modes, PolyExpAnsatz.from_function(fn=wigner_to_bargmann_rho, cov=cov, means=means), name, @@ -197,8 +197,8 @@ def from_quadrature( with the number of modes. """ QtoB = BtoQ(modes, phi).inverse() - Q = DM.from_modes(modes, PolyExpAnsatz(*triple)) - return DM.from_modes(modes, (Q >> QtoB).ansatz, name) + Q = DM.from_ansatz(modes, PolyExpAnsatz(*triple)) + return DM.from_ansatz(modes, (Q >> QtoB).ansatz, name) @classmethod def random(cls, modes: Sequence[int], m: int | None = None, max_r: float = 1.0) -> DM: diff --git a/mrmustard/lab_dev/states/ket.py b/mrmustard/lab_dev/states/ket.py index 1a5f99254..aa5c41cb4 100644 --- a/mrmustard/lab_dev/states/ket.py +++ b/mrmustard/lab_dev/states/ket.py @@ -94,7 +94,7 @@ def from_bargmann( triple: tuple[ComplexMatrix, ComplexVector, complex], name: str | None = None, ) -> State: - return Ket.from_modes(modes, PolyExpAnsatz(*triple), name) + return Ket.from_ansatz(modes, PolyExpAnsatz(*triple), name) @classmethod def from_fock( @@ -104,10 +104,10 @@ def from_fock( name: str | None = None, batched: bool = False, ) -> State: - return Ket.from_modes(modes, ArrayAnsatz(array, batched), name) + return Ket.from_ansatz(modes, ArrayAnsatz(array, batched), name) @classmethod - def from_modes( + def from_ansatz( cls, modes: Sequence[int], ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, @@ -116,7 +116,7 @@ def from_modes( modes = set(modes) if ansatz and ansatz.num_vars != len(modes): raise ValueError( - f"Expected a representation with {len(modes)} variables, found {ansatz.num_vars}." + f"Expected an ansatz with {len(modes)} variables, found {ansatz.num_vars}." ) wires = Wires(modes_out_ket=modes) return Ket(Representation(ansatz, wires), name) @@ -138,7 +138,7 @@ def from_phase_space( if p < 1.0 - atol_purity: msg = f"Cannot initialize a Ket: purity is {p:.5f} (must be at least 1.0-{atol_purity})." raise ValueError(msg) - return Ket.from_modes( + return Ket.from_ansatz( modes, coeff * PolyExpAnsatz.from_function(fn=wigner_to_bargmann_psi, cov=cov, means=means), name, @@ -153,8 +153,8 @@ def from_quadrature( name: str | None = None, ) -> State: QtoB = BtoQ(modes, phi).inverse() - Q = Ket.from_modes(modes, PolyExpAnsatz(*triple)) - return Ket.from_modes(modes, (Q >> QtoB).ansatz, name) + Q = Ket.from_ansatz(modes, PolyExpAnsatz(*triple)) + return Ket.from_ansatz(modes, (Q >> QtoB).ansatz, name) @classmethod def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Ket: @@ -188,7 +188,7 @@ def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Ket: S = math.conj(math.transpose(transformation)) @ S @ transformation S_1 = S[:m, :m] S_2 = S[:m, m:] - A = S_2 @ math.conj(math.inv(S_1)) # use solve for inverse + A = math.transpose(math.solve(math.dagger(S_1), math.transpose(S_2))) b = math.zeros(m, dtype=A.dtype) psi = cls.from_bargmann(modes, [[A], [b], [complex(1)]]) return psi.normalize() diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index f85dc1a41..9f55c72b5 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -72,7 +72,7 @@ def __init__( ns, cs = list(reshape_params(len(modes), n=n, cutoffs=n if cutoffs is None else cutoffs)) self._add_parameter(make_parameter(False, ns, "n", (None, None), dtype="int64")) self._add_parameter(make_parameter(False, cs, "cutoffs", (None, None))) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes=modes, ansatz=ArrayAnsatz.from_function( fock_state, n=self.n.value, cutoffs=self.cutoffs.value diff --git a/mrmustard/lab_dev/states/quadrature_eigenstate.py b/mrmustard/lab_dev/states/quadrature_eigenstate.py index 4ec9b999c..29acac3ad 100644 --- a/mrmustard/lab_dev/states/quadrature_eigenstate.py +++ b/mrmustard/lab_dev/states/quadrature_eigenstate.py @@ -70,7 +70,7 @@ def __init__( self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) self.manual_shape = (50,) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes=modes, ansatz=PolyExpAnsatz.from_function( fn=triples.quadrature_eigenstates_Abc, x=self.x, phi=self.phi diff --git a/mrmustard/lab_dev/states/sauron.py b/mrmustard/lab_dev/states/sauron.py index 81fae9e0f..e4476f2a8 100644 --- a/mrmustard/lab_dev/states/sauron.py +++ b/mrmustard/lab_dev/states/sauron.py @@ -44,7 +44,7 @@ def __init__(self, modes: Sequence[int], n: int, epsilon: float = 0.1): self._add_parameter(make_parameter(False, n, "n", (None, None), dtype="int64")) self._add_parameter(make_parameter(False, epsilon, "epsilon", (None, None))) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes=modes, ansatz=PolyExpAnsatz.from_function( triples.sauron_state_Abc, n=self.n.value, epsilon=self.epsilon.value diff --git a/mrmustard/lab_dev/states/squeezed_vacuum.py b/mrmustard/lab_dev/states/squeezed_vacuum.py index 4078813e7..9ed0523cc 100644 --- a/mrmustard/lab_dev/states/squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/squeezed_vacuum.py @@ -70,7 +70,7 @@ def __init__( self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes=modes, ansatz=PolyExpAnsatz.from_function( fn=triples.squeezed_vacuum_state_Abc, r=self.r, phi=self.phi diff --git a/mrmustard/lab_dev/states/thermal.py b/mrmustard/lab_dev/states/thermal.py index 27d88ca86..d7148807b 100644 --- a/mrmustard/lab_dev/states/thermal.py +++ b/mrmustard/lab_dev/states/thermal.py @@ -61,7 +61,7 @@ def __init__( super().__init__(name="Thermal") (nbars,) = list(reshape_params(len(modes), nbar=nbar)) self._add_parameter(make_parameter(nbar_trainable, nbars, "nbar", nbar_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes=modes, ansatz=PolyExpAnsatz.from_function(fn=triples.thermal_state_Abc, nbar=self.nbar), ).representation diff --git a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py index 458ebe4a7..9f93dc7ac 100644 --- a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py @@ -66,7 +66,7 @@ def __init__( rs, phis = list(reshape_params(int(len(modes) / 2), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes=modes, ansatz=PolyExpAnsatz.from_function( fn=triples.two_mode_squeezed_vacuum_state_Abc, r=self.r, phi=self.phi diff --git a/mrmustard/lab_dev/states/vacuum.py b/mrmustard/lab_dev/states/vacuum.py index abddd9244..7b0e88da7 100644 --- a/mrmustard/lab_dev/states/vacuum.py +++ b/mrmustard/lab_dev/states/vacuum.py @@ -61,7 +61,7 @@ def __init__( modes: Sequence[int], ) -> None: super().__init__(name="Vac") - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes=modes, ansatz=PolyExpAnsatz.from_function(fn=triples.vacuum_state_Abc, n_modes=len(modes)), ).representation diff --git a/mrmustard/lab_dev/transformations/amplifier.py b/mrmustard/lab_dev/transformations/amplifier.py index 209fc36dc..7707a8950 100644 --- a/mrmustard/lab_dev/transformations/amplifier.py +++ b/mrmustard/lab_dev/transformations/amplifier.py @@ -95,7 +95,7 @@ def __init__( None, ) ) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function(fn=triples.amplifier_Abc, g=self.gain), diff --git a/mrmustard/lab_dev/transformations/attenuator.py b/mrmustard/lab_dev/transformations/attenuator.py index 449305b21..15b61c515 100644 --- a/mrmustard/lab_dev/transformations/attenuator.py +++ b/mrmustard/lab_dev/transformations/attenuator.py @@ -95,7 +95,7 @@ def __init__( None, ) ) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function(fn=triples.attenuator_Abc, eta=self.transmissivity), diff --git a/mrmustard/lab_dev/transformations/base.py b/mrmustard/lab_dev/transformations/base.py index afc3980c0..e2f37b4ae 100644 --- a/mrmustard/lab_dev/transformations/base.py +++ b/mrmustard/lab_dev/transformations/base.py @@ -62,7 +62,7 @@ def from_bargmann( @classmethod @abstractmethod - def from_modes( + def from_ansatz( cls, modes_out: Sequence[int], modes_in: Sequence[int], @@ -152,10 +152,10 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Operation.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple), name) + return Operation.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple), name) @classmethod - def from_modes( + def from_ansatz( cls, modes_out: Sequence[int], modes_in: Sequence[int], @@ -182,9 +182,9 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Operation.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple)) + QQ = Operation.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Operation.from_modes(modes_out, modes_in, BB.ansatz, name) + return Operation.from_ansatz(modes_out, modes_in, BB.ansatz, name) class Unitary(Operation): @@ -210,10 +210,10 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Unitary.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple), name) + return Unitary.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple), name) @classmethod - def from_modes( + def from_ansatz( cls, modes_out: Sequence[int], modes_in: Sequence[int], @@ -240,9 +240,9 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Unitary.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple)) + QQ = Unitary.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Unitary.from_modes(modes_out, modes_in, BB.ansatz, name) + return Unitary.from_ansatz(modes_out, modes_in, BB.ansatz, name) @classmethod def from_symplectic(cls, modes, S) -> Unitary: @@ -311,10 +311,10 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Map.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple), name) + return Map.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple), name) @classmethod - def from_modes( + def from_ansatz( cls, modes_out: Sequence[int], modes_in: Sequence[int], @@ -343,9 +343,9 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Map.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple)) + QQ = Map.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Map.from_modes(modes_out, modes_in, BB.ansatz, name) + return Map.from_ansatz(modes_out, modes_in, BB.ansatz, name) class Channel(Map): @@ -410,10 +410,10 @@ def from_bargmann( triple: tuple, name: str | None = None, ) -> Transformation: - return Channel.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple), name) + return Channel.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple), name) @classmethod - def from_modes( + def from_ansatz( cls, modes_out: Sequence[int], modes_in: Sequence[int], @@ -442,9 +442,9 @@ def from_quadrature( QtoB_out = BtoQ(modes_out, phi).inverse() QtoB_in = BtoQ(modes_in, phi).inverse().dual - QQ = Channel.from_modes(modes_out, modes_in, PolyExpAnsatz(*triple)) + QQ = Channel.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple)) BB = QtoB_in >> QQ >> QtoB_out - return Channel.from_modes(modes_out, modes_in, BB.ansatz, name) + return Channel.from_ansatz(modes_out, modes_in, BB.ansatz, name) @classmethod def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Channel: diff --git a/mrmustard/lab_dev/transformations/bsgate.py b/mrmustard/lab_dev/transformations/bsgate.py index 39c06a473..83a2beeb7 100644 --- a/mrmustard/lab_dev/transformations/bsgate.py +++ b/mrmustard/lab_dev/transformations/bsgate.py @@ -104,7 +104,7 @@ def __init__( super().__init__(name="BSgate") self._add_parameter(make_parameter(theta_trainable, theta, "theta", theta_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function( diff --git a/mrmustard/lab_dev/transformations/cft.py b/mrmustard/lab_dev/transformations/cft.py index 5d26bebbc..697288c3e 100644 --- a/mrmustard/lab_dev/transformations/cft.py +++ b/mrmustard/lab_dev/transformations/cft.py @@ -43,7 +43,7 @@ def __init__( modes: Sequence[int], ): super().__init__(name="CFT") - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function( diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index 7ce9a071f..aa5f76a0f 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -95,7 +95,7 @@ def __init__( xs, ys = list(reshape_params(len(modes), x=x, y=y)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function( diff --git a/mrmustard/lab_dev/transformations/fockdamping.py b/mrmustard/lab_dev/transformations/fockdamping.py index 65bec5840..61bb9f47f 100644 --- a/mrmustard/lab_dev/transformations/fockdamping.py +++ b/mrmustard/lab_dev/transformations/fockdamping.py @@ -85,7 +85,7 @@ def __init__( None, ) ) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function(fn=triples.fock_damping_Abc, beta=self.damping), diff --git a/mrmustard/lab_dev/transformations/gaussrandnoise.py b/mrmustard/lab_dev/transformations/gaussrandnoise.py index 19dada7cb..7c708ef73 100644 --- a/mrmustard/lab_dev/transformations/gaussrandnoise.py +++ b/mrmustard/lab_dev/transformations/gaussrandnoise.py @@ -76,7 +76,7 @@ def __init__( super().__init__(name="GRN") self._add_parameter(make_parameter(Y_trainable, value=Y, name="Y", bounds=(None, None))) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function(fn=triples.gaussian_random_noise_Abc, Y=self.Y), diff --git a/mrmustard/lab_dev/transformations/ggate.py b/mrmustard/lab_dev/transformations/ggate.py index ec74411ab..514deb28c 100644 --- a/mrmustard/lab_dev/transformations/ggate.py +++ b/mrmustard/lab_dev/transformations/ggate.py @@ -57,7 +57,7 @@ def __init__( super().__init__(name="Ggate") S = make_parameter(symplectic_trainable, symplectic, "symplectic", (None, None)) self.parameter_set.add_parameter(S) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function( diff --git a/mrmustard/lab_dev/transformations/identity.py b/mrmustard/lab_dev/transformations/identity.py index f433ac566..31edb0f41 100644 --- a/mrmustard/lab_dev/transformations/identity.py +++ b/mrmustard/lab_dev/transformations/identity.py @@ -52,7 +52,7 @@ def __init__( modes: Sequence[int], ): super().__init__(name="Identity") - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)), diff --git a/mrmustard/lab_dev/transformations/rgate.py b/mrmustard/lab_dev/transformations/rgate.py index 4eeb183ed..3241bf1b4 100644 --- a/mrmustard/lab_dev/transformations/rgate.py +++ b/mrmustard/lab_dev/transformations/rgate.py @@ -62,7 +62,7 @@ def __init__( super().__init__(name="Rgate") (phis,) = list(reshape_params(len(modes), phi=phi)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function(fn=triples.rotation_gate_Abc, theta=self.phi), diff --git a/mrmustard/lab_dev/transformations/s2gate.py b/mrmustard/lab_dev/transformations/s2gate.py index bdb4779c7..dfb76549b 100644 --- a/mrmustard/lab_dev/transformations/s2gate.py +++ b/mrmustard/lab_dev/transformations/s2gate.py @@ -87,7 +87,7 @@ def __init__( super().__init__(name="S2gate") self._add_parameter(make_parameter(r_trainable, r, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phi, "phi", phi_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function( diff --git a/mrmustard/lab_dev/transformations/sgate.py b/mrmustard/lab_dev/transformations/sgate.py index 4b0677499..23a674c73 100644 --- a/mrmustard/lab_dev/transformations/sgate.py +++ b/mrmustard/lab_dev/transformations/sgate.py @@ -94,7 +94,7 @@ def __init__( rs, phis = list(reshape_params(len(modes), r=r, phi=phi)) self._add_parameter(make_parameter(r_trainable, rs, "r", r_bounds)) self._add_parameter(make_parameter(phi_trainable, phis, "phi", phi_bounds)) - self._representation = self.from_modes( + self._representation = self.from_ansatz( modes_in=modes, modes_out=modes, ansatz=PolyExpAnsatz.from_function( diff --git a/mrmustard/physics/ansatz/__init__.py b/mrmustard/physics/ansatz/__init__.py index dbb2cee1b..5fd71825b 100644 --- a/mrmustard/physics/ansatz/__init__.py +++ b/mrmustard/physics/ansatz/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. r""" -The classes for representations in circuit components. +The classes for Ansatze in circuit components. """ from .base import * diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index f1aabbec6..509267b84 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -1,4 +1,4 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. +# Copyright 2023 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_lab_dev/test_states/test_dm.py b/tests/test_lab_dev/test_states/test_dm.py index 471321fcf..b5927998e 100644 --- a/tests/test_lab_dev/test_states/test_dm.py +++ b/tests/test_lab_dev/test_states/test_dm.py @@ -56,7 +56,7 @@ class TestDM: # pylint:disable=too-many-public-methods @pytest.mark.parametrize("name", [None, "my_dm"]) @pytest.mark.parametrize("modes", [{0}, {0, 1}, {3, 19, 2}]) def test_init(self, name, modes): - state = DM.from_modes(modes, None, name) + state = DM.from_ansatz(modes, None, name) assert state.name in ("DM0", "DM01", "DM2319") if not name else name assert list(state.modes) == sorted(modes) diff --git a/tests/test_lab_dev/test_states/test_ket.py b/tests/test_lab_dev/test_states/test_ket.py index 98087c844..a9523727e 100644 --- a/tests/test_lab_dev/test_states/test_ket.py +++ b/tests/test_lab_dev/test_states/test_ket.py @@ -73,7 +73,7 @@ class TestKet: # pylint: disable=too-many-public-methods @pytest.mark.parametrize("name", [None, "my_ket"]) @pytest.mark.parametrize("modes", [[0], [0, 1], [3, 19, 2]]) def test_init(self, name, modes): - state = Ket.from_modes(modes, None, name) + state = Ket.from_ansatz(modes, None, name) assert state.name in ("Ket0", "Ket01", "Ket2319") if not name else name assert list(state.modes) == sorted(modes) @@ -195,7 +195,7 @@ def test_probability(self): @pytest.mark.parametrize("modes", [[0], [0, 1], [3, 19, 2]]) def test_purity(self, modes): - state = Ket.from_modes(modes, None, "my_ket") + state = Ket.from_ansatz(modes, None, "my_ket") assert state.purity == 1 assert state.is_pure diff --git a/tests/test_lab_dev/test_transformations/test_transformations_base.py b/tests/test_lab_dev/test_transformations/test_transformations_base.py index 6e2380ac5..5c33e0335 100644 --- a/tests/test_lab_dev/test_transformations/test_transformations_base.py +++ b/tests/test_lab_dev/test_transformations/test_transformations_base.py @@ -57,7 +57,7 @@ class TestUnitary: @pytest.mark.parametrize("name", [None, "my_unitary"]) @pytest.mark.parametrize("modes", [{0}, {0, 1}, {3, 19, 2}]) def test_init(self, name, modes): - gate = Unitary.from_modes(modes, modes, name=name) + gate = Unitary.from_ansatz(modes, modes, name=name) assert gate.name[:1] == (name or "U")[:1] assert list(gate.modes) == sorted(modes) @@ -136,7 +136,7 @@ class TestChannel: @pytest.mark.parametrize("name", [None, "my_channel"]) @pytest.mark.parametrize("modes", [{0}, {0, 1}, {3, 19, 2}]) def test_init(self, name, modes): - gate = Channel.from_modes(modes, modes, name=name) + gate = Channel.from_ansatz(modes, modes, name=name) assert gate.name[:2] == (name or "Ch")[:2] assert list(gate.modes) == sorted(modes) From 0c5188d0133d33764a194ee56caaa618a526f5c7 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 29 Oct 2024 09:38:36 -0400 Subject: [PATCH 84/87] codefactor --- tests/test_lab_dev/test_states/test_ket.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_lab_dev/test_states/test_ket.py b/tests/test_lab_dev/test_states/test_ket.py index a9523727e..684507a26 100644 --- a/tests/test_lab_dev/test_states/test_ket.py +++ b/tests/test_lab_dev/test_states/test_ket.py @@ -26,13 +26,9 @@ from mrmustard import math, settings from mrmustard.lab_dev.circuit_components import CircuitComponent from mrmustard.math.parameters import Constant, Variable -from mrmustard.physics.bargmann_utils import ( - bargmann_Abc_to_phasespace_cov_means, - wigner_to_bargmann_rho, -) from mrmustard.physics.gaussian import vacuum_cov, vacuum_means, squeezed_vacuum_cov from mrmustard.physics.triples import coherent_state_Abc -from mrmustard.lab_dev.circuit_components_utils import BtoPS, TraceOut +from mrmustard.lab_dev.circuit_components_utils import TraceOut from mrmustard.lab_dev.states import ( Coherent, DisplacedSqueezed, From a4cb8e622b063c85dbb164cae5d344fe2a1a9881 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 29 Oct 2024 13:46:09 -0400 Subject: [PATCH 85/87] codefactor --- tests/test_math/test_backend_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index b6d78c1fb..8b736d99b 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -17,11 +17,11 @@ """ from unittest.mock import MagicMock, patch +import math import numpy as np import pytest import tensorflow as tf -import math from mrmustard import math from ..conftest import skip_np From b70bfff2a5e56529f36b277754d52125315c1eaa Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 29 Oct 2024 13:47:45 -0400 Subject: [PATCH 86/87] import math --- tests/test_math/test_backend_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index 8b736d99b..aa57518aa 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -17,7 +17,6 @@ """ from unittest.mock import MagicMock, patch -import math import numpy as np import pytest import tensorflow as tf From 4426b498cdb90e8ed420fe5c4033f2f6ca0997e8 Mon Sep 17 00:00:00 2001 From: Anthony Date: Tue, 29 Oct 2024 13:57:01 -0400 Subject: [PATCH 87/87] fock_array --- tests/test_lab_dev/test_circuit_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index ac044abb8..200479614 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -597,7 +597,7 @@ def test_hermite_renormalized_with_custom_shape(self): # made up, means nothing def cost(): - ket = S.fock(shape=[3]) + ket = S.fock_array(shape=[3]) return -math.real(ket[2]) circuit = Circuit([S])