Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Representation + Wires refactor #498

Merged
merged 103 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
a121ebe
rename
apchytr Sep 9, 2024
5d3db1e
rename
apchytr Sep 9, 2024
1ac933c
init
apchytr Sep 10, 2024
b6feb3b
init
apchytr Sep 10, 2024
61a83cf
fock working
apchytr Sep 10, 2024
9f60a21
Representation
apchytr Sep 10, 2024
803cd88
tests passing
apchytr Sep 10, 2024
3474c85
too-many-instance-attributes
apchytr Sep 10, 2024
014a07b
protected-access
apchytr Sep 10, 2024
807f555
some cleanup
apchytr Sep 10, 2024
a2289f6
representation done
apchytr Sep 10, 2024
110e58a
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Sep 10, 2024
74d4cef
merge
apchytr Sep 10, 2024
3bf79a3
codefactor
apchytr Sep 10, 2024
8091738
doc
apchytr Sep 10, 2024
902e455
Merge branch 'develop' into refactorRepAnsatz
ziofil Sep 23, 2024
31bafed
move wires to physics
apchytr Sep 30, 2024
a27838c
move wires to physics
apchytr Sep 30, 2024
0f8bfc9
init
apchytr Oct 3, 2024
d92fb82
tests passing
apchytr Oct 3, 2024
8c45168
unused imports
apchytr Oct 3, 2024
be5dfe1
some progress for btoq
apchytr Oct 3, 2024
9688a01
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 3, 2024
119a0f1
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 4, 2024
6569ce8
merge
apchytr Oct 4, 2024
e261cbc
some ansatz tests ported
apchytr Oct 7, 2024
40ef6f1
ansatz tests
apchytr Oct 7, 2024
fabf1d8
some codefactor
apchytr Oct 7, 2024
e2012c0
some codefactor
apchytr Oct 7, 2024
3080a6c
docs
apchytr Oct 7, 2024
640259b
docs
apchytr Oct 7, 2024
dfe55bc
workflow
apchytr Oct 7, 2024
ba536ad
some coverage
apchytr Oct 7, 2024
7840995
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 7, 2024
d5b347d
coverage
apchytr Oct 7, 2024
ed1dbb6
widgets
apchytr Oct 7, 2024
b4c4c07
widgets
apchytr Oct 7, 2024
5bfb17a
cleanup
apchytr Oct 7, 2024
6335d37
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 7, 2024
b7930f3
unused imports
apchytr Oct 7, 2024
78b73e3
Merge branch 'refactorRepAnsatz' of https://github.com/XanaduAI/MrMus…
apchytr Oct 7, 2024
d8c2713
ansatz representation
apchytr Oct 8, 2024
e779692
cleanup
apchytr Oct 8, 2024
e4ca9f5
docs
apchytr Oct 8, 2024
ad6d5bf
cleanup
apchytr Oct 8, 2024
122a300
workflow
apchytr Oct 8, 2024
6034e85
some rename
apchytr Oct 8, 2024
a79284a
cleanup
apchytr Oct 8, 2024
d5f9f3a
cleanup
apchytr Oct 8, 2024
f0a969c
docs
apchytr Oct 8, 2024
1331c93
cleanup
apchytr Oct 8, 2024
2192754
docs
apchytr Oct 8, 2024
0bb2369
docs
apchytr Oct 8, 2024
fb9cf8e
docs
apchytr Oct 8, 2024
76646d8
docs
apchytr Oct 8, 2024
1c1c3fa
docs
apchytr Oct 8, 2024
d61a626
docs
apchytr Oct 8, 2024
fd40f69
moving adjoint and dual
apchytr Oct 10, 2024
62bd944
some more cleanup
apchytr Oct 10, 2024
5389c02
some more cleanup
apchytr Oct 10, 2024
19afefe
some more cleanup
apchytr Oct 10, 2024
6781fe2
CC -> representation arg
apchytr Oct 10, 2024
4fd0cd0
doc
apchytr Oct 10, 2024
a940112
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 11, 2024
e04a908
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 15, 2024
1b1c76c
Merge branch 'refactorRepAnsatz' of https://github.com/XanaduAI/MrMus…
apchytr Oct 15, 2024
5f3d82e
states
apchytr Oct 15, 2024
381e495
transformations
apchytr Oct 15, 2024
a53a2be
progress
apchytr Oct 16, 2024
fabfda0
print rem
apchytr Oct 17, 2024
5c61ed4
progress
apchytr Oct 17, 2024
497c434
matmul working
apchytr Oct 17, 2024
d5085ce
rename
apchytr Oct 17, 2024
4fd828a
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 21, 2024
51e5b8c
Merge branch 'refactorRepAnsatz' of https://github.com/XanaduAI/MrMus…
apchytr Oct 21, 2024
5bac799
initial test file
apchytr Oct 21, 2024
f756139
pylint
apchytr Oct 22, 2024
00b280e
codefactor
apchytr Oct 22, 2024
20c63e0
fix
apchytr Oct 22, 2024
3ce5e3c
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 22, 2024
9c846dd
patch 100
apchytr Oct 23, 2024
7253b75
some tests
apchytr Oct 23, 2024
e0f66fd
rem
apchytr Oct 23, 2024
348a702
cleanup _from_attribute
apchytr Oct 23, 2024
ab2c99f
some cov
apchytr Oct 23, 2024
b8cfc7a
codefactor
apchytr Oct 23, 2024
40f6bcc
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 24, 2024
b5c4334
cleanup imports
apchytr Oct 24, 2024
5e98362
cleanup
apchytr Oct 24, 2024
01f9502
doc cleanup
apchytr Oct 24, 2024
00b714e
AI cleanup
apchytr Oct 24, 2024
b11e463
remove tuple
apchytr Oct 24, 2024
a06cc3b
coverage
apchytr Oct 24, 2024
53eba32
doc
apchytr Oct 24, 2024
fbb5083
rename
apchytr Oct 24, 2024
22e7696
cr
apchytr Oct 28, 2024
c3f29c2
cr
apchytr Oct 29, 2024
0c5188d
codefactor
apchytr Oct 29, 2024
62ecbda
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 29, 2024
bc305b7
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 29, 2024
a4cb8e6
codefactor
apchytr Oct 29, 2024
b70bfff
import math
apchytr Oct 29, 2024
4426b49
fock_array
apchytr Oct 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ coverage:
target: 89%
patch:
default:
target: auto
target: 99%
threshold: 0%
3 changes: 1 addition & 2 deletions .github/workflows/tests_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,5 @@ jobs:
- name: Run tests
run: |
python -m pytest --doctest-modules mrmustard/math/parameter_set.py
python -m pytest --doctest-modules mrmustard/physics/ansatze.py
python -m pytest --doctest-modules mrmustard/physics/representations.py
python -m pytest --doctest-modules mrmustard/physics/ansatz
python -m pytest --doctest-modules mrmustard/lab_dev
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ ignored-classes=numpy,tensorflow,scipy,networkx,strawberryfields,thewalrus
# can either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once).
disable=fixme,no-member,line-too-long,invalid-name,too-many-lines,redefined-builtin,too-many-locals,duplicate-code,too-many-arguments,too-few-public-methods,no-else-return,isinstance-second-argument-not-valid-type,no-self-argument, arguments-differ
disable=fixme,no-member,line-too-long,invalid-name,too-many-lines,redefined-builtin,too-many-locals,duplicate-code,too-many-arguments,too-few-public-methods,no-else-return,isinstance-second-argument-not-valid-type,no-self-argument, arguments-differ, protected-access
1 change: 0 additions & 1 deletion doc/code/lab_dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ mrmustard.lab_dev
.. toctree::
:maxdepth: 1

lab_dev/wires
lab_dev/circuit_components
lab_dev/states
lab_dev/transformations
Expand Down
8 changes: 0 additions & 8 deletions doc/code/lab_dev/wires.rst

This file was deleted.

2 changes: 1 addition & 1 deletion doc/code/physics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mrmustard.physics
.. toctree::
:maxdepth: 1

physics/ansatze
physics/wires
physics/representations

.. toctree::
Expand Down
8 changes: 0 additions & 8 deletions doc/code/physics/ansatze.rst

This file was deleted.

4 changes: 2 additions & 2 deletions doc/code/physics/utils/bargmann_calculations.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Calculations on Bargmann objects
================================

.. currentmodule:: mrmustard.physics.bargmann
.. currentmodule:: mrmustard.physics.bargmann_utils

.. automodapi:: mrmustard.physics.bargmann
.. automodapi:: mrmustard.physics.bargmann_utils
:no-heading:
:include-all-objects:
4 changes: 2 additions & 2 deletions doc/code/physics/utils/fock_calculations.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Calculations on Fock objects
============================

.. currentmodule:: mrmustard.physics.fock
.. currentmodule:: mrmustard.physics.fock_utils

.. automodapi:: mrmustard.physics.fock
.. automodapi:: mrmustard.physics.fock_utils
:no-heading:
:include-all-objects:
8 changes: 8 additions & 0 deletions doc/code/physics/wires.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
mrmustard.physics.wires
=======================

.. currentmodule:: mrmustard.physics.wires

.. automodapi:: mrmustard.physics.wires
:no-heading:
:include-all-objects:
50 changes: 25 additions & 25 deletions mrmustard/lab/abstract/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from mrmustard import math, settings
from mrmustard.math.parameters import Constant, Variable
from mrmustard.physics import bargmann, fock, gaussian
from mrmustard.physics import bargmann_utils, fock_utils, gaussian
from mrmustard.physics.wigner import wigner_discretized
from mrmustard.utils.typing import (
ComplexMatrix,
Expand Down Expand Up @@ -157,7 +157,7 @@ def purity(self) -> float:
if self.is_gaussian:
self._purity = gaussian.purity(self.cov)
else:
self._purity = fock.purity(self._dm)
self._purity = fock_utils.purity(self._dm)
return self._purity

@property
Expand Down Expand Up @@ -187,15 +187,15 @@ def number_stdev(self) -> RealVector:
return math.sqrt(math.diag_part(self.number_cov))

return math.sqrt(
fock.number_variances(self.fock, is_dm=len(self.fock.shape) == self.num_modes * 2)
fock_utils.number_variances(self.fock, is_dm=len(self.fock.shape) == self.num_modes * 2)
)

@property
def cutoffs(self) -> list[int]:
r"""Returns the Hilbert space dimension of each mode."""
if self._cutoffs is None:
if self._ket is None and self._dm is None:
self._cutoffs = fock.autocutoffs(
self._cutoffs = fock_utils.autocutoffs(
self.cov, self.means, settings.AUTOSHAPE_PROBABILITY
)
else:
Expand Down Expand Up @@ -223,7 +223,7 @@ def shape(self) -> list[int]:
def fock(self) -> ComplexTensor:
r"""Returns the Fock representation of the state."""
if self._dm is None and self._ket is None:
_fock = fock.wigner_to_fock_state(
_fock = fock_utils.wigner_to_fock_state(
self.cov,
self.means,
shape=self.shape,
Expand All @@ -243,7 +243,7 @@ def number_means(self) -> RealVector:
if self.is_gaussian:
return gaussian.number_means(self.cov, self.means)

return fock.number_means(tensor=self.fock, is_dm=self.is_mixed)
return fock_utils.number_means(tensor=self.fock, is_dm=self.is_mixed)

@property
def number_cov(self) -> RealMatrix:
Expand All @@ -258,7 +258,7 @@ def norm(self) -> float:
r"""Returns the norm of the state."""
if self.is_gaussian:
return self._norm
return fock.norm(self.fock, not self.is_hilbert_vector)
return fock_utils.norm(self.fock, not self.is_hilbert_vector)

@property
def probability(self) -> float:
Expand Down Expand Up @@ -296,7 +296,7 @@ def ket(
cutoffs = [c if c is not None else self.cutoffs[i] for i, c in enumerate(cutoffs)]

if self.is_gaussian:
self._ket = fock.wigner_to_fock_state(
self._ket = fock_utils.wigner_to_fock_state(
self.cov,
self.means,
shape=cutoffs,
Expand All @@ -308,12 +308,12 @@ def ket(
if self._ket is None:
# if state is pure and has a density matrix, calculate the ket
if self.is_pure:
self._ket = fock.dm_to_ket(self._dm)
self._ket = fock_utils.dm_to_ket(self._dm)
current_cutoffs = [int(s) for s in self._ket.shape]
if cutoffs != current_cutoffs:
paddings = [(0, max(0, new - old)) for new, old in zip(cutoffs, current_cutoffs)]
if any(p != (0, 0) for p in paddings):
padded = fock.math.pad(self._ket, paddings, mode="constant")
padded = fock_utils.math.pad(self._ket, paddings, mode="constant")
else:
padded = self._ket
return padded[tuple(slice(s) for s in cutoffs)]
Expand All @@ -336,16 +336,16 @@ def dm(self, cutoffs: list[int] | None = None) -> ComplexTensor:
if self.is_pure:
ket = self.ket(cutoffs=cutoffs)
if ket is not None:
return fock.ket_to_dm(ket)
return fock_utils.ket_to_dm(ket)
else:
if self.is_gaussian:
self._dm = fock.wigner_to_fock_state(
self._dm = fock_utils.wigner_to_fock_state(
self.cov, self.means, shape=cutoffs + cutoffs, return_dm=True
)
elif cutoffs != (current_cutoffs := list(self._dm.shape[: self.num_modes])):
paddings = [(0, max(0, new - old)) for new, old in zip(cutoffs, current_cutoffs)]
if any(p != (0, 0) for p in paddings):
padded = fock.math.pad(self._dm, paddings + paddings, mode="constant")
padded = fock_utils.math.pad(self._dm, paddings + paddings, mode="constant")
else:
padded = self._dm
return padded[tuple(slice(s) for s in cutoffs + cutoffs)]
Expand All @@ -366,10 +366,10 @@ def fock_probabilities(self, cutoffs: Sequence[int]) -> RealTensor:
if self._fock_probabilities is None:
if self.is_mixed:
dm = self.dm(cutoffs=cutoffs)
self._fock_probabilities = fock.dm_to_probs(dm)
self._fock_probabilities = fock_utils.dm_to_probs(dm)
else:
ket = self.ket(cutoffs=cutoffs)
self._fock_probabilities = fock.ket_to_probs(ket)
self._fock_probabilities = fock_utils.ket_to_probs(ket)
return self._fock_probabilities

def primal(self, other: State | Transformation) -> State:
Expand Down Expand Up @@ -427,9 +427,9 @@ def _project_onto_fock(self, other: State) -> State | float:

# return the probability (norm) of the state when there are no modes left
return (
fock.math.abs(out_fock) ** 2
fock_utils.math.abs(out_fock) ** 2
if other.is_pure and self.is_pure
else fock.math.abs(out_fock)
else fock_utils.math.abs(out_fock)
)

def _contract_with_other(self, other):
Expand All @@ -441,7 +441,7 @@ def _contract_with_other(self, other):
else:
# matching other's cutoffs
self_cutoffs = [other.cutoffs[other.indices(m)] for m in self.modes]
out_fock = fock.contract_states(
out_fock = fock_utils.contract_states(
stateA=(other.ket(other_cutoffs) if other.is_pure else other.dm(other_cutoffs)),
stateB=(self.ket(self_cutoffs) if self.is_pure else self.dm(self_cutoffs)),
a_is_dm=other.is_mixed,
Expand Down Expand Up @@ -496,7 +496,7 @@ def __and__(self, other: State) -> State:
if self.is_mixed or other.is_mixed:
self_fock = self.dm()
other_fock = other.dm()
dm = fock.math.tensordot(self_fock, other_fock, [[], []])
dm = fock_utils.math.tensordot(self_fock, other_fock, [[], []])
# e.g. self has shape [1,3,1,3] and other has shape [2,2]
# we want self & other to have shape [1,3,2,1,3,2]
# before transposing shape is [1,3,1,3]+[2,2]
Expand All @@ -516,7 +516,7 @@ def __and__(self, other: State) -> State:
self_fock = self.ket()
other_fock = other.ket()
return State(
ket=fock.math.tensordot(self_fock, other_fock, [[], []]),
ket=fock_utils.math.tensordot(self_fock, other_fock, [[], []]),
modes=self.modes + [m + max(self.modes) + 1 for m in other.modes],
)
cov = gaussian.join_covs([self.cov, other.cov])
Expand Down Expand Up @@ -548,9 +548,9 @@ def bargmann(self, numpy=False) -> tuple[ComplexMatrix, ComplexVector, complex]
"""
if self.is_gaussian:
if self.is_pure:
A, B, C = bargmann.wigner_to_bargmann_psi(self.cov, self.means)
A, B, C = bargmann_utils.wigner_to_bargmann_psi(self.cov, self.means)
else:
A, B, C = bargmann.wigner_to_bargmann_rho(self.cov, self.means)
A, B, C = bargmann_utils.wigner_to_bargmann_rho(self.cov, self.means)
else:
return None
if numpy:
Expand Down Expand Up @@ -579,7 +579,7 @@ def get_modes(self, item) -> State:
means, _ = gaussian.partition_means(self.means, item_idx)
return State(cov=cov, means=means, modes=item)

fock_partitioned = fock.trace(self.dm(self.cutoffs), keep=item_idx)
fock_partitioned = fock_utils.trace(self.dm(self.cutoffs), keep=item_idx)
return State(dm=fock_partitioned, modes=item)

def __eq__(self, other) -> bool: # pylint: disable=too-many-return-statements
Expand Down Expand Up @@ -733,8 +733,8 @@ def mikkel_plot(
if plot_args["ytick_labels"] is None:
plot_args["ytick_labels"] = plot_args["yticks"]

q, ProbX = fock.quadrature_distribution(rho)
p, ProbP = fock.quadrature_distribution(rho, np.pi / 2)
q, ProbX = fock_utils.quadrature_distribution(rho)
p, ProbP = fock_utils.quadrature_distribution(rho, np.pi / 2)

xvec = np.linspace(*xbounds, plot_args["resolution"])
pvec = np.linspace(*ybounds, plot_args["resolution"])
Expand Down
26 changes: 15 additions & 11 deletions mrmustard/lab/abstract/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from mrmustard.math.parameter_set import ParameterSet
from mrmustard.math.parameters import Constant, Variable
from mrmustard.math.tensor_networks import Tensor
from mrmustard.physics import bargmann, fock, gaussian
from mrmustard.physics import bargmann_utils, fock_utils, gaussian
from mrmustard.utils.typing import RealMatrix, RealVector

from .state import State
Expand Down Expand Up @@ -172,9 +172,9 @@ def d_vector_dual(self) -> RealVector | None:
def bargmann(self, numpy=False):
X, Y, d = self.XYd(allow_none=False)
if self.is_unitary:
A, B, C = bargmann.wigner_to_bargmann_U(X, d)
A, B, C = bargmann_utils.wigner_to_bargmann_U(X, d)
else:
A, B, C = bargmann.wigner_to_bargmann_Choi(X, Y, d)
A, B, C = bargmann_utils.wigner_to_bargmann_Choi(X, Y, d)
if numpy:
return math.asnumpy(A), math.asnumpy(B), math.asnumpy(C)
return A, B, C
Expand Down Expand Up @@ -208,11 +208,11 @@ def choi(
U = self.U(shape[: self.num_modes])
Udual = self.U(shape[self.num_modes :])
if dual:
return fock.U_to_choi(U=Udual, Udual=U)
return fock.U_to_choi(U=U, Udual=Udual)
return fock_utils.U_to_choi(U=Udual, Udual=U)
return fock_utils.U_to_choi(U=U, Udual=Udual)

X, Y, d = self.XYd(allow_none=False)
choi = fock.wigner_to_fock_Choi(X, Y, d, shape=shape)
choi = fock_utils.wigner_to_fock_Choi(X, Y, d, shape=shape)
if dual:
n = len(shape) // 4
N0 = list(range(0, n))
Expand Down Expand Up @@ -382,8 +382,10 @@ def _transform_fock(self, state: State, dual=False) -> State:
op_idx = [state.modes.index(m) for m in self.modes]
U = self.U(cutoffs=[state.cutoffs[i] for i in op_idx])
if state.is_hilbert_vector:
return State(ket=fock.apply_kraus_to_ket(U, state.ket(), op_idx), modes=state.modes)
return State(dm=fock.apply_kraus_to_dm(U, state.dm(), op_idx), modes=state.modes)
return State(
ket=fock_utils.apply_kraus_to_ket(U, state.ket(), op_idx), modes=state.modes
)
return State(dm=fock_utils.apply_kraus_to_dm(U, state.dm(), op_idx), modes=state.modes)

def U(
self,
Expand Down Expand Up @@ -411,7 +413,7 @@ def U(
raise ValueError(f"len(cutoffs) must be {self.num_modes} (got {len(cutoffs)})")
shape = shape or tuple(cutoffs) * 2
X, _, d = self.XYd(allow_none=False)
return fock.wigner_to_fock_U(X, d, shape=shape)
return fock_utils.wigner_to_fock_U(X, d, shape=shape)

def __eq__(self, other):
r"""Returns ``True`` if the two transformations are equal."""
Expand Down Expand Up @@ -453,8 +455,10 @@ def _transform_fock(self, state: State, dual: bool = False) -> State:
op_idx = [state.modes.index(m) for m in self.modes]
choi = self.choi(cutoffs=[state.cutoffs[i] for i in op_idx], dual=dual)
if state.is_hilbert_vector:
return State(dm=fock.apply_choi_to_ket(choi, state.ket(), op_idx), modes=state.modes)
return State(dm=fock.apply_choi_to_dm(choi, state.dm(), op_idx), modes=state.modes)
return State(
dm=fock_utils.apply_choi_to_ket(choi, state.ket(), op_idx), modes=state.modes
)
return State(dm=fock_utils.apply_choi_to_dm(choi, state.dm(), op_idx), modes=state.modes)

def value(self, shape: tuple[int]):
return self.choi(shape=shape)
Expand Down
6 changes: 3 additions & 3 deletions mrmustard/lab/detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Iterable

from mrmustard import settings
from mrmustard.physics import fock, gaussian
from mrmustard.physics import fock_utils, gaussian
from mrmustard.utils.typing import RealMatrix, RealVector

from mrmustard import math
Expand Down Expand Up @@ -407,7 +407,7 @@ def _measure_fock(self, other) -> State | float:
reduced_state = other.get_modes(self.modes)

# build pdf and sample homodyne outcome
x_outcome, probability = fock.sample_homodyne(
x_outcome, probability = fock_utils.sample_homodyne(
state=reduced_state.ket() if reduced_state.is_pure else reduced_state.dm(),
quadrature_angle=self.quadrature_angle,
)
Expand All @@ -427,7 +427,7 @@ def _measure_fock(self, other) -> State | float:
other_cutoffs = [
None if m not in self.modes else other.cutoffs[other.indices(m)] for m in other.modes
]
out_fock = fock.contract_states(
out_fock = fock_utils.contract_states(
stateA=(other.ket(other_cutoffs) if other.is_pure else other.dm(other_cutoffs)),
stateB=self.state.ket(self_cutoffs),
a_is_dm=other.is_mixed,
Expand Down
Loading
Loading