Skip to content

Commit

Permalink
Merge pull request #844 from open-AIMS/fix-rsa-test
Browse files Browse the repository at this point in the history
Overly constrained argument types causing tests to fail
  • Loading branch information
DanTanAtAims authored Sep 10, 2024
2 parents f19656b + cf5ba86 commit 7273f1e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
48 changes: 24 additions & 24 deletions src/analysis/sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,9 @@ end
rsa(X::DataFrame, y::AbstractVector{<:Real}, model_spec::DataFrame; S::Int64=10)::Dataset
rsa(r_s::YAXArray, X_q::AbstractArray, X_i::AbstractArray, y::AbstractVecOrMat{<:Real}, sel::BitVector)::YAXArray
rsa(X::Vector{Float64}, y::AbstractVector{<:Real}, foi_spec::DataFrame; S::Int64=10)::YAXArray
rsa(rs::ResultSet, y::YAXArray{Float64,1}; S::Int64=10)::Dataset
rsa(rs::ResultSet, y::YAXArray{Float64,1}, factors::Vector{Symbol}; S::Int64=10)::Dataset
rsa(rs::ResultSet, y::YAXArray{Float64,1}, factor::Symbol; S::Int64=10)::YAXArray
rsa(rs::ResultSet, y::AbstractVector{T}; S::Int64=10)::Dataset where {T<:Real}
rsa(rs::ResultSet, y::AbstractVector{T}, factors::Vector{Symbol}; S::Int64=10)::Dataset where {T<:Real}
rsa(rs::ResultSet, y::AbstractVector{T}, factor::Symbol; S::Int64=10)::YAXArray where {T<:Real}
Perform Regional Sensitivity Analysis.
Expand Down Expand Up @@ -605,10 +605,10 @@ ADRIA.sensitivity.rsa(X, y; S=10)
"""
function rsa(
X::DataFrame,
y::AbstractVector{<:Real},
y::AbstractVector{T},
model_spec::DataFrame;
S::Int64=10
)::Dataset
)::Dataset where {T<:Real}
factors = Symbol.(names(X))
N, D = size(X)

Expand Down Expand Up @@ -670,10 +670,10 @@ function rsa(
end
function rsa(
X::Vector{Float64},
y::AbstractVector{<:Real},
y::AbstractVector{T},
foi_spec::DataFrame;
S::Int64=10
)::YAXArray
)::YAXArray where {T<:Real}
factor = foi_spec.fieldname[1]
N = length(X)
sel = trues(N)
Expand All @@ -694,15 +694,15 @@ function rsa(
r_s, X_q, X, y, sel
)
end
function rsa(rs::ResultSet, y::YAXArray{Float64,1}; S::Int64=10)::Dataset
function rsa(rs::ResultSet, y::AbstractVector{T}; S::Int64=10)::Dataset where {T<:Real}
return rsa(rs.inputs[!, Not(:RCP)], vec(y), rs.model_spec; S=S)
end
function rsa(
rs::ResultSet,
y::YAXArray{Float64,1},
y::AbstractVector{T},
factors::Vector{Symbol};
S::Int64=10
)::Dataset
)::Dataset where {T<:Real}
return rsa(
rs.inputs[!, Not(:RCP)][!, factors],
vec(y),
Expand All @@ -712,10 +712,10 @@ function rsa(
end
function rsa(
rs::ResultSet,
y::YAXArray{Float64,1},
y::AbstractVector{T},
factor::Symbol;
S::Int64=10
)::YAXArray
)::YAXArray where {T<:Real}
return rsa(
rs.inputs[!, Not(:RCP)][!, factor],
vec(y),
Expand All @@ -728,10 +728,10 @@ end
outcome_map(p::YAXArray, X_q::AbstractArray, X_f::AbstractArray, y::AbstractVecOrMat{<:Real}, behave::BitVector; n_boot::Int64=100, conf::Float64=0.95)::YAXArray
outcome_map(X::DataFrame, y::AbstractVecOrMat{<:Real}, rule::Union{Function,BitVector,Vector{Int64}}, target_factors::Vector{Symbol}, model_spec::DataFrame; S::Int64=10, n_boot::Int64=100, conf::Float64=0.95)::Dataset
outcome_map(X::DataFrame, y::AbstractVecOrMat{<:Real}, rule::Union{Function,BitVector,Vector{Int64}}, target_factor::Symbol, model_spec::DataFrame; S::Int64=20, n_boot::Int64=100, conf::Float64=0.95)::YAXArray
outcome_map(X::DataFrame, y::YAXArray{Float64,1}, rule::Union{Function,BitVector,Vector{Int64}}; S::Int64=20, n_boot::Int64=100, conf::Float64=0.95)::Dataset
outcome_map(rs::ResultSet, y::YAXArray{Float64,1}, rule::Union{Function,BitVector,Vector{Int64}}, target_factors::Vector{Symbol}; S::Int64=20, n_boot::Int64=100, conf::Float64=0.95)::Dataset
outcome_map(rs::ResultSet, y::YAXArray{Float64,1}, rule::Union{Function,BitVector,Vector{Int64}}, target_factor::Symbol; S::Int64=20, n_boot::Int64=100, conf::Float64=0.95)::YAXArray
outcome_map(rs::ResultSet, y::YAXArray{Float64,1}, rule::Union{Function,BitVector,Vector{Int64}}; S::Int64=20, n_boot::Int64=100, conf::Float64=0.95)::Dataset
outcome_map(X::DataFrame, y::AbstractVector{T}, rule::Union{Function,BitVector,Vector{Int64}}; S::Int64=20, n_boot::Int64=100, conf::Float64=0.95)::Dataset where {T<:Real}
outcome_map(rs::ResultSet, y::AbstractVector{T}, rule::Union{Function,BitVector,Vector{Int64}}, target_factors::Vector{Symbol}; S::Int64=20, n_boot::Int64=100, conf::Float64=0.95)::Dataset where {T<:Real}
outcome_map(rs::ResultSet, y::AbstractVector{T}, rule::Union{Function,BitVector,Vector{Int64}}, target_factor::Symbol; S::Int64=20, n_boot::Int64=100, conf::Float64=0.95)::YAXArray where {T<:Real}
outcome_map(rs::ResultSet, y::AbstractVector{T}, rule::Union{Function,BitVector,Vector{Int64}}; S::Int64=20, n_boot::Int64=100, conf::Float64=0.95)::Dataset where {T<:Real}
Map normalized outcomes (defined by `rule`) to factor values discretized into `S` bins.
Expand Down Expand Up @@ -876,23 +876,23 @@ function outcome_map(
end
function outcome_map(
X::DataFrame,
y::YAXArray{Float64,1},
y::AbstractVector{T},
rule::Union{Function,BitVector,Vector{Int64}};
S::Int64=20,
n_boot::Int64=100,
conf::Float64=0.95
)::Dataset
)::Dataset where {T<:Real}
return outcome_map(X, vec(y), rule, names(X); S, n_boot, conf)
end
function outcome_map(
rs::ResultSet,
y::YAXArray{Float64,1},
y::AbstractVector{T},
rule::Union{Function,BitVector,Vector{Int64}},
target_factors::Vector{Symbol};
S::Int64=20,
n_boot::Int64=100,
conf::Float64=0.95
)::Dataset
)::Dataset where {T<:Real}
return outcome_map(
rs.inputs[:, Not(:RCP)],
vec(y),
Expand All @@ -906,13 +906,13 @@ function outcome_map(
end
function outcome_map(
rs::ResultSet,
y::YAXArray{Float64,1},
y::AbstractVector{T},
rule::Union{Function,BitVector,Vector{Int64}},
target_factor::Symbol;
S::Int64=20,
n_boot::Int64=100,
conf::Float64=0.95
)::YAXArray
)::YAXArray where {T<:Real}
return outcome_map(
rs.inputs[:, Not(:RCP)],
vec(y),
Expand All @@ -926,12 +926,12 @@ function outcome_map(
end
function outcome_map(
rs::ResultSet,
y::YAXArray{Float64,1},
y::AbstractVector{T},
rule::Union{Function,BitVector,Vector{Int64}};
S::Int64=20,
n_boot::Int64=100,
conf::Float64=0.95
)::Dataset
)::Dataset where {T<:Real}
return outcome_map(
rs.inputs[:, Not(:RCP)],
vec(y),
Expand Down
13 changes: 9 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,15 @@ function test_rs_w_fig()

### Regional Sensitivity Analysis

foi =
Symbol.([
"dhw_scenario", "wave_scenario", "N_seed_TA", "N_seed_CA", "fogging", "SRM"
])
foi = [
:dhw_scenario,
:wave_scenario,
:N_seed_TA,
:N_seed_CA,
:fogging,
:SRM
]

tac_rs = ADRIA.sensitivity.rsa(rs, mean_s_tac; S=10)
rsa_fig = ADRIA.viz.rsa(
rs,
Expand Down

0 comments on commit 7273f1e

Please sign in to comment.