From 737bda673517daaac2226375f26ee1c70e5e66bd Mon Sep 17 00:00:00 2001 From: David Stansby Date: Thu, 24 Mar 2022 15:50:12 +0000 Subject: [PATCH 1/3] Correctly re-rorder dimensions for interpolation --- psipy/model/variable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/psipy/model/variable.py b/psipy/model/variable.py index c4fd7a1..4ed6bbf 100644 --- a/psipy/model/variable.py +++ b/psipy/model/variable.py @@ -46,6 +46,7 @@ def __init__(self, data, name, unit): # Convert from xarray Dataset to DataArray self._data = data[name] # Sort the data once now for any interpolation later + self._data = self._data.transpose(*['phi', 'theta', 'r', 'time']) self._data = self._data.sortby(['phi', 'theta', 'r', 'time']) self.name = name self._unit = unit From 42a50ce4762273f18fc885253ae97891ece5bde9 Mon Sep 17 00:00:00 2001 From: David Stansby Date: Thu, 24 Mar 2022 15:50:52 +0000 Subject: [PATCH 2/3] Add phi slice at the beginning --- psipy/model/variable.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/psipy/model/variable.py b/psipy/model/variable.py index 4ed6bbf..59a16b0 100644 --- a/psipy/model/variable.py +++ b/psipy/model/variable.py @@ -403,8 +403,13 @@ def sample_at_coords(self, lon: u.deg, lat: u.deg, r: u.m, t=None): # Pad phi points so it's possible to interpolate all the way from # 0 to 360 deg - points[0] = np.append(points[0], points[0][0] + 2 * np.pi) + pcoords = points[0] + pcoords = np.append(pcoords, pcoords[0] + 2 * np.pi) + pcoords = np.insert(pcoords, 0, pcoords[-2] - 2 * np.pi) + points[0] = pcoords + values = np.append(values, values[0:1, :, :, :], axis=0) + values = np.insert(values, 0, values[-2:-1, :, :, :], axis=0) if len(points[3]) == 1: # Only one timestep From 8eda80051680ca336b7f189e6c5416193f61ef2b Mon Sep 17 00:00:00 2001 From: David Stansby Date: Thu, 24 Mar 2022 16:00:35 +0000 Subject: [PATCH 3/3] Fix interpolation and add nicer error messages --- psipy/model/tests/test_variable.py | 15 ++++++++++++++- psipy/model/variable.py | 16 +++++++++++----- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/psipy/model/tests/test_variable.py b/psipy/model/tests/test_variable.py index 70bf19e..fea8032 100644 --- a/psipy/model/tests/test_variable.py +++ b/psipy/model/tests/test_variable.py @@ -23,8 +23,21 @@ def test_radial_normalised(mas_model): ([1, 2] * u.deg, [1, 2] * u.deg, [30, 31] * const.R_sun), ([1, 0] * u.deg, [1, 0] * u.deg, [30, 31] * const.R_sun), ]) -def test_sample_at_coords(mas_model, lon, lat, r): +def test_sample_at_coords_mas(mas_model, lon, lat, r): # Check scalar coords rho = mas_model['rho'].sample_at_coords(lon=lon, lat=lat, r=r) assert rho.unit == mas_model['rho'].unit assert u.allclose(rho[0], [447.02795493] * u.cm**-3) + + +@pytest.mark.parametrize( + 'lon, lat, r', + [(1*u.deg, 1*u.deg, 1*const.R_sun), + ([1, 2] * u.deg, [1, 2] * u.deg, [1, 1.01] * const.R_sun), + ([1, 0] * u.deg, [1, 0] * u.deg, [1, 1.01] * const.R_sun), + ]) +def test_sample_at_coords_pluto(pluto_model, lon, lat, r): + # Check scalar coords + rho = pluto_model['rho'].sample_at_coords(lon=lon, lat=lat, r=r) + assert rho.unit == pluto_model['rho'].unit + assert u.allclose(rho[0], [13.50442343]) diff --git a/psipy/model/variable.py b/psipy/model/variable.py index 59a16b0..5c99adc 100644 --- a/psipy/model/variable.py +++ b/psipy/model/variable.py @@ -386,8 +386,8 @@ def sample_at_coords(self, lon: u.deg, lat: u.deg, r: u.m, t=None): Linear interpolation is used to interpoalte between cells. See the docstring of `scipy.interpolate.interpn` for more information. """ - points = [self.data.coords[dim].values for dim in - ['phi', 'theta', 'r', 'time']] + dims = ['phi', 'theta', 'r', 'time'] + points = [self.data.coords[dim].values for dim in dims] values = self.data.values # Check that coordinates are increasing @@ -419,10 +419,16 @@ def sample_at_coords(self, lon: u.deg, lat: u.deg, r: u.m, t=None): values = values[:, :, :, 0] points = points[:-1] else: - xi = np.column_stack([t, - lon.to_value(u.rad), + xi = np.column_stack([lon.to_value(u.rad), lat.to_value(u.rad), - r.to_value(const.R_sun)]) + r.to_value(const.R_sun), + t]) + + for i, dim in enumerate(dims[:-1]): + bounds = np.min(points[i]), np.max(points[i]) + if not (np.all(bounds[0] <= xi[:, i]) and np.all(xi[:, i] <= bounds[1])): + raise ValueError(f"At least one point is outside bounds {bounds} in {dim} dimension.") + values_x = interpolate.interpn(points, values, xi) return values_x * self._unit