Skip to content

Commit

Permalink
Remove duplicate rules keeping the one with highest probability
Browse files Browse the repository at this point in the history
  • Loading branch information
Zapiano committed Nov 15, 2024
1 parent f01f4b4 commit 24f2bf8
Showing 1 changed file with 59 additions and 7 deletions.
66 changes: 59 additions & 7 deletions src/analysis/rule_extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ function print_rules(rules::Vector{Rule{Vector{Vector},Vector{Float64}}})::Nothi
end

"""
cluster_rules(clusters::Vector{T}, X::DataFrame, max_rules::T; seed::Int64=123, kwargs...) where {T<:Integer,F<:Real}
cluster_rules(clusters::Union{BitVector,Vector{Bool}}, X::DataFrame, max_rules::T; kwargs...) where {T<:Int64}
cluster_rules(clusters::Vector{T}, X::DataFrame, max_rules::T;seed::Int64=123, remove_duplicates=true, kwargs...)::Vector{Rule{Vector{Vector},Vector{Float64}}} where {T<:Int64}
cluster_rules(clusters::Union{BitVector,Vector{Bool}}, X::DataFrame, max_rules::T;kwargs...)::Vector{Rule{Vector{Vector},Vector{Float64}}} where {T<:Int64}
Use SIRUS package to extract rules from time series clusters based on some summary metric
(default is median). More information about the keyword arguments accepeted can be found in
Expand All @@ -136,8 +136,10 @@ A StableRules object (implemented by SIRUS).
Electron. J. Statist. 15 (1) 427 - 505.
https://doi.org//10.1214/20-EJS1792
"""
function cluster_rules(clusters::Vector{T}, X::DataFrame, max_rules::T;
seed::Int64=123, kwargs...) where {T<:Int64}
function cluster_rules(
clusters::Vector{T}, X::DataFrame, max_rules::T;
seed::Int64=123, remove_duplicates=true, kwargs...
)::Vector{Rule{Vector{Vector},Vector{Float64}}} where {T<:Int64}
# Set seed and Random Number Generator
rng = StableRNG(seed)

Expand All @@ -155,13 +157,63 @@ function cluster_rules(clusters::Vector{T}, X::DataFrame, max_rules::T;
error("Failed fitting SIRUS. Try increasing the number of scenarios/samples.")
end

return rules(mach.fitresult)
if remove_duplicates
return _remove_duplicates(rules(mach.fitresult))
else
return rules(mach.fitresult)
end
end
function cluster_rules(clusters::Union{BitVector,Vector{Bool}}, X::DataFrame, max_rules::T;
kwargs...) where {T<:Int64}
function cluster_rules(
clusters::Union{BitVector,Vector{Bool}}, X::DataFrame, max_rules::T;
kwargs...
)::Vector{Rule{Vector{Vector},Vector{Float64}}} where {T<:Int64}
return cluster_rules(convert.(Int64, clusters), X, max_rules; kwargs...)
end

"""
_remove_duplicates(rules)::Vector{Rule{Vector{Vector},Vector{Float64}}}
Returns `Vector{Rule}` without duplicate rules (if there's any). The criteria to choose
which rule to keep is based on the rule consequence probability (the one with the highest
probability one is kept). If more than one rule has the same highest probability, the first
one is chosen.
"""
function _remove_duplicates(rules)::Vector{Rule{Vector{Vector},Vector{Float64}}}
# Extract subclauses from each rule without value
subclauses = join.([_strip_value.(r.condition) for r in rules], "_&_")
unique_subclauses = unique(subclauses)

# Return rules if there are no duplicate rules
n_unique_rules = length(unique_subclauses)
n_unique_rules == length(rules) && return rules

unique_rules::Vector{Rule} = Vector{Rule}(undef, n_unique_rules)
for (unique_idx, unique_subclause) in enumerate(unique_subclauses)
duplicate_rules_filter = unique_subclause .== subclauses

# If current rule has no duplicates go to next iteration
if sum(duplicate_rules_filter) == 1
unique_rules[unique_idx] = rules[duplicate_rules_filter][1]
continue
end

duplicate_rules = rules[duplicate_rules_filter]
max_probability_idx = findmax([r.consequent[1] for r in duplicate_rules])[2]
unique_rules[unique_idx] = duplicate_rules[max_probability_idx]
end

return unique_rules
end

"""
_strip_value(condition_subclause::Vector)
Helper function that extracts factor name and direction from a rule condition subclause
"""
function _strip_value(condition_subclause::Vector)
return join(condition_subclause[1:2], "__")
end

"""
maximum_probability(rules::SIRUS.StableRules{Int64})
Expand Down

0 comments on commit 24f2bf8

Please sign in to comment.