Skip to content

Commit

Permalink
Merge pull request #98 from gaelforget/v0p3p10b
Browse files Browse the repository at this point in the history
updates related to recent changes in DiffEqBase
  • Loading branch information
gaelforget authored Aug 29, 2022
2 parents 25e11d9 + 165e722 commit 09f9d54
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 36 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "IndividualDisplacements"
uuid = "b92f0c32-5b7e-11e9-1d7b-238b2da8b0e6"
authors = ["gaelforget <gforget@mit.edu>"]
version = "0.3.8"
version = "0.3.9"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand All @@ -21,7 +21,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
CFTime = "0.1"
CSV = "0.6, 0.7, 0.8, 0.9, 0.10"
CyclicArrays = "0.2, 0.3, 0.4, 0.5"
DataFrames = "0.21, 0.22, 1.0, 1.1"
DataFrames = "0.21, 0.22, 1"
MeshArrays = "0.2.19, 0.2"
NetCDF = "0.10, 0.11"
OrdinaryDiffEq = "5, 6"
Expand Down
6 changes: 2 additions & 4 deletions examples/jupyter/flow_fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,10 @@ function global_ocean_circulation(;k=1,ny=2)
γ=GridSpec("LatLonCap",MeshArrays.GRID_LLC90)
Γ=GridLoad(γ;option="full")
Γ=merge(Γ,MeshArrays.NeighborTileIndices_cs(Γ))

func=(u -> MeshArrays.update_location_llc!(u,𝐷))
Γ=merge(Γ,(; update_location! = func))
func=(u -> MeshArrays.update_location_llc!(u,Γ))

#initialize u0,u1 etc
𝑃,𝐷=set_up_FlowFields(k,Γ,ECCOclim_path);
𝑃,𝐷=set_up_FlowFields(k,Γ,func,ECCOclim_path);

#add parameters for use in reset!
tmp=(frac=r_reset, Γ=Γ)
Expand Down
2 changes: 1 addition & 1 deletion examples/jupyter/global_ocean_circulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fieldnames(typeof(𝑃))
#
# - initial particle positions randomly over Global Ocean

np=100
np=10

#xy = init_global_randn(np,𝐷)
#df=DataFrame(x=xy[1,:],y=xy[2,:],f=xy[3,:])
Expand Down
19 changes: 9 additions & 10 deletions examples/jupyter/helper_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,25 @@ file location (`pth`).
_Note: the initial implementation approximates month durations to
365 days / 12 months for simplicity and sets 𝑃.𝑇 to [-mon/2,mon/2]_
"""
function set_up_FlowFields(k::Int::NamedTuple,pth::String)
function set_up_FlowFields(k::Int::NamedTuple,func::Function,pth::String)
XC=exchange.XC) #add 1 lon point at each edge
YC=exchange.YC) #add 1 lat point at each edge
iDXC=1. ./Γ.DXC
iDYC=1. ./Γ.DYC
γ=Γ.XC.grid
mon=86400.0*365.0/12.0
func=Γ.update_location!

if k==0
msk=Γ.hFacC
(_,nr)=size(msk)
𝑃=FlowFields(MeshArray(γ,Float64,nr),MeshArray(γ,Float64,nr),
MeshArray(γ,Float64,nr),MeshArray(γ,Float64,nr),
MeshArray(γ,Float64,nr+1),MeshArray(γ,Float64,nr+1),
𝑃=FlowFields(MeshArray(γ,Float32,nr),MeshArray(γ,Float32,nr),
MeshArray(γ,Float32,nr),MeshArray(γ,Float32,nr),
MeshArray(γ,Float32,nr+1),MeshArray(γ,Float32,nr+1),
[-mon/2,mon/2],func)
else
msk=Γ.hFacC[:, k]
𝑃=FlowFields(MeshArray(γ,Float64),MeshArray(γ,Float64),
MeshArray(γ,Float64),MeshArray(γ,Float64),[-mon/2,mon/2],func)
𝑃=FlowFields(MeshArray(γ,Float32),MeshArray(γ,Float32),
MeshArray(γ,Float32),MeshArray(γ,Float32),[-mon/2,mon/2],func)
end

𝐷 = (🔄 = update_FlowFields!, pth=pth,
Expand All @@ -109,7 +108,7 @@ _Note: for now, it is assumed that (1) the time interval `dt` between
consecutive records is diff(𝑃.𝑇), (2) monthly climatologies are used
with a periodicity of 12 months, (3) vertical 𝑃.k is selected_
"""
function update_FlowFields!(𝑃::𝐹_MeshArray2D,𝐷::NamedTuple,t::Float64)
function update_FlowFields!(𝑃::𝐹_MeshArray2D,𝐷::NamedTuple,t::AbstractFloat)
dt=𝑃.𝑇[2]-𝑃.𝑇[1]

m0=Int(floor((t+dt/2.0)/dt))
Expand Down Expand Up @@ -208,11 +207,11 @@ function update_FlowFields!(𝑃::𝐹_MeshArray3D,𝐷::NamedTuple,t::Float64)

θ0=IndividualDisplacements.read_nctiles(𝐷.pth*"THETA/THETA","THETA",𝑃.u0.grid,I=(:,:,:,m0))
θ0[findall(isnan.(θ0))]=0.0 #mask with 0s rather than NaNs
𝐷.θ0[:,:]=θ0[:,:]
𝐷.θ0[:,:]=float32.(θ0[:,:])

θ1=IndividualDisplacements.read_nctiles(𝐷.pth*"THETA/THETA","THETA",𝑃.u0.grid,I=(:,:,:,m1))
θ1[findall(isnan.(θ1))]=0.0 #mask with 0s rather than NaNs
𝐷.θ1[:,:]=θ1[:,:]
𝐷.θ1[:,:]=float32.(θ1[:,:])

𝑃.𝑇[:]=[t0,t1]
end
2 changes: 1 addition & 1 deletion examples/jupyter/particle_cloud.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ x=vec([x-0.5 for x in ii1, y in ii2])
y=vec([y-0.5 for x in ii1, y in ii2])
xy = permutedims([[x[i];y[i];1.0] for i in eachindex(x)])

solv(prob) = solve(prob,Tsit5(),reltol=1e-6,abstol=1e-6)
solv(prob) = IndividualDisplacements.ensemble_solver(prob,solver=Tsit5(),reltol=1e-6,abstol=1e-6)
tr = DataFrame(ID=Int[], x=Float64[], y=Float64[], t=Float64[])

#𝐼 = Individuals{Float64,2}(📌=xy[:,:], 🔴=tr, 🆔=collect(1:size(xy,2)),
Expand Down
3 changes: 2 additions & 1 deletion examples/jupyter/three_dimensional_ocean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ function custom🔧(sol,𝑃::𝐹_MeshArray3D,𝐷::NamedTuple;id=missing,𝑇=
df.year=df.t ./86400/365

#add depth (i.e. the 3rd, vertical, coordinate)
k=[sol[1,i,j][3] for i in 1:size(sol,2), j in 1:size(sol,3)]
k=[[sol[i][3,1] for i in 1:size(sol,3)];[sol[i][3,end] for i in 1:size(sol,3)]]

nz=length(𝐼.𝑃.u1)
df.k=min.(max.(k[:],Ref(0.0)),Ref(nz)) #level
k=Int.(floor.(df.k)); w=(df.k-k);
Expand Down
26 changes: 20 additions & 6 deletions src/API.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ end
"""

default_solver(prob) = solve(prob,Tsit5(),reltol=1e-8,abstol=1e-8)

function ensemble_solver(prob;solver=Tsit5(),reltol=1e-8,abstol=1e-8)
u0 = prob.u0
prob_func(prob,i,repeat) = remake(prob,u0=u0[i])
indiv_prob = ODEProblem(prob.f,u0[1],prob.tspan,prob.p)
ensemble_prob = EnsembleProblem(indiv_prob,prob_func=prob_func)
solve(ensemble_prob, solver, reltol=reltol, abstol=abstol, trajectories=length(u0))
end

a=fill(0.0,1,1)
default_flowfields = 𝐹_Array2D{Float64}(a,a,a,a,[0. 1.])
default_recorder = DataFrame(ID=Int[], x=Float64[], y=Float64[], t=Float64[])
Expand Down Expand Up @@ -216,7 +225,7 @@ function Individuals(𝐹::𝐹_Array2D,x,y, NT::NamedTuple = NamedTuple())
🆔=collect(1:size(📌,2))
haskey(NT,:🆔) ? 🆔=NT.🆔 : nothing

=default_solver
=ensemble_solver
haskey(NT,:∫) ?=NT.∫ : nothing

𝐷=NamedTuple()
Expand Down Expand Up @@ -244,7 +253,7 @@ function Individuals(𝐹::𝐹_Array3D,x,y,z, NT::NamedTuple = NamedTuple())
🆔=collect(1:size(📌,2))
haskey(NT,:🆔) ? 🆔=NT.🆔 : nothing

=default_solver
=ensemble_solver
haskey(NT,:∫) ?=NT.∫ : nothing

𝐷=NamedTuple()
Expand All @@ -267,7 +276,7 @@ function Individuals(𝐹::𝐹_MeshArray2D,x,y,fid, NT::NamedTuple = NamedTuple
🆔=collect(1:size(📌,2))
haskey(NT,:🆔) ? 🆔=NT.🆔 : nothing

=default_solver
=ensemble_solver
haskey(NT,:∫) ?=NT.∫ : nothing

𝐷=NamedTuple()
Expand Down Expand Up @@ -295,7 +304,7 @@ function Individuals(𝐹::𝐹_MeshArray3D,x,y,z,fid, NT::NamedTuple = NamedTup
🆔=collect(1:size(📌,2))
haskey(NT,:🆔) ? 🆔=NT.🆔 : nothing

=default_solver
=ensemble_solver
haskey(NT,:∫) ?=NT.∫ : nothing

𝐷=NamedTuple()
Expand Down Expand Up @@ -324,8 +333,13 @@ function ∫!(𝐼::Individuals,𝑇::Tuple)
isempty(🔴) ? np =0 : np=length(🆔)
append!(🔴,tmp[np+1:end,:])

nd=length(size(sol))
nd==3 ? 📌[:,:] = deepcopy(sol[:,:,end]) : 📌[:] = deepcopy(sol[:,end])
if isa(sol,EnsembleSolution)
np=length(sol)
📌[:] = deepcopy([sol[i].u[end] for i in 1:np])
else
nd=length(size(sol))
nd==3 ? 📌[:,:] = deepcopy(sol[:,:,end]) : 📌[:] = deepcopy(sol[:,end])
end

end

Expand Down
4 changes: 2 additions & 2 deletions src/compute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ Interpolate velocity from gridded fields (2D; NO halos) to position `u`
using IndividualDisplacements, Statistics
p=dirname(pathof(IndividualDisplacements))
include(joinpath(p,"../examples/jupyter/particle_cloud.jl"))
ref=[29.381183342468674 19.890831699436823]
prod(isapprox.([mean(𝐼.🔴.x) mean(𝐼.🔴.y)],ref,atol=1.0))
ref=[28. 22.]
prod(isapprox.([median(𝐼.🔴.x) median(𝐼.🔴.y)],ref,atol=10.0))
# output
Expand Down
34 changes: 25 additions & 9 deletions src/data_wrangling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,25 @@ end
Copy `sol` to a `DataFrame` & map position to lon,lat coordinates
using "exchanged" 𝐷.XC, 𝐷.YC via `add_lonlat!`
"""
function postprocess_MeshArray(sol::ODESolution,𝑃::FlowFields, 𝐷::NamedTuple; id=missing, 𝑇=missing)
function postprocess_MeshArray(sol,𝑃::FlowFields, 𝐷::NamedTuple; id=missing, 𝑇=missing)
ismissing(id) ? id=collect(1:size(sol,2)) : nothing
ismissing(𝑇) ? 𝑇=𝑃.𝑇 : nothing

nd=length(size(sol))
nt=size(sol,nd)
nf=size(sol,nd-1)
id=id*ones(1,size(sol,nd))
if (size(sol,1)>1)&&(nd>2)
t=[ceil(i/nf)-1 for i in 1:nt*nf]
t=𝑇[1] .+ (𝑇[2]-𝑇[1])/t[end].*t

if isa(sol,EnsembleSolution)
np=length(sol)
x=[[sol[i][1,1] for i in 1:np];[sol[i][1,end] for i in 1:np]]
y=[[sol[i][2,1] for i in 1:np];[sol[i][2,end] for i in 1:np]]
fIndex=[[sol[i][nd,end] for i in 1:np];[sol[i][nd,end] for i in 1:np]];
t=[fill(𝑇[1],np);fill(𝑇[2],np)]
id=[id[:,1];id[:,1]]
elseif (size(sol,1)>1)&&(nd>2)
x=sol[1,:,:]
y=sol[2,:,:]
fIndex=sol[end,:,:]
Expand All @@ -53,9 +63,6 @@ function postprocess_MeshArray(sol::ODESolution,𝑃::FlowFields, 𝐷::NamedTup
end

𝑃.u0.grid.nFaces==1 ? fIndex=ones(size(x)) : nothing

t=[ceil(i/nf)-1 for i in 1:nt*nf]
t=𝑇[1] .+ (𝑇[2]-𝑇[1])/t[end].*t

df = DataFrame(ID=Int.(id[:]), x=x[:], y=y[:], fid=Int.(fIndex[:]), t=t[:])
return df
Expand Down Expand Up @@ -126,7 +133,19 @@ function postprocess_xy(sol,𝑃::FlowFields,𝐷::NamedTuple; id=missing, 𝑇=
nd=length(size(sol))

id=id*ones(1,size(sol,nd))
if (size(sol,1)>1)&&(nd>2)
t=[ceil(i/nf)-1 for i in 1:nt*nf]
#size(𝐷.XC,1)>1 ? fIndex=sol[3,:,:] : fIndex=fill(1.0,size(x))
t=𝑇[1] .+ (𝑇[2]-𝑇[1])/t[end].*t

if isa(sol,EnsembleSolution)
np=length(sol)
x=[mod.([sol[i][1,1] for i in 1:np],Ref(nx));
mod.([sol[i][1,end] for i in 1:np],Ref(nx))];
y=[mod.([sol[i][2,1] for i in 1:np],Ref(ny));
mod.([sol[i][2,end] for i in 1:np],Ref(ny))]
t=[fill(𝑇[1],np);fill(𝑇[2],np)]
id=[id[:,1];id[:,1]]
elseif (size(sol,1)>1)&&(nd>2)
x=mod.(sol[1,:,:],Ref(nx))
y=mod.(sol[2,:,:],Ref(ny))
elseif (nd>2)
Expand All @@ -136,9 +155,6 @@ function postprocess_xy(sol,𝑃::FlowFields,𝐷::NamedTuple; id=missing, 𝑇=
x=mod.(sol[1,:],Ref(nx))
y=mod.(sol[2,:],Ref(ny))
end
t=[ceil(i/nf)-1 for i in 1:nt*nf]
#size(𝐷.XC,1)>1 ? fIndex=sol[3,:,:] : fIndex=fill(1.0,size(x))
t=𝑇[1] .+ (𝑇[2]-𝑇[1])/t[end].*t

return DataFrame(ID=Int.(id[:]), t=t[:], x=x[:], y=y[:])
end
Expand Down

0 comments on commit 09f9d54

Please sign in to comment.