Skip to content

Commit

Permalink
Merge pull request #53 from dstansby/pluto-data
Browse files Browse the repository at this point in the history
Fix PLUTO interpolation
  • Loading branch information
dstansby authored Mar 24, 2022
2 parents d611db3 + 8eda800 commit 2a63b47
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
15 changes: 14 additions & 1 deletion psipy/model/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,21 @@ def test_radial_normalised(mas_model):
([1, 2] * u.deg, [1, 2] * u.deg, [30, 31] * const.R_sun),
([1, 0] * u.deg, [1, 0] * u.deg, [30, 31] * const.R_sun),
])
def test_sample_at_coords(mas_model, lon, lat, r):
def test_sample_at_coords_mas(mas_model, lon, lat, r):
# Check scalar coords
rho = mas_model['rho'].sample_at_coords(lon=lon, lat=lat, r=r)
assert rho.unit == mas_model['rho'].unit
assert u.allclose(rho[0], [447.02795493] * u.cm**-3)


@pytest.mark.parametrize(
'lon, lat, r',
[(1*u.deg, 1*u.deg, 1*const.R_sun),
([1, 2] * u.deg, [1, 2] * u.deg, [1, 1.01] * const.R_sun),
([1, 0] * u.deg, [1, 0] * u.deg, [1, 1.01] * const.R_sun),
])
def test_sample_at_coords_pluto(pluto_model, lon, lat, r):
# Check scalar coords
rho = pluto_model['rho'].sample_at_coords(lon=lon, lat=lat, r=r)
assert rho.unit == pluto_model['rho'].unit
assert u.allclose(rho[0], [13.50442343])
24 changes: 18 additions & 6 deletions psipy/model/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, data, name, unit):
# Convert from xarray Dataset to DataArray
self._data = data[name]
# Sort the data once now for any interpolation later
self._data = self._data.transpose(*['phi', 'theta', 'r', 'time'])
self._data = self._data.sortby(['phi', 'theta', 'r', 'time'])
self.name = name
self._unit = unit
Expand Down Expand Up @@ -385,8 +386,8 @@ def sample_at_coords(self, lon: u.deg, lat: u.deg, r: u.m, t=None):
Linear interpolation is used to interpoalte between cells. See the
docstring of `scipy.interpolate.interpn` for more information.
"""
points = [self.data.coords[dim].values for dim in
['phi', 'theta', 'r', 'time']]
dims = ['phi', 'theta', 'r', 'time']
points = [self.data.coords[dim].values for dim in dims]
values = self.data.values

# Check that coordinates are increasing
Expand All @@ -402,8 +403,13 @@ def sample_at_coords(self, lon: u.deg, lat: u.deg, r: u.m, t=None):

# Pad phi points so it's possible to interpolate all the way from
# 0 to 360 deg
points[0] = np.append(points[0], points[0][0] + 2 * np.pi)
pcoords = points[0]
pcoords = np.append(pcoords, pcoords[0] + 2 * np.pi)
pcoords = np.insert(pcoords, 0, pcoords[-2] - 2 * np.pi)
points[0] = pcoords

values = np.append(values, values[0:1, :, :, :], axis=0)
values = np.insert(values, 0, values[-2:-1, :, :, :], axis=0)

if len(points[3]) == 1:
# Only one timestep
Expand All @@ -413,10 +419,16 @@ def sample_at_coords(self, lon: u.deg, lat: u.deg, r: u.m, t=None):
values = values[:, :, :, 0]
points = points[:-1]
else:
xi = np.column_stack([t,
lon.to_value(u.rad),
xi = np.column_stack([lon.to_value(u.rad),
lat.to_value(u.rad),
r.to_value(const.R_sun)])
r.to_value(const.R_sun),
t])

for i, dim in enumerate(dims[:-1]):
bounds = np.min(points[i]), np.max(points[i])
if not (np.all(bounds[0] <= xi[:, i]) and np.all(xi[:, i] <= bounds[1])):
raise ValueError(f"At least one point is outside bounds {bounds} in {dim} dimension.")


values_x = interpolate.interpn(points, values, xi)
return values_x * self._unit

0 comments on commit 2a63b47

Please sign in to comment.