From 51ed52680144505936d0268d6e41a8fac9049721 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Thu, 12 Sep 2024 20:27:06 +0200
Subject: [PATCH 01/60] new first commit
---
Project.toml | 1 +
src/LaplaceRedux.jl | 3 +
src/baselaplace/predicting.jl | 2 +-
src/direct_mlj.jl | 142 ++++++++++++++++++++++++++++++++++
4 files changed, 147 insertions(+), 1 deletion(-)
create mode 100644 src/direct_mlj.jl
diff --git a/Project.toml b/Project.toml
index 451430e3..2028e0ba 100644
--- a/Project.toml
+++ b/Project.toml
@@ -4,6 +4,7 @@ authors = ["Patrick Altmeyer"]
version = "1.1.1"
[deps]
+CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl
index 9a36d18e..4d8df9ec 100644
--- a/src/LaplaceRedux.jl
+++ b/src/LaplaceRedux.jl
@@ -30,4 +30,7 @@ export empirical_frequency_binary_classification,
sharpness_regression,
extract_mean_and_variance,
sigma_scaling, rescale_stddev
+
+
+ include("direct_mlj.jl")
end
diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl
index 4a0a1d6c..71b74a4b 100644
--- a/src/baselaplace/predicting.jl
+++ b/src/baselaplace/predicting.jl
@@ -93,7 +93,7 @@ Computes the Bayesian predictivie distribution from a neural network with a Lapl
- `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`.
- `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model.
- `return_distr::Bool=false`: if `false` (default), the function outputs either the direct output of the chain or pseudo-probabilities (if `predict_proba=true`).
- if `true` predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.
+ if `true` predict return a probability distribution.
# Returns
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
new file mode 100644
index 00000000..ed046ff5
--- /dev/null
+++ b/src/direct_mlj.jl
@@ -0,0 +1,142 @@
+using Flux
+using ProgressMeter: Progress, next!, BarGlyphs
+using Random
+using Tables
+using LinearAlgebra
+using LaplaceRedux
+using ComputationalResources
+using MLJBase: MLJBase
+import MLJModelInterface as MMI
+using Optimisers: Optimisers
+
+"""
+ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
+
+A mutable struct representing a Laplace regression model.
+It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
+It has the following Hyperparameters:
+- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
+- `subnetwork_indices`: the indices of the subnetworks.
+- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
+- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
+- `σ`: the standard deviation of the prior distribution.
+- `μ₀`: the mean of the prior distribution.
+- `P₀`: the covariance matrix of the prior distribution.
+- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
+- `fit_prior_nsteps`: the number of steps used to fit the priors.
+"""
+MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
+ subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
+ subnetwork_indices = nothing
+ hessian_structure::Union{HessianStructure,Symbol,String} =
+ :full::(_ in (:full, :diagonal))
+ backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
+ σ::Float64 = 1.0
+ μ₀::Float64 = 0.0
+ P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
+ ret_distr::Bool = false::(_ in (true, false))
+ fit_prior_nsteps::Int = 100::(_ > 0)
+end
+
+
+function MLJModelInterface.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing)
+
+ X = MLJBase.matrix(X)
+
+
+
+
+ cache = nothing
+ return (fitresult, cache, report)
+end
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+"""
+ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
+
+A mutable struct representing a Laplace Classification model.
+It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
+
+
+The model also has the following parameters:
+
+- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
+- `subnetwork_indices`: the indices of the subnetworks.
+- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
+- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
+- `σ`: the standard deviation of the prior distribution.
+- `μ₀`: the mean of the prior distribution.
+- `P₀`: the covariance matrix of the prior distribution.
+- `link_approx`: the link approximation to use, either `:probit` or `:plugin`.
+- `predict_proba`: a boolean that select whether to predict probabilities or not.
+- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
+- `fit_prior_nsteps`: the number of steps used to fit the priors.
+"""
+MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
+ subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
+ subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([])
+ hessian_structure::Union{HessianStructure,Symbol,String} =
+ :full::(_ in (:full, :diagonal))
+ backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
+ σ::Float64 = 1.0
+ μ₀::Float64 = 0.0
+ P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
+ link_approx::Symbol = :probit::(_ in (:probit, :plugin))
+ predict_proba::Bool = true::(_ in (true, false))
+ ret_distr::Bool = false::(_ in (true, false))
+ fit_prior_nsteps::Int = 100::(_ > 0)
+end
+
+
+
+function MLJModelInterface.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing)
+
+
+
+
+
+ cache = nothing
+ return (fitresult, cache, report)
+end
+
+
+
+
+MLJBase.metadata_model(
+ LaplaceClassifier;
+ input_scitype=Union{
+ AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types
+ MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
+ },
+ target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass
+ load_path="LaplaceRedux.LaplaceClassification",
+)
+# metadata for each model,
+MLJBase.metadata_model(
+ LaplaceRegressor;
+ input_scitype=Union{
+ AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types
+ MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
+ },
+ target_scitype=AbstractArray{MLJBase.Continuous},
+ load_path="LaplaceRedux.MLJFlux.LaplaceRegression",
+)
\ No newline at end of file
From e8e96d13f956b7a74a00bb7246cf4f9b82f9274c Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 13 Sep 2024 00:40:22 +0200
Subject: [PATCH 02/60] various stuff
---
src/direct_mlj.jl | 99 ++++++++++++++++++++++++++++++++++-------------
1 file changed, 73 insertions(+), 26 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index ed046ff5..b1d0f01e 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -1,13 +1,11 @@
using Flux
-using ProgressMeter: Progress, next!, BarGlyphs
using Random
using Tables
using LinearAlgebra
using LaplaceRedux
-using ComputationalResources
-using MLJBase: MLJBase
+using MLJBase
import MLJModelInterface as MMI
-using Optimisers: Optimisers
+using Distributions: Normal
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
@@ -15,6 +13,7 @@ using Optimisers: Optimisers
A mutable struct representing a Laplace regression model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
It has the following Hyperparameters:
+- flux_model????
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
@@ -26,6 +25,7 @@ It has the following Hyperparameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
+ flux_model::Flux
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -39,13 +39,30 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
end
-function MLJModelInterface.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing)
+function MMI.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing)
+ features = Tables.schema(X).names
X = MLJBase.matrix(X)
-
-
-
+ la = LaplaceRedux.Laplace(
+ m.flux_model;
+ likelihood=:regression,
+ subset_of_weights=model.subset_of_weights,
+ subnetwork_indices=model.subnetwork_indices,
+ hessian_structure=model.hessian_structure,
+ backend=model.backend,
+ σ=m.σ,
+ μ₀=m.μ₀,
+ P₀=m.P₀,
+ )
+
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, zip(X, y))
+ optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
+
+
+ fitresult=la
+ report = (status="success", message="Model fitted successfully")
cache = nothing
return (fitresult, cache, report)
end
@@ -53,18 +70,12 @@ end
-
-
-
-
-
-
-
-
-
-
-
-
+function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
+ Xnew = MLJBase.matrix(Xnew) |> permutedims
+ la = fitresult[1]
+ yhat = LaplaceRedux.predict(la, Xnew; ret_distr=model.ret_distr)
+ return [Normal(μᵢ, σ̂) for (μᵢ,σ) ∈ yhat]
+end
@@ -79,6 +90,7 @@ It uses Laplace approximation to estimate the posterior distribution of the weig
The model also has the following parameters:
+- `model`: Flux ???
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
@@ -92,6 +104,8 @@ The model also has the following parameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
+
+ model::Flux
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([])
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -108,14 +122,47 @@ end
-function MLJModelInterface.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing)
-
-
-
+function MMI.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing)
+ features = Tables.schema(X).names
+ Xmatrix = MLJBase.matrix(X)
+ decode = y[1]
+ y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int
+
+ la = LaplaceRedux.Laplace(
+ m.flux_model;
+ likelihood=:classification,
+ subset_of_weights=model.subset_of_weights,
+ subnetwork_indices=model.subnetwork_indices,
+ hessian_structure=model.hessian_structure,
+ backend=model.backend,
+ σ=m.σ,
+ μ₀=m.μ₀,
+ P₀=m.P₀,
+ )
+
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, zip(X, y_plain))
+ optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
+
+
+ fitresult=la
+ report = (status="success", message="Model fitted successfully")
+ cache = nothing
+ return (fitresult, decode), cache, report
+end
- cache = nothing
- return (fitresult, cache, report)
+function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
+ la = fitresult
+ Xnew = MLJBase.matrix(Xnew) |> permutedims
+ predictions = LaplaceRedux.predict(
+ la,
+ Xnew;
+ link_approx=model.link_approx,
+ predict_proba=model.predict_proba,
+ ret_distr=model.ret_distr,
+ )
+ return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in predictions]
end
From 255cf19dfb00255ef099317501a8d6bc86f769a6 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 13 Sep 2024 01:28:13 +0200
Subject: [PATCH 03/60] fixes
---
src/direct_mlj.jl | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index b1d0f01e..d341f4a1 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -25,7 +25,7 @@ It has the following Hyperparameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
- flux_model::Flux
+ flux_model::Flux.Chain= nothing
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -105,7 +105,7 @@ The model also has the following parameters:
"""
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
- model::Flux
+ #model::Flux
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([])
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -172,18 +172,18 @@ MLJBase.metadata_model(
LaplaceClassifier;
input_scitype=Union{
AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types
- MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
+ MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass
- load_path="LaplaceRedux.LaplaceClassification",
+ load_path="LaplaceRedux.LaplaceClassifier",
)
# metadata for each model,
MLJBase.metadata_model(
LaplaceRegressor;
input_scitype=Union{
AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types
- MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
+ MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{MLJBase.Continuous},
- load_path="LaplaceRedux.MLJFlux.LaplaceRegression",
+ load_path="LaplaceRedux.LaplaceRegressor",
)
\ No newline at end of file
From 35ac2d8c25ebca59518754889b3c2a22f7796e29 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 13 Sep 2024 06:08:06 +0200
Subject: [PATCH 04/60] changes
---
src/direct_mlj.jl | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index d341f4a1..cb4c3c17 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -25,7 +25,7 @@ It has the following Hyperparameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
- flux_model::Flux.Chain= nothing
+ flux_model::Flux.Chain = nothing
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -39,7 +39,7 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
end
-function MMI.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing)
+function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
features = Tables.schema(X).names
X = MLJBase.matrix(X)
@@ -47,15 +47,17 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing)
la = LaplaceRedux.Laplace(
m.flux_model;
likelihood=:regression,
- subset_of_weights=model.subset_of_weights,
- subnetwork_indices=model.subnetwork_indices,
- hessian_structure=model.hessian_structure,
- backend=model.backend,
+ subset_of_weights=m.subset_of_weights,
+ subnetwork_indices=m.subnetwork_indices,
+ hessian_structure=m.hessian_structure,
+ backend=m.backend,
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
)
+ println(la)
+
# fit the Laplace model:
LaplaceRedux.fit!(la, zip(X, y))
optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
@@ -122,7 +124,7 @@ end
-function MMI.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing)
+function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
features = Tables.schema(X).names
Xmatrix = MLJBase.matrix(X)
decode = y[1]
From 44469a20bdc62e489c5c2173b3a7498f1dbf263d Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 18 Sep 2024 08:13:25 +0200
Subject: [PATCH 05/60] there is still a problem with the classifier
---
src/direct_mlj.jl | 49 +++++++++++++++++++++++++++++++++++++----------
1 file changed, 39 insertions(+), 10 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index cb4c3c17..b8da8f70 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -13,7 +13,8 @@ using Distributions: Normal
A mutable struct representing a Laplace regression model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
It has the following Hyperparameters:
-- flux_model????
+- `flux_model`: A flux model provided by the user and compatible with the dataset.
+- `epochs`: The number of training epochs.
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
@@ -25,7 +26,9 @@ It has the following Hyperparameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
+
flux_model::Flux.Chain = nothing
+ epochs::Integer = 1000::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -42,7 +45,19 @@ end
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
features = Tables.schema(X).names
- X = MLJBase.matrix(X)
+ X = MLJBase.matrix(X) |> permutedims
+ y = reshape(y, 1,:)
+ data_loader = Flux.DataLoader((X,y), batchsize=10)
+ opt_state = Flux.setup(Adam(), m.flux_model)
+ loss(y_hat, y) = mean(Flux.Losses.mse.(y_hat, y))
+
+ for epoch in 1:m.epochs
+ Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y
+ loss(model(X), y)
+
+ end
+ end
+
la = LaplaceRedux.Laplace(
m.flux_model;
@@ -56,11 +71,10 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
P₀=m.P₀,
)
- println(la)
# fit the Laplace model:
- LaplaceRedux.fit!(la, zip(X, y))
- optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
+ LaplaceRedux.fit!(la, data_loader )
+ optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)
fitresult=la
@@ -92,7 +106,8 @@ It uses Laplace approximation to estimate the posterior distribution of the weig
The model also has the following parameters:
-- `model`: Flux ???
+- `flux_model`: A flux model provided by the user and compatible with the dataset.
+- `epochs`: The number of training epochs.
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
@@ -107,7 +122,8 @@ The model also has the following parameters:
"""
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
- #model::Flux
+ flux_model::Flux.Chain = nothing
+ epochs::Integer = 1000::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([])
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -130,6 +146,17 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
decode = y[1]
y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int
+
+
+ for epoch in 1:m.epochs
+ Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y
+ loss(model(X), y)
+
+ end
+ end
+
+
+
la = LaplaceRedux.Laplace(
m.flux_model;
likelihood=:classification,
@@ -142,15 +169,17 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
P₀=m.P₀,
)
+
+
+
# fit the Laplace model:
LaplaceRedux.fit!(la, zip(X, y_plain))
optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
-
- fitresult=la
+
report = (status="success", message="Model fitted successfully")
cache = nothing
- return (fitresult, decode), cache, report
+ return (la, decode), cache, report
end
From 53152645ebec793d86d37fc1d4e818aaf75f45c4 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 18 Sep 2024 11:03:37 +0200
Subject: [PATCH 06/60] almost fixed
---
src/direct_mlj.jl | 71 ++++++++++++++++++++++++++++-------------------
1 file changed, 43 insertions(+), 28 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index b8da8f70..d30b6c71 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -15,6 +15,7 @@ It uses Laplace approximation to estimate the posterior distribution of the weig
It has the following Hyperparameters:
- `flux_model`: A flux model provided by the user and compatible with the dataset.
- `epochs`: The number of training epochs.
+- `batch_size`: The batch size.
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
@@ -29,6 +30,7 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
flux_model::Flux.Chain = nothing
epochs::Integer = 1000::(_ > 0)
+ batch_size::Integer= 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -47,7 +49,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
X = MLJBase.matrix(X) |> permutedims
y = reshape(y, 1,:)
- data_loader = Flux.DataLoader((X,y), batchsize=10)
+ data_loader = Flux.DataLoader((X,y), batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
loss(y_hat, y) = mean(Flux.Losses.mse.(y_hat, y))
@@ -88,9 +90,13 @@ end
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
Xnew = MLJBase.matrix(Xnew) |> permutedims
- la = fitresult[1]
- yhat = LaplaceRedux.predict(la, Xnew; ret_distr=model.ret_distr)
- return [Normal(μᵢ, σ̂) for (μᵢ,σ) ∈ yhat]
+ la = fitresult
+ yhat = LaplaceRedux.predict(la, Xnew; ret_distr=m.ret_distr)
+ # Extract mean and variance matrices
+ means, variances = yhat
+
+ # Create Normal distributions from the means and variances
+ return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]
end
@@ -108,6 +114,7 @@ The model also has the following parameters:
- `flux_model`: A flux model provided by the user and compatible with the dataset.
- `epochs`: The number of training epochs.
+- `batch_size`: The batch size.
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
@@ -124,33 +131,41 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis
flux_model::Flux.Chain = nothing
epochs::Integer = 1000::(_ > 0)
+ batch_size::Integer= 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
- subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([])
+ subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
:full::(_ in (:full, :diagonal))
backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
- link_approx::Symbol = :probit::(_ in (:probit, :plugin))
- predict_proba::Bool = true::(_ in (true, false))
ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
+ link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end
-
+#link_approx::Symbol = :probit::(_ in (:probit, :plugin))
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
features = Tables.schema(X).names
- Xmatrix = MLJBase.matrix(X)
+ X = MLJBase.matrix(X) |> permutedims
+ y = reshape(y, 1,:)
+
decode = y[1]
y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int
+ loss(y_hat, y) = mean(Flux.Losses.crossentropy(y_hat, y))
+
+ data_loader = Flux.DataLoader((X,y_plain), batchsize=m.batch_size)
+ opt_state = Flux.setup(Adam(), m.flux_model)
+
+
for epoch in 1:m.epochs
- Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y
- loss(model(X), y)
+ Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_plain
+ loss(model(X), y_plain)
end
end
@@ -160,10 +175,10 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
la = LaplaceRedux.Laplace(
m.flux_model;
likelihood=:classification,
- subset_of_weights=model.subset_of_weights,
- subnetwork_indices=model.subnetwork_indices,
- hessian_structure=model.hessian_structure,
- backend=model.backend,
+ subset_of_weights=m.subset_of_weights,
+ subnetwork_indices=m.subnetwork_indices,
+ hessian_structure=m.hessian_structure,
+ backend=m.backend,
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
@@ -173,27 +188,27 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
# fit the Laplace model:
- LaplaceRedux.fit!(la, zip(X, y_plain))
- optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
-
+ LaplaceRedux.fit!(la, data_loader )
+ optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)
+
+ fitresult=la
report = (status="success", message="Model fitted successfully")
cache = nothing
- return (la, decode), cache, report
+ return ((fitresult,decode), cache, report)
end
function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
- la = fitresult
+ #la = fitresult
Xnew = MLJBase.matrix(Xnew) |> permutedims
- predictions = LaplaceRedux.predict(
- la,
- Xnew;
- link_approx=model.link_approx,
- predict_proba=model.predict_proba,
- ret_distr=model.ret_distr,
- )
- return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in predictions]
+ #predictions = LaplaceRedux.predict(
+ #la,
+ #Xnew;
+ #link_approx=model.link_approx,
+ #predict_proba=model.predict_proba,
+ #ret_distr=model.ret_distr)
+ #return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in predictions]
end
From f6e7a00c4c16288f085657bd53caea79e70eb256 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 18 Sep 2024 11:20:36 +0200
Subject: [PATCH 07/60] works but i have to fix the hyperparameters
---
src/direct_mlj.jl | 21 +++++++++------------
1 file changed, 9 insertions(+), 12 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index d30b6c71..207a101c 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -156,7 +156,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int
- loss(y_hat, y) = mean(Flux.Losses.crossentropy(y_hat, y))
+ loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y)
data_loader = Flux.DataLoader((X,y_plain), batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
@@ -191,24 +191,21 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
LaplaceRedux.fit!(la, data_loader )
optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)
-
- fitresult=la
report = (status="success", message="Model fitted successfully")
cache = nothing
- return ((fitresult,decode), cache, report)
+ return ((la,decode), cache, report)
end
function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
- #la = fitresult
+ la = fitresult
Xnew = MLJBase.matrix(Xnew) |> permutedims
- #predictions = LaplaceRedux.predict(
- #la,
- #Xnew;
- #link_approx=model.link_approx,
- #predict_proba=model.predict_proba,
- #ret_distr=model.ret_distr)
- #return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in predictions]
+ predictions = LaplaceRedux.predict(
+ la,
+ Xnew;
+ link_approx=m.link_approx,
+ ret_distr=m.ret_distr)
+ return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction,augment=true) for prediction in predictions]
end
From d7c4f7b65cbf00404f43cf51ca851696f3643dda Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 18 Sep 2024 11:46:17 +0200
Subject: [PATCH 08/60] question on parameters....
---
src/direct_mlj.jl | 19 ++++++++++---------
1 file changed, 10 insertions(+), 9 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 207a101c..46f826b5 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -29,6 +29,7 @@ It has the following Hyperparameters:
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
flux_model::Flux.Chain = nothing
+ flux_loss = Flux.Losses.mse
epochs::Integer = 1000::(_ > 0)
batch_size::Integer= 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
@@ -39,7 +40,7 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
- ret_distr::Bool = false::(_ in (true, false))
+ #ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
end
@@ -51,11 +52,10 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
y = reshape(y, 1,:)
data_loader = Flux.DataLoader((X,y), batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
- loss(y_hat, y) = mean(Flux.Losses.mse.(y_hat, y))
for epoch in 1:m.epochs
Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y
- loss(model(X), y)
+ m.flux_loss(model(X), y)
end
end
@@ -91,7 +91,7 @@ end
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
Xnew = MLJBase.matrix(Xnew) |> permutedims
la = fitresult
- yhat = LaplaceRedux.predict(la, Xnew; ret_distr=m.ret_distr)
+ yhat = LaplaceRedux.predict(la, Xnew; ret_distr= false)
# Extract mean and variance matrices
means, variances = yhat
@@ -130,6 +130,7 @@ The model also has the following parameters:
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
flux_model::Flux.Chain = nothing
+ flux_loss = Flux.Losses.logitcrossentropy
epochs::Integer = 1000::(_ > 0)
batch_size::Integer= 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
@@ -140,12 +141,12 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
- ret_distr::Bool = false::(_ in (true, false))
+ #ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end
-#link_approx::Symbol = :probit::(_ in (:probit, :plugin))
+
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
features = Tables.schema(X).names
@@ -156,7 +157,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int
- loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y)
+ #loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y)
data_loader = Flux.DataLoader((X,y_plain), batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
@@ -165,7 +166,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
for epoch in 1:m.epochs
Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_plain
- loss(model(X), y_plain)
+ m.flux_loss(model(X), y_plain)
end
end
@@ -204,7 +205,7 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
la,
Xnew;
link_approx=m.link_approx,
- ret_distr=m.ret_distr)
+ ret_distr=false)
return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction,augment=true) for prediction in predictions]
end
From bbab4600302eb8c8b3836eab6b86cc0ab4ac74f7 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 18 Sep 2024 13:17:15 +0200
Subject: [PATCH 09/60] there is some problem with the one hot encoding
---
src/direct_mlj.jl | 19 ++++++++++---------
1 file changed, 10 insertions(+), 9 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 46f826b5..52e34a2d 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -46,7 +46,7 @@ end
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
- features = Tables.schema(X).names
+ #features = Tables.schema(X).names
X = MLJBase.matrix(X) |> permutedims
y = reshape(y, 1,:)
@@ -149,24 +149,25 @@ end
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
- features = Tables.schema(X).names
+ #features = Tables.schema(X).names
X = MLJBase.matrix(X) |> permutedims
- y = reshape(y, 1,:)
-
- decode = y[1]
+ decode = MMI.decoder(y[1])
+ #y = reshape(y, 1,:)
y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int
+ y_onehot = Flux.onehotbatch(y_plain, unique(y_plain) )
+
#loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y)
- data_loader = Flux.DataLoader((X,y_plain), batchsize=m.batch_size)
+ data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
for epoch in 1:m.epochs
- Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_plain
- m.flux_loss(model(X), y_plain)
+ Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot
+ m.flux_loss(model(X), y_onehot)
end
end
@@ -206,7 +207,7 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
Xnew;
link_approx=m.link_approx,
ret_distr=false)
- return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction,augment=true) for prediction in predictions]
+ return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction, pool=decode,augment=true) for prediction in predictions]
end
From 8af38aed4b5c1895fd2f08793b85f12c1130e2ae Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Thu, 19 Sep 2024 12:22:04 +0200
Subject: [PATCH 10/60] fixed error in univariatefinite
---
src/direct_mlj.jl | 25 ++++++++-----------------
1 file changed, 8 insertions(+), 17 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 52e34a2d..d2b42368 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -149,22 +149,13 @@ end
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
- #features = Tables.schema(X).names
X = MLJBase.matrix(X) |> permutedims
- decode = MMI.decoder(y[1])
- #y = reshape(y, 1,:)
- y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int
+ decode = y[1]
+ y_plain = MLJBase.int(y) .- 1
y_onehot = Flux.onehotbatch(y_plain, unique(y_plain) )
-
-
-
- #loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y)
-
data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
-
-
for epoch in 1:m.epochs
Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot
m.flux_loss(model(X), y_onehot)
@@ -173,7 +164,6 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
end
-
la = LaplaceRedux.Laplace(
m.flux_model;
likelihood=:classification,
@@ -186,9 +176,6 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
P₀=m.P₀,
)
-
-
-
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader )
optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)
@@ -206,8 +193,12 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
la,
Xnew;
link_approx=m.link_approx,
- ret_distr=false)
- return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction, pool=decode,augment=true) for prediction in predictions]
+ ret_distr=false) |>permutedims
+
+
+
+
+ return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in eachrow(predictions)]
end
From fe19d4d405c631632677ecea1d1cc5ae88b70359 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Thu, 19 Sep 2024 12:26:25 +0200
Subject: [PATCH 11/60] performance improvement
---
src/direct_mlj.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index d2b42368..20fa29a6 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -198,7 +198,7 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
- return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in eachrow(predictions)]
+ return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions)
end
From d809afbbfa65670d0f2e2ea2ffc1af796ee7ad66 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sat, 21 Sep 2024 03:47:34 +0200
Subject: [PATCH 12/60] JuliaFormatter
---
dev/issues/predict_slow.jl | 6 +--
docs/make.jl | 2 +-
src/LaplaceRedux.jl | 6 +--
src/baselaplace/predicting.jl | 1 -
src/calibration_functions.jl | 9 ++--
src/data/functions.jl | 4 +-
src/direct_mlj.jl | 85 ++++++++++++-----------------------
test/calibration.jl | 4 +-
test/data.jl | 2 +-
9 files changed, 45 insertions(+), 74 deletions(-)
diff --git a/dev/issues/predict_slow.jl b/dev/issues/predict_slow.jl
index ae713cbd..95c7c1f8 100644
--- a/dev/issues/predict_slow.jl
+++ b/dev/issues/predict_slow.jl
@@ -4,16 +4,16 @@ using Optimisers
X = MLJBase.table(rand(Float32, 100, 3));
y = coerce(rand("abc", 100), Multiclass);
-model = LaplaceClassification(optimiser=Optimisers.Adam(0.1), epochs=100);
+model = LaplaceClassification(; optimiser=Optimisers.Adam(0.1), epochs=100);
fitresult, _, _ = MLJBase.fit(model, 2, X, y);
la = fitresult[1];
Xmat = matrix(X) |> permutedims;
# Single test sample:
-Xtest = Xmat[:,1:10];
+Xtest = Xmat[:, 1:10];
Xtest_tab = MLJBase.table(Xtest');
MLJBase.predict(model, fitresult, Xtest_tab); # warm up
LaplaceRedux.predict(la, Xmat); # warm up
@time MLJBase.predict(model, fitresult, Xtest_tab);
@time LaplaceRedux.predict(la, Xtest);
-@time glm_predictive_distribution(la, Xtest);
\ No newline at end of file
+@time glm_predictive_distribution(la, Xtest);
diff --git a/docs/make.jl b/docs/make.jl
index 8aedd777..f24ea50b 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -24,7 +24,7 @@ makedocs(;
"Calibrated forecasts" => "tutorials/calibration.md",
],
"Reference" => "reference.md",
- "MLJ interface"=> "mlj_interface.md",
+ "MLJ interface" => "mlj_interface.md",
"Additional Resources" => "resources/_resources.md",
],
)
diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl
index 4d8df9ec..15ab3d4f 100644
--- a/src/LaplaceRedux.jl
+++ b/src/LaplaceRedux.jl
@@ -29,8 +29,8 @@ export empirical_frequency_binary_classification,
empirical_frequency_regression,
sharpness_regression,
extract_mean_and_variance,
- sigma_scaling, rescale_stddev
+ sigma_scaling,
+ rescale_stddev
-
- include("direct_mlj.jl")
+include("direct_mlj.jl")
end
diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl
index 71b74a4b..d26c9a07 100644
--- a/src/baselaplace/predicting.jl
+++ b/src/baselaplace/predicting.jl
@@ -144,7 +144,6 @@ function predict(
else
return fμ, pred_var
end
-
end
# Classification:
diff --git a/src/calibration_functions.jl b/src/calibration_functions.jl
index 1d313248..d6aa5062 100644
--- a/src/calibration_functions.jl
+++ b/src/calibration_functions.jl
@@ -78,7 +78,7 @@ Outputs: \
"""
function sharpness_classification(y_binary, distributions::Vector{Bernoulli{Float64}})
mean_class_one = mean(mean.(distributions[findall(y_binary .== 1)]))
- mean_class_zero = mean( 1 .- mean.(distributions[findall(y_binary .== 0)]))
+ mean_class_zero = mean(1 .- mean.(distributions[findall(y_binary .== 0)]))
return mean_class_one, mean_class_zero
end
@@ -176,8 +176,9 @@ Inputs: \
Outputs: \
- `sigma`: the scalar that maximize the likelihood.
"""
-function sigma_scaling(distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat}
- ) where T <: AbstractFloat
+function sigma_scaling(
+ distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat}
+) where {T<:AbstractFloat}
means, variances = extract_mean_and_variance(distr)
sigma = sqrt(1 / length(y_cal) * sum(norm.(y_cal .- means) ./ variances))
@@ -198,4 +199,4 @@ Outputs: \
function rescale_stddev(distr::Vector{Normal{T}}, s::T) where {T<:AbstractFloat}
rescaled_distr = [Normal(mean(d), std(d) * s) for d in distr]
return rescaled_distr
-end
\ No newline at end of file
+end
diff --git a/src/data/functions.jl b/src/data/functions.jl
index 99e4440c..5ae83805 100644
--- a/src/data/functions.jl
+++ b/src/data/functions.jl
@@ -1,4 +1,4 @@
-import Random
+using Random: Random
"""
toy_data_linear(N=100)
@@ -42,7 +42,7 @@ toy_data_non_linear()
```
"""
-function toy_data_non_linear( N=100; seed=nothing)
+function toy_data_non_linear(N=100; seed=nothing)
#set seed if available
if seed !== nothing
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 20fa29a6..73758254 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -27,11 +27,10 @@ It has the following Hyperparameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
-
flux_model::Flux.Chain = nothing
flux_loss = Flux.Losses.mse
epochs::Integer = 1000::(_ > 0)
- batch_size::Integer= 32::(_ > 0)
+ batch_size::Integer = 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -44,22 +43,19 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
fit_prior_nsteps::Int = 100::(_ > 0)
end
-
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
#features = Tables.schema(X).names
X = MLJBase.matrix(X) |> permutedims
- y = reshape(y, 1,:)
- data_loader = Flux.DataLoader((X,y), batchsize=m.batch_size)
+ y = reshape(y, 1, :)
+ data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
- for epoch in 1:m.epochs
- Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y
+ for epoch in 1:(m.epochs)
+ Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y
m.flux_loss(model(X), y)
-
end
- end
-
+ end
la = LaplaceRedux.Laplace(
m.flux_model;
@@ -73,25 +69,20 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
P₀=m.P₀,
)
-
# fit the Laplace model:
- LaplaceRedux.fit!(la, data_loader )
- optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
-
- fitresult=la
+ fitresult = la
report = (status="success", message="Model fitted successfully")
- cache = nothing
+ cache = nothing
return (fitresult, cache, report)
end
-
-
-
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
Xnew = MLJBase.matrix(Xnew) |> permutedims
la = fitresult
- yhat = LaplaceRedux.predict(la, Xnew; ret_distr= false)
+ yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
# Extract mean and variance matrices
means, variances = yhat
@@ -99,10 +90,6 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]
end
-
-
-
-
"""
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
@@ -128,11 +115,10 @@ The model also has the following parameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
-
flux_model::Flux.Chain = nothing
flux_loss = Flux.Losses.logitcrossentropy
epochs::Integer = 1000::(_ > 0)
- batch_size::Integer= 32::(_ > 0)
+ batch_size::Integer = 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
@@ -146,23 +132,19 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end
-
-
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
X = MLJBase.matrix(X) |> permutedims
decode = y[1]
- y_plain = MLJBase.int(y) .- 1
- y_onehot = Flux.onehotbatch(y_plain, unique(y_plain) )
- data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size)
+ y_plain = MLJBase.int(y) .- 1
+ y_onehot = Flux.onehotbatch(y_plain, unique(y_plain))
+ data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
opt_state = Flux.setup(Adam(), m.flux_model)
- for epoch in 1:m.epochs
- Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot
+ for epoch in 1:(m.epochs)
+ Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y_onehot
m.flux_loss(model(X), y_onehot)
-
end
- end
-
+ end
la = LaplaceRedux.Laplace(
m.flux_model;
@@ -177,37 +159,28 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
)
# fit the Laplace model:
- LaplaceRedux.fit!(la, data_loader )
- optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
report = (status="success", message="Model fitted successfully")
- cache = nothing
- return ((la,decode), cache, report)
+ cache = nothing
+ return ((la, decode), cache, report)
end
-
function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
la = fitresult
Xnew = MLJBase.matrix(Xnew) |> permutedims
- predictions = LaplaceRedux.predict(
- la,
- Xnew;
- link_approx=m.link_approx,
- ret_distr=false) |>permutedims
-
-
-
+ predictions =
+ LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |>
+ permutedims
return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions)
end
-
-
-
MLJBase.metadata_model(
LaplaceClassifier;
input_scitype=Union{
- AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types
+ AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass
@@ -217,9 +190,9 @@ MLJBase.metadata_model(
MLJBase.metadata_model(
LaplaceRegressor;
input_scitype=Union{
- AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types
+ AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
- target_scitype=AbstractArray{MLJBase.Continuous},
+ target_scitype=AbstractArray{MLJBase.Continuous},
load_path="LaplaceRedux.LaplaceRegressor",
-)
\ No newline at end of file
+)
diff --git a/test/calibration.jl b/test/calibration.jl
index 847543ce..988ff1ef 100644
--- a/test/calibration.jl
+++ b/test/calibration.jl
@@ -3,8 +3,6 @@ using LaplaceRedux
using Distributions
using Trapz
-
-
# Test for `sharpness_regression` function
@testset "sharpness_regression distributions tests" begin
@info " testing sharpness_regression with distributions"
@@ -201,4 +199,4 @@ end
distributions = [Normal(0, 1), Normal(2, 1), Normal(4, 1)]
rescaled_distr = rescale_stddev(distributions, 2.0)
@test rescaled_distr == [Normal(0, 2), Normal(2, 2), Normal(4, 2)]
-end
\ No newline at end of file
+end
diff --git a/test/data.jl b/test/data.jl
index e6d61881..05a5b30d 100644
--- a/test/data.jl
+++ b/test/data.jl
@@ -15,7 +15,7 @@ for fun in fun_list
# Generate data with the same seed
Random.seed!(seed)
xs1, ys1 = fun(N; seed=seed)
-
+
Random.seed!(seed)
xs2, ys2 = fun(N; seed=seed)
From 33d84f570c1fcbbcd91c3ab3c7c0c80d10bdd8eb Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sat, 21 Sep 2024 07:22:10 +0200
Subject: [PATCH 13/60] juliaformatter+docstrings
---
src/direct_mlj.jl | 83 +++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 83 insertions(+)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 73758254..77409e16 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -43,6 +43,34 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
fit_prior_nsteps::Int = 100::(_ > 0)
end
+@doc """
+ MMI.fit(m::LaplaceRegressor, verbosity, X, y)
+
+Fit a LaplaceRegressor model using the provided features and target values.
+
+# Arguments
+- `m::LaplaceRegressor`: The LaplaceRegressor model to be fitted.
+- `verbosity`: Verbosity level for logging.
+- `X`: Input features, expected to be in a format compatible with MLJBase.matrix.
+- `y`: Target values.
+
+# Returns
+- `fitresult`: The fitted Laplace model.
+- `cache`: Currently unused, returns `nothing`.
+- `report`: A tuple containing the status and message of the fitting process.
+
+# Description
+This function performs the following steps:
+1. Converts the input features `X` to a matrix and transposes it.
+2. Reshapes the target values `y` to shape (1,:).
+3. Creates a data loader for batching the data.
+4. Sets up the optimizer state using the Adam optimizer.
+5. Trains the model for a specified number of epochs.
+6. Initializes a Laplace model with the trained Flux model and specified parameters.
+7. Fits the Laplace model using the data loader.
+8. Optimizes the prior of the Laplace model.
+9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success.
+"""
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
#features = Tables.schema(X).names
@@ -79,6 +107,22 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
return (fitresult, cache, report)
end
+@doc """
+function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
+
+ Predicts the response for new data using a fitted LaplaceRegressor model.
+
+ # Arguments
+ - `m::LaplaceRegressor`: The LaplaceRegressor model.
+ - `fitresult`: The result of fitting the LaplaceRegressor model.
+ - `Xnew`: The new data for which predictions are to be made.
+
+ # Returns
+ - An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`.
+
+ The function first converts `Xnew` to a matrix and permutes its dimensions. It then uses the `LaplaceRedux.predict` function to obtain the predicted means and variances.
+Finally, it creates Normal distributions from these means and variances and returns them as an array.
+"""
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
Xnew = MLJBase.matrix(Xnew) |> permutedims
la = fitresult
@@ -132,6 +176,31 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end
+@doc """
+
+ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
+
+ Description:
+ This function fits a LaplaceClassifier model using the provided data. It first preprocesses the input data `X` and target labels `y`,
+ then trains a neural network model using the Flux library. After training, it fits a Laplace approximation to the trained model.
+
+ Arguments:
+ - `m::LaplaceClassifier`: The LaplaceClassifier model to be fitted.
+ - `verbosity`: Verbosity level for logging.
+ - `X`: Input data features.
+ - `y`: Target labels.
+
+ Returns:
+ - A tuple containing:
+ - `(la, decode)`: The fitted Laplace model and the decode function for the target labels.
+ - `cache`: A placeholder for any cached data (currently `nothing`).
+ - `report`: A report dictionary containing the status and message of the fitting process.
+
+ Notes:
+ - The function uses the Flux library for neural network training and the LaplaceRedux library for fitting the Laplace approximation.
+ - The `optimize_prior!` function is called to optimize the prior parameters of the Laplace model.
+
+"""
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
X = MLJBase.matrix(X) |> permutedims
decode = y[1]
@@ -167,6 +236,20 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
return ((la, decode), cache, report)
end
+@doc """
+Predicts the class probabilities for new data using a Laplace classifier.
+
+ # Arguments
+ - `m::LaplaceClassifier`: The Laplace classifier model.
+ - `(fitresult, decode)`: A tuple containing the fitted model result and the decode function.
+ - `Xnew`: The new data for which predictions are to be made.
+
+ # Returns
+ - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data.
+
+The function transforms the new data `Xnew` into a matrix, applies the LaplaceRedux
+prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object.
+"""
function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
la = fitresult
Xnew = MLJBase.matrix(Xnew) |> permutedims
From 9731297057c1324f4c6e9990d477435a0cb44fe3 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sat, 21 Sep 2024 09:38:12 +0200
Subject: [PATCH 14/60] removed predict_proba and ret_Distr from the struct
---
src/direct_mlj.jl | 22 ++++++++++------------
1 file changed, 10 insertions(+), 12 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 77409e16..3e166288 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -13,7 +13,9 @@ using Distributions: Normal
A mutable struct representing a Laplace regression model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
It has the following Hyperparameters:
-- `flux_model`: A flux model provided by the user and compatible with the dataset.
+- `flux_model`: A Flux model provided by the user and compatible with the dataset.
+- `flux_loss` : a Flux loss function
+- `optimiser` = a Flux optimiser
- `epochs`: The number of training epochs.
- `batch_size`: The batch size.
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
@@ -23,12 +25,12 @@ It has the following Hyperparameters:
- `σ`: the standard deviation of the prior distribution.
- `μ₀`: the mean of the prior distribution.
- `P₀`: the covariance matrix of the prior distribution.
-- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
flux_model::Flux.Chain = nothing
flux_loss = Flux.Losses.mse
+ optimiser = Adam()
epochs::Integer = 1000::(_ > 0)
batch_size::Integer = 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
@@ -39,7 +41,6 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
- #ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
end
@@ -72,12 +73,11 @@ This function performs the following steps:
9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success.
"""
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
- #features = Tables.schema(X).names
X = MLJBase.matrix(X) |> permutedims
y = reshape(y, 1, :)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
- opt_state = Flux.setup(Adam(), m.flux_model)
+ opt_state = Flux.setup(m.optimiser(), m.flux_model)
for epoch in 1:(m.epochs)
Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y
@@ -139,11 +139,11 @@ end
A mutable struct representing a Laplace Classification model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
-
-
The model also has the following parameters:
-- `flux_model`: A flux model provided by the user and compatible with the dataset.
+- `flux_model`: A Flux model provided by the user and compatible with the dataset.
+- `flux_loss` : a Flux loss function
+- `optimiser` = a Flux optimiser
- `epochs`: The number of training epochs.
- `batch_size`: The batch size.
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
@@ -154,13 +154,12 @@ The model also has the following parameters:
- `μ₀`: the mean of the prior distribution.
- `P₀`: the covariance matrix of the prior distribution.
- `link_approx`: the link approximation to use, either `:probit` or `:plugin`.
-- `predict_proba`: a boolean that select whether to predict probabilities or not.
-- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
flux_model::Flux.Chain = nothing
flux_loss = Flux.Losses.logitcrossentropy
+ optimiser = Adam()
epochs::Integer = 1000::(_ > 0)
batch_size::Integer = 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
@@ -171,7 +170,6 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
- #ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end
@@ -207,7 +205,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
y_plain = MLJBase.int(y) .- 1
y_onehot = Flux.onehotbatch(y_plain, unique(y_plain))
data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
- opt_state = Flux.setup(Adam(), m.flux_model)
+ opt_state = Flux.setup(m.optimiser, m.flux_model)
for epoch in 1:(m.epochs)
Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y_onehot
From f70d239adb8307492f40a6a172f02868c7de8ac8 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sat, 21 Sep 2024 11:47:51 +0200
Subject: [PATCH 15/60] mlj docstring in progress
---
src/direct_mlj.jl | 139 ++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 139 insertions(+)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 3e166288..37e0b4b0 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -277,3 +277,142 @@ MLJBase.metadata_model(
target_scitype=AbstractArray{MLJBase.Continuous},
load_path="LaplaceRedux.LaplaceRegressor",
)
+
+const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]"
+ "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in "
+ "Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", *NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103";
+
+ """
+$(MMI.doc_header(LaplaceClassifier))
+
+`LaplaceClassifier` implements the $DOC_CART.
+
+# Training data
+
+In MLJ or MLJBase, bind an instance `model` to data with
+
+ mach = machine(model, X, y)
+
+where
+
+- `X`: any table of input features (eg, a `DataFrame`) whose columns
+ each have one of the following element scitypes: `Continuous`,
+ `Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`
+
+- `y`: is the target, which can be any `AbstractVector` whose element
+ scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
+ with `scitype(y)`
+
+Train the machine using `fit!(mach, rows=...)`.
+
+
+# Hyperparameters
+
+- `max_depth=-1`: max depth of the decision tree (-1=any)
+
+- `min_samples_leaf=1`: max number of samples each leaf needs to have
+
+- `min_samples_split=2`: min number of samples needed for a split
+
+- `min_purity_increase=0`: min purity needed for a split
+
+- `n_subfeatures=0`: number of features to select at random (0 for all)
+
+- `post_prune=false`: set to `true` for post-fit pruning
+
+- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
+ combined purity `>= merge_purity_threshold`
+
+- `display_depth=5`: max depth to show when displaying the tree
+
+- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
+ :split)`
+
+- `rng=Random.GLOBAL_RNG`: random number generator or seed
+
+
+# Operations
+
+- `predict(mach, Xnew)`: return predictions of the target given
+ features `Xnew` having the same scitype as `X` above. Predictions
+ are probabilistic, but uncalibrated.
+
+- `predict_mode(mach, Xnew)`: instead return the mode of each
+ prediction above.
+
+
+# Fitted parameters
+
+The fields of `fitted_params(mach)` are:
+
+- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
+ algorithm
+
+- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
+ interface; see "Examples" below
+
+- `encoding`: dictionary of target classes keyed on integers used
+ internally by DecisionTree.jl
+
+- `features`: the names of the features encountered in training, in an
+ order consistent with the output of `print_tree` (see below)
+
+
+# Report
+
+The fields of `report(mach)` are:
+
+- `classes_seen`: list of target classes actually observed in training
+
+- `print_tree`: alternative method to print the fitted
+ tree, with single argument the tree depth; interpretation requires
+ internal integer-class encoding (see "Fitted parameters" above).
+
+- `features`: the names of the features encountered in training, in an
+ order consistent with the output of `print_tree` (see below)
+
+# Accessor functions
+
+- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
+ the type of importance is determined by the hyperparameter `feature_importance` (see
+ above)
+
+# Examples
+
+```
+using MLJ
+DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
+model = DecisionTreeClassifier(max_depth=3, min_samples_split=3)
+
+X, y = @load_iris
+mach = machine(model, X, y) |> fit!
+
+Xnew = (sepal_length = [6.4, 7.2, 7.4],
+ sepal_width = [2.8, 3.0, 2.8],
+ petal_length = [5.6, 5.8, 6.1],
+ petal_width = [2.1, 1.6, 1.9],)
+yhat = predict(mach, Xnew) # probabilistic predictions
+predict_mode(mach, Xnew) # point predictions
+pdf.(yhat, "virginica") # probabilities for the "verginica" class
+
+julia> tree = fitted_params(mach).tree
+petal_length < 2.45
+├─ setosa (50/50)
+└─ petal_width < 1.75
+ ├─ petal_length < 4.95
+ │ ├─ versicolor (47/48)
+ │ └─ virginica (4/6)
+ └─ petal_length < 4.85
+ ├─ virginica (2/3)
+ └─ virginica (43/43)
+
+using Plots, TreeRecipe
+plot(tree) # for a graphical representation of the tree
+
+feature_importances(mach)
+```
+
+See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl).
+
+"""
+LaplaceClassifier
\ No newline at end of file
From 80c65539e09487d9f826b3fc8824a53cfea1a16c Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sat, 21 Sep 2024 11:53:29 +0200
Subject: [PATCH 16/60] ah fixed constant , added prototype for regression
---
src/direct_mlj.jl | 145 ++++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 140 insertions(+), 5 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 37e0b4b0..4ecc1702 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -278,14 +278,14 @@ MLJBase.metadata_model(
load_path="LaplaceRedux.LaplaceRegressor",
)
-const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]"
- "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in "
- "Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", *NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103";
+const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]"*
+ "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in "*
+ "Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103"
"""
$(MMI.doc_header(LaplaceClassifier))
-`LaplaceClassifier` implements the $DOC_CART.
+`LaplaceClassifier` implements the $DOC_LAPLACE_REDUX for classification models.
# Training data
@@ -415,4 +415,139 @@ feature_importances(mach)
See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl).
"""
-LaplaceClassifier
\ No newline at end of file
+LaplaceClassifier
+
+"""
+$(MMI.doc_header(LaplaceRegressor))
+
+`LaplaceRegressor` implements the $DOC_LAPLACE_REDUX for regression models.
+
+# Training data
+
+In MLJ or MLJBase, bind an instance `model` to data with
+
+ mach = machine(model, X, y)
+
+where
+
+- `X`: any table of input features (eg, a `DataFrame`) whose columns
+ each have one of the following element scitypes: `Continuous`,
+ `Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`
+
+- `y`: is the target, which can be any `AbstractVector` whose element
+ scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
+ with `scitype(y)`
+
+Train the machine using `fit!(mach, rows=...)`.
+
+
+# Hyperparameters
+
+- `max_depth=-1`: max depth of the decision tree (-1=any)
+
+- `min_samples_leaf=1`: max number of samples each leaf needs to have
+
+- `min_samples_split=2`: min number of samples needed for a split
+
+- `min_purity_increase=0`: min purity needed for a split
+
+- `n_subfeatures=0`: number of features to select at random (0 for all)
+
+- `post_prune=false`: set to `true` for post-fit pruning
+
+- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
+ combined purity `>= merge_purity_threshold`
+
+- `display_depth=5`: max depth to show when displaying the tree
+
+- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
+ :split)`
+
+- `rng=Random.GLOBAL_RNG`: random number generator or seed
+
+
+# Operations
+
+- `predict(mach, Xnew)`: return predictions of the target given
+ features `Xnew` having the same scitype as `X` above. Predictions
+ are probabilistic, but uncalibrated.
+
+- `predict_mode(mach, Xnew)`: instead return the mode of each
+ prediction above.
+
+
+# Fitted parameters
+
+The fields of `fitted_params(mach)` are:
+
+- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
+ algorithm
+
+- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
+ interface; see "Examples" below
+
+- `encoding`: dictionary of target classes keyed on integers used
+ internally by DecisionTree.jl
+
+- `features`: the names of the features encountered in training, in an
+ order consistent with the output of `print_tree` (see below)
+
+
+# Report
+
+The fields of `report(mach)` are:
+
+- `classes_seen`: list of target classes actually observed in training
+
+- `print_tree`: alternative method to print the fitted
+ tree, with single argument the tree depth; interpretation requires
+ internal integer-class encoding (see "Fitted parameters" above).
+
+- `features`: the names of the features encountered in training, in an
+ order consistent with the output of `print_tree` (see below)
+
+# Accessor functions
+
+- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
+ the type of importance is determined by the hyperparameter `feature_importance` (see
+ above)
+
+# Examples
+
+```
+using MLJ
+DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
+model = DecisionTreeClassifier(max_depth=3, min_samples_split=3)
+
+X, y = @load_iris
+mach = machine(model, X, y) |> fit!
+
+Xnew = (sepal_length = [6.4, 7.2, 7.4],
+ sepal_width = [2.8, 3.0, 2.8],
+ petal_length = [5.6, 5.8, 6.1],
+ petal_width = [2.1, 1.6, 1.9],)
+yhat = predict(mach, Xnew) # probabilistic predictions
+predict_mode(mach, Xnew) # point predictions
+pdf.(yhat, "virginica") # probabilities for the "verginica" class
+
+julia> tree = fitted_params(mach).tree
+petal_length < 2.45
+├─ setosa (50/50)
+└─ petal_width < 1.75
+ ├─ petal_length < 4.95
+ │ ├─ versicolor (47/48)
+ │ └─ virginica (4/6)
+ └─ petal_length < 4.85
+ ├─ virginica (2/3)
+ └─ virginica (43/43)
+
+using Plots, TreeRecipe
+plot(tree) # for a graphical representation of the tree
+
+feature_importances(mach)
+```
+
+See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl).
+
+"""
+LaplaceRegressor
\ No newline at end of file
From d1c895c128cba4088bd9a4af4da1da938c9b3f6c Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sat, 21 Sep 2024 16:13:44 +0200
Subject: [PATCH 17/60] small stuff here and there in the docstring plus fixed
a small mistake in the optimiser
---
src/direct_mlj.jl | 58 ++++++++++++++++++-----------------------------
1 file changed, 22 insertions(+), 36 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 4ecc1702..bc1826ab 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -77,7 +77,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
X = MLJBase.matrix(X) |> permutedims
y = reshape(y, 1, :)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
- opt_state = Flux.setup(m.optimiser(), m.flux_model)
+ opt_state = Flux.setup(m.optimiser, m.flux_model)
for epoch in 1:(m.epochs)
Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y
@@ -381,10 +381,21 @@ The fields of `report(mach)` are:
```
using MLJ
-DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
-model = DecisionTreeClassifier(max_depth=3, min_samples_split=3)
+LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux
X, y = @load_iris
+
+# Define the Flux Chain model
+using Flux
+flux_model = Chain(
+ Dense(4, 10, relu),
+ Dense(10, 10, relu),
+ Dense(10, 3)
+)
+
+#Define the LaplaceClassifier
+model = LaplaceClassifier(flux_model=flux_model)
+
mach = machine(model, X, y) |> fit!
Xnew = (sepal_length = [6.4, 7.2, 7.4],
@@ -395,19 +406,8 @@ yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
pdf.(yhat, "virginica") # probabilities for the "verginica" class
-julia> tree = fitted_params(mach).tree
-petal_length < 2.45
-├─ setosa (50/50)
-└─ petal_width < 1.75
- ├─ petal_length < 4.95
- │ ├─ versicolor (47/48)
- │ └─ virginica (4/6)
- └─ petal_length < 4.85
- ├─ virginica (2/3)
- └─ virginica (43/43)
+julia> tree = fitted_params(mach)
-using Plots, TreeRecipe
-plot(tree) # for a graphical representation of the tree
feature_importances(mach)
```
@@ -516,33 +516,19 @@ The fields of `report(mach)` are:
```
using MLJ
-DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
-model = DecisionTreeClassifier(max_depth=3, min_samples_split=3)
+LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux
+model = LaplaceRegressor(flux_model=flux_model)
-X, y = @load_iris
+X, y = make_regression(100, 2; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) |> fit!
-Xnew = (sepal_length = [6.4, 7.2, 7.4],
- sepal_width = [2.8, 3.0, 2.8],
- petal_length = [5.6, 5.8, 6.1],
- petal_width = [2.1, 1.6, 1.9],)
+Xnew, _ = make_regression(3, 2; rng=123)
yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
-pdf.(yhat, "virginica") # probabilities for the "verginica" class
-julia> tree = fitted_params(mach).tree
-petal_length < 2.45
-├─ setosa (50/50)
-└─ petal_width < 1.75
- ├─ petal_length < 4.95
- │ ├─ versicolor (47/48)
- │ └─ virginica (4/6)
- └─ petal_length < 4.85
- ├─ virginica (2/3)
- └─ virginica (43/43)
-
-using Plots, TreeRecipe
-plot(tree) # for a graphical representation of the tree
+julia> tree = fitted_params(mach)
+
+
feature_importances(mach)
```
From 19ffa1667d38a83f5ddeac23ec3a2396f297f599 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sat, 21 Sep 2024 17:48:08 +0200
Subject: [PATCH 18/60] still writing this long ass docstring
---
src/direct_mlj.jl | 127 ++++++++++++++++------------------------------
1 file changed, 45 insertions(+), 82 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index bc1826ab..5e72712c 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -306,30 +306,35 @@ where
Train the machine using `fit!(mach, rows=...)`.
-# Hyperparameters
+# Hyperparameters (format: name-type-default value-restrictions)
-- `max_depth=-1`: max depth of the decision tree (-1=any)
+- `flux_model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
-- `min_samples_leaf=1`: max number of samples each leaf needs to have
+- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
-- `min_samples_split=2`: min number of samples needed for a split
+- `optimiser = Adam()` a Flux optimiser
-- `min_purity_increase=0`: min purity needed for a split
+- `epochs::Integer = 1000::(_ > 0)`: the number of training epochs.
-- `n_subfeatures=0`: number of features to select at random (0 for all)
+- `batch_size::Integer = 32::(_ > 0)`: the batch size.
-- `post_prune=false`: set to `true` for post-fit pruning
+- `subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
-- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
- combined purity `>= merge_purity_threshold`
+- `subnetwork_indices = nothing`: the indices of the subnetworks.
-- `display_depth=5`: max depth to show when displaying the tree
+- `hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal))`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
-- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
- :split)`
+- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
-- `rng=Random.GLOBAL_RNG`: random number generator or seed
+- `σ::Float64 = 1.0`: the standard deviation of the prior distribution.
+- `μ₀::Float64 = 0.0`: the mean of the prior distribution.
+
+- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution.
+
+- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors.
+
+- `link_approx::Symbol = :probit::(_ in (:probit, :plugin))`: the approximation to adopt to compute the probabilities.
# Operations
@@ -348,34 +353,8 @@ The fields of `fitted_params(mach)` are:
- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
algorithm
-- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
- interface; see "Examples" below
-
-- `encoding`: dictionary of target classes keyed on integers used
- internally by DecisionTree.jl
-
-- `features`: the names of the features encountered in training, in an
- order consistent with the output of `print_tree` (see below)
-
-
-# Report
-
-The fields of `report(mach)` are:
-
-- `classes_seen`: list of target classes actually observed in training
-
-- `print_tree`: alternative method to print the fitted
- tree, with single argument the tree depth; interpretation requires
- internal integer-class encoding (see "Fitted parameters" above).
-
-- `features`: the names of the features encountered in training, in an
- order consistent with the output of `print_tree` (see below)
-
# Accessor functions
-- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
- the type of importance is determined by the hyperparameter `feature_importance` (see
- above)
# Examples
@@ -408,8 +387,6 @@ pdf.(yhat, "virginica") # probabilities for the "verginica" class
julia> tree = fitted_params(mach)
-
-feature_importances(mach)
```
See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl).
@@ -435,35 +412,39 @@ where
`Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`
- `y`: is the target, which can be any `AbstractVector` whose element
- scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
+ scitype is `<:Continuous`; check the scitype
with `scitype(y)`
Train the machine using `fit!(mach, rows=...)`.
-# Hyperparameters
+# Hyperparameters (format: name-type-default value-restrictions)
+
+- `flux_model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
-- `max_depth=-1`: max depth of the decision tree (-1=any)
+- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
-- `min_samples_leaf=1`: max number of samples each leaf needs to have
+- `optimiser = Adam()` a Flux optimiser
-- `min_samples_split=2`: min number of samples needed for a split
+- `epochs::Integer = 1000::(_ > 0)`: the number of training epochs.
-- `min_purity_increase=0`: min purity needed for a split
+- `batch_size::Integer = 32::(_ > 0)`: the batch size.
-- `n_subfeatures=0`: number of features to select at random (0 for all)
+- `subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
-- `post_prune=false`: set to `true` for post-fit pruning
+- `subnetwork_indices = nothing`: the indices of the subnetworks.
-- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
- combined purity `>= merge_purity_threshold`
+- `hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal))`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
-- `display_depth=5`: max depth to show when displaying the tree
+- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
-- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
- :split)`
+- `σ::Float64 = 1.0`: the standard deviation of the prior distribution.
-- `rng=Random.GLOBAL_RNG`: random number generator or seed
+- `μ₀::Float64 = 0.0`: the mean of the prior distribution.
+
+- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution.
+
+- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors.
# Operations
@@ -483,46 +464,28 @@ The fields of `fitted_params(mach)` are:
- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
algorithm
-- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
- interface; see "Examples" below
-
-- `encoding`: dictionary of target classes keyed on integers used
- internally by DecisionTree.jl
-
-- `features`: the names of the features encountered in training, in an
- order consistent with the output of `print_tree` (see below)
-
-
-# Report
-
-The fields of `report(mach)` are:
-
-- `classes_seen`: list of target classes actually observed in training
-
-- `print_tree`: alternative method to print the fitted
- tree, with single argument the tree depth; interpretation requires
- internal integer-class encoding (see "Fitted parameters" above).
-
-- `features`: the names of the features encountered in training, in an
- order consistent with the output of `print_tree` (see below)
# Accessor functions
-- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
- the type of importance is determined by the hyperparameter `feature_importance` (see
- above)
+
# Examples
```
using MLJ
+using Flux
LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux
+flux_model = Chain(
+ Dense(4, 10, relu),
+ Dense(10, 10, relu),
+ Dense(10, 1)
+)
model = LaplaceRegressor(flux_model=flux_model)
-X, y = make_regression(100, 2; noise=0.5, sparse=0.2, outliers=0.1)
+X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) |> fit!
-Xnew, _ = make_regression(3, 2; rng=123)
+Xnew, _ = make_regression(3, 4; rng=123)
yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
From de0bd91ac850adcf6595fe8a8952c31705dce0c6 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sun, 22 Sep 2024 06:32:37 +0200
Subject: [PATCH 19/60] added fit_params functions
---
src/direct_mlj.jl | 123 ++++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 119 insertions(+), 4 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 5e72712c..4bd86eb0 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -107,6 +107,48 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
return (fitresult, cache, report)
end
+
+@doc """
+
+ function MMI.fitted_params(model::LaplaceRegressor, fitresult)
+
+
+ This function extracts the fitted parameters from a `LaplaceRegressor` model.
+
+ # Arguments
+ - `model::LaplaceRegressor`: The Laplace regression model.
+ - `fitresult`: the Laplace approximation (`la`).
+
+ # Returns
+ A named tuple containing:
+ - `μ`: The mean of the posterior distribution.
+ - `H`: The Hessian of the posterior distribution.
+ - `P`: The precision matrix of the posterior distribution.
+ - `Σ`: The covariance matrix of the posterior distribution.
+ - `n_data`: The number of data points.
+ - `n_params`: The number of parameters.
+ - `n_out`: The number of outputs.
+ - `loss`: The loss value of the posterior distribution.
+
+"""
+function MMI.fitted_params(model::LaplaceRegressor, fitresult)
+ la = fitresult
+ posterior = la.posterior
+ return (
+ μ = posterior.μ,
+ H = posterior.H,
+ P = posterior.P,
+ Σ = posterior.Σ,
+ n_data = posterior.n_data,
+ n_params = posterior.n_params,
+ n_out = posterior.n_out,
+ loss = posterior.loss
+ )
+end
+
+
+
+
@doc """
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
@@ -234,6 +276,53 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
return ((la, decode), cache, report)
end
+
+
+
+
+
+@doc """
+
+ function MMI.fitted_params(model::LaplaceClassifier, fitresult)
+
+
+ This function extracts the fitted parameters from a `LaplaceClassifier` model.
+
+ # Arguments
+ - `model::LaplaceClassifier`: The Laplace classifier model.
+ - `fitresult`: A tuple containing the Laplace approximation (`la`) and a decode function.
+
+ # Returns
+ A named tuple containing:
+ - `μ`: The mean of the posterior distribution.
+ - `H`: The Hessian of the posterior distribution.
+ - `P`: The precision matrix of the posterior distribution.
+ - `Σ`: The covariance matrix of the posterior distribution.
+ - `n_data`: The number of data points.
+ - `n_params`: The number of parameters.
+ - `n_out`: The number of outputs.
+ - `loss`: The loss value of the posterior distribution.
+
+"""
+function MMI.fitted_params(model::LaplaceClassifier, fitresult)
+ la, decode = fitresult
+ posterior = la.posterior
+ return (
+ μ = posterior.μ,
+ H = posterior.H,
+ P = posterior.P,
+ Σ = posterior.Σ,
+ n_data = posterior.n_data,
+ n_params = posterior.n_params,
+ n_out = posterior.n_out,
+ loss = posterior.loss
+ )
+end
+
+
+
+
+
@doc """
Predicts the class probabilities for new data using a Laplace classifier.
@@ -350,8 +439,21 @@ Train the machine using `fit!(mach, rows=...)`.
The fields of `fitted_params(mach)` are:
-- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
- algorithm
+ - `μ`: The mean of the posterior distribution.
+
+ - `H`: The Hessian of the posterior distribution.
+
+ - `P`: The precision matrix of the posterior distribution.
+
+ - `Σ`: The covariance matrix of the posterior distribution.
+
+ - `n_data`: The number of data points.
+
+ - `n_params`: The number of parameters.
+
+ - `n_out`: The number of outputs.
+
+ - `loss`: The loss value of the posterior distribution.
# Accessor functions
@@ -461,8 +563,21 @@ Train the machine using `fit!(mach, rows=...)`.
The fields of `fitted_params(mach)` are:
-- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
- algorithm
+ - `μ`: The mean of the posterior distribution.
+
+ - `H`: The Hessian of the posterior distribution.
+
+ - `P`: The precision matrix of the posterior distribution.
+
+ - `Σ`: The covariance matrix of the posterior distribution.
+
+ - `n_data`: The number of data points.
+
+ - `n_params`: The number of parameters.
+
+ - `n_out`: The number of outputs.
+
+ - `loss`: The loss value of the posterior distribution.
# Accessor functions
From 87df85f7c7f931f25710c9ec7fcb345380808620 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sun, 22 Sep 2024 14:52:07 +0200
Subject: [PATCH 20/60] switched to customized loop
---
src/direct_mlj.jl | 83 +++++++++++++++++++++++++++++++++++++++--------
1 file changed, 70 insertions(+), 13 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 4bd86eb0..9e6e5921 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -78,10 +78,40 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
y = reshape(y, 1, :)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
opt_state = Flux.setup(m.optimiser, m.flux_model)
+ loss_history=[]
for epoch in 1:(m.epochs)
- Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y
- m.flux_loss(model(X), y)
+
+ loss_per_epoch= 0.0
+
+
+ for (X_batch, y_batch) in data_loader
+ # Forward pass: compute predictions
+ y_pred = m.flux_model(X_batch)
+
+ # Compute loss
+ loss = m.flux_loss(y_pred, y_batch)
+
+ # Compute gradients explicitly
+ grads = gradient(m.flux_model) do model
+ # Recompute predictions inside gradient context
+ y_pred = model(X_batch)
+ m.flux_loss(y_pred, y_batch)
+ end
+
+ # Update parameters using the optimizer and computed gradients
+ Flux.Optimise.update!(opt_state ,m.flux_model , grads[1])
+
+ # Accumulate the loss for this batch
+ loss_per_epoch += sum(loss) # Summing the batch loss
+
+ end
+
+ push!(loss_history,loss_per_epoch )
+
+ # Print loss every 100 epochs if verbosity is 1 or more
+ if verbosity >= 1 && epoch % 100 == 0
+ println("Epoch $epoch: Loss: $loss_per_epoch ")
end
end
@@ -102,7 +132,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
fitresult = la
- report = (status="success", message="Model fitted successfully")
+ report = (status="success", loss_history = loss_history)
cache = nothing
return (fitresult, cache, report)
end
@@ -243,17 +273,44 @@ end
"""
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
X = MLJBase.matrix(X) |> permutedims
- decode = y[1]
- y_plain = MLJBase.int(y) .- 1
- y_onehot = Flux.onehotbatch(y_plain, unique(y_plain))
- data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
- opt_state = Flux.setup(m.optimiser, m.flux_model)
+# Store the first label as decode function
+decode = y[1]
+
+# Convert labels to integer format starting from 0 for one-hot encoding
+y_plain = MLJBase.int(y) .- 1
+
+# One-hot encoding of labels
+unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
+y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
+
+# Create a data loader for batching the data
+data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
+
+# Set up the optimizer for the model
+opt_state = Flux.setup(m.optimiser, m.flux_model)
+loss_history = []
+
+# Training loop for the specified number of epochs
+for epoch in 1:m.epochs
+ loss_per_epoch = 0.0 # Initialize loss for the current epoch
+
+ # Training function for each batch in the data loader
+ Flux.train!(m.flux_model, data_loader, opt_state) do model, X_batch, y_batch
+ # Forward pass and compute the loss
+ loss = m.flux_loss(model(X_batch), y_batch)
+
+ # Accumulate the loss for the current epoch
+ loss_per_epoch += sum(loss)
+ end
+
+ # Record loss history for analysis
+ push!(loss_history, loss_per_epoch)
- for epoch in 1:(m.epochs)
- Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y_onehot
- m.flux_loss(model(X), y_onehot)
- end
+ # Verbosity: print loss every 100 epochs
+ if verbosity >= 1 && epoch % 100 == 0
+ println("Epoch $epoch: Loss: $loss_per_epoch")
end
+end
la = LaplaceRedux.Laplace(
m.flux_model;
@@ -271,7 +328,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
LaplaceRedux.fit!(la, data_loader)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
- report = (status="success", message="Model fitted successfully")
+ report = (status="success", loss_history = loss_history)
cache = nothing
return ((la, decode), cache, report)
end
From 24459a1b807230e529c93fe80dc14868975f4e45 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sun, 22 Sep 2024 15:41:59 +0200
Subject: [PATCH 21/60] fixed error in custom loop
---
src/direct_mlj.jl | 18 ++++++++++++++----
1 file changed, 14 insertions(+), 4 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 9e6e5921..28407f4b 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -294,14 +294,24 @@ loss_history = []
for epoch in 1:m.epochs
loss_per_epoch = 0.0 # Initialize loss for the current epoch
- # Training function for each batch in the data loader
- Flux.train!(m.flux_model, data_loader, opt_state) do model, X_batch, y_batch
- # Forward pass and compute the loss
- loss = m.flux_loss(model(X_batch), y_batch)
+ for (X_batch, y_batch) in data_loader
+ # Compute gradients explicitly
+ grads = gradient(m.flux_model) do model
+ # Recompute predictions inside gradient context
+ y_pred = model(X_batch)
+ m.flux_loss(y_pred, y_batch)
+ end
+
+ # Update the model parameters using the computed gradients
+ Flux.Optimise.update!(opt_state, Flux.params(m.flux_model), grads)
+
+ # Compute the loss for this batch
+ loss = m.flux_loss(m.flux_model(X_batch), y_batch)
# Accumulate the loss for the current epoch
loss_per_epoch += sum(loss)
end
+
# Record loss history for analysis
push!(loss_history, loss_per_epoch)
From 0e2ca0330f35378e09a7696b980ccac34beb6945 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sun, 22 Sep 2024 16:16:46 +0200
Subject: [PATCH 22/60] various fixes
---
src/direct_mlj.jl | 100 ++++++++++++++++++++++++++++++++++++----------
1 file changed, 80 insertions(+), 20 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 28407f4b..87171553 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -13,7 +13,7 @@ using Distributions: Normal
A mutable struct representing a Laplace regression model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
It has the following Hyperparameters:
-- `flux_model`: A Flux model provided by the user and compatible with the dataset.
+- `model`: A Flux model provided by the user and compatible with the dataset.
- `flux_loss` : a Flux loss function
- `optimiser` = a Flux optimiser
- `epochs`: The number of training epochs.
@@ -28,7 +28,7 @@ It has the following Hyperparameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
- flux_model::Flux.Chain = nothing
+ model::Flux.Chain = nothing
flux_loss = Flux.Losses.mse
optimiser = Adam()
epochs::Integer = 1000::(_ > 0)
@@ -77,7 +77,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
X = MLJBase.matrix(X) |> permutedims
y = reshape(y, 1, :)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
- opt_state = Flux.setup(m.optimiser, m.flux_model)
+ opt_state = Flux.setup(m.optimiser, m.model)
loss_history=[]
for epoch in 1:(m.epochs)
@@ -87,20 +87,20 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
for (X_batch, y_batch) in data_loader
# Forward pass: compute predictions
- y_pred = m.flux_model(X_batch)
+ y_pred = m.model(X_batch)
# Compute loss
loss = m.flux_loss(y_pred, y_batch)
# Compute gradients explicitly
- grads = gradient(m.flux_model) do model
+ grads = gradient(m.model) do model
# Recompute predictions inside gradient context
y_pred = model(X_batch)
m.flux_loss(y_pred, y_batch)
end
# Update parameters using the optimizer and computed gradients
- Flux.Optimise.update!(opt_state ,m.flux_model , grads[1])
+ Flux.Optimise.update!(opt_state ,m.model , grads[1])
# Accumulate the loss for this batch
loss_per_epoch += sum(loss) # Summing the batch loss
@@ -116,7 +116,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
end
la = LaplaceRedux.Laplace(
- m.flux_model;
+ m.model;
likelihood=:regression,
subset_of_weights=m.subset_of_weights,
subnetwork_indices=m.subnetwork_indices,
@@ -179,6 +179,26 @@ end
+
+@doc """
+ MMI.training_losses(model::LaplaceRegressor, report)
+
+Retrieve the training loss history from the given `report`.
+
+# Arguments
+- `model::LaplaceRegressor`: The model for which the training losses are being retrieved.
+- `report`: An object containing the training report, which includes the loss history.
+
+# Returns
+- A collection representing the loss history from the training report.
+"""
+function MMI.training_losses(model::LaplaceRegressor, report)
+ return report.loss_history
+end
+
+
+
+
@doc """
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
@@ -213,7 +233,7 @@ A mutable struct representing a Laplace Classification model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
The model also has the following parameters:
-- `flux_model`: A Flux model provided by the user and compatible with the dataset.
+- `model`: A Flux model provided by the user and compatible with the dataset.
- `flux_loss` : a Flux loss function
- `optimiser` = a Flux optimiser
- `epochs`: The number of training epochs.
@@ -229,7 +249,7 @@ The model also has the following parameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
- flux_model::Flux.Chain = nothing
+ model::Flux.Chain = nothing
flux_loss = Flux.Losses.logitcrossentropy
optimiser = Adam()
epochs::Integer = 1000::(_ > 0)
@@ -287,7 +307,7 @@ y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
# Set up the optimizer for the model
-opt_state = Flux.setup(m.optimiser, m.flux_model)
+opt_state = Flux.setup(m.optimiser, m.model)
loss_history = []
# Training loop for the specified number of epochs
@@ -296,17 +316,17 @@ for epoch in 1:m.epochs
for (X_batch, y_batch) in data_loader
# Compute gradients explicitly
- grads = gradient(m.flux_model) do model
+ grads = gradient(m.model) do model
# Recompute predictions inside gradient context
y_pred = model(X_batch)
m.flux_loss(y_pred, y_batch)
end
# Update the model parameters using the computed gradients
- Flux.Optimise.update!(opt_state, Flux.params(m.flux_model), grads)
+ Flux.Optimise.update!(opt_state, Flux.params(m.model), grads)
# Compute the loss for this batch
- loss = m.flux_loss(m.flux_model(X_batch), y_batch)
+ loss = m.flux_loss(m.model(X_batch), y_batch)
# Accumulate the loss for the current epoch
loss_per_epoch += sum(loss)
@@ -323,7 +343,7 @@ for epoch in 1:m.epochs
end
la = LaplaceRedux.Laplace(
- m.flux_model;
+ m.model;
likelihood=:classification,
subset_of_weights=m.subset_of_weights,
subnetwork_indices=m.subnetwork_indices,
@@ -387,7 +407,21 @@ function MMI.fitted_params(model::LaplaceClassifier, fitresult)
end
+@doc """
+ MMI.training_losses(model::LaplaceClassifier, report)
+
+Retrieve the training loss history from the given `report`.
+# Arguments
+- `model::LaplaceClassifier`: The model for which the training losses are being retrieved.
+- `report`: An object containing the training report, which includes the loss history.
+
+# Returns
+- A collection representing the loss history from the training report.
+"""
+function MMI.training_losses(model::LaplaceClassifier, report)
+ return report.loss_history
+end
@doc """
@@ -414,6 +448,29 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions)
end
+
+MMI.metadata_pkg(
+ LaplaceRegressor,
+ name="LaplaceRedux",
+ package_uuid="??????",
+ package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
+ is_pure_julia=true,
+ is_wrapper=true,
+ package_license = "https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE",
+)
+
+MMI.metadata_pkg(
+ LaplaceClassifier,
+ name="LaplaceRedux",
+ package_uuid="dontknow",
+ package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
+ is_pure_julia=true,
+ is_wrapper=true,
+ package_license = "https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE",
+)
+
+
+
MLJBase.metadata_model(
LaplaceClassifier;
input_scitype=Union{
@@ -421,6 +478,7 @@ MLJBase.metadata_model(
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass
+ supports_training_losses = true,
load_path="LaplaceRedux.LaplaceClassifier",
)
# metadata for each model,
@@ -431,7 +489,9 @@ MLJBase.metadata_model(
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{MLJBase.Continuous},
+ supports_training_losses = true,
load_path="LaplaceRedux.LaplaceRegressor",
+
)
const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]"*
@@ -464,7 +524,7 @@ Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters (format: name-type-default value-restrictions)
-- `flux_model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
+- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
@@ -535,14 +595,14 @@ X, y = @load_iris
# Define the Flux Chain model
using Flux
-flux_model = Chain(
+model = Chain(
Dense(4, 10, relu),
Dense(10, 10, relu),
Dense(10, 3)
)
#Define the LaplaceClassifier
-model = LaplaceClassifier(flux_model=flux_model)
+model = LaplaceClassifier(model=model)
mach = machine(model, X, y) |> fit!
@@ -589,7 +649,7 @@ Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters (format: name-type-default value-restrictions)
-- `flux_model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
+- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
@@ -657,12 +717,12 @@ The fields of `fitted_params(mach)` are:
using MLJ
using Flux
LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux
-flux_model = Chain(
+model = Chain(
Dense(4, 10, relu),
Dense(10, 10, relu),
Dense(10, 1)
)
-model = LaplaceRegressor(flux_model=flux_model)
+model = LaplaceRegressor(model=model)
X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) |> fit!
From 841d5ebf5898a56700534ef304d0aeee4550c36c Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sun, 22 Sep 2024 18:08:02 +0200
Subject: [PATCH 23/60] added reformat. must updated again the doc string....
---
src/direct_mlj.jl | 168 ++++++++++++++++++++++++++++------------------
1 file changed, 104 insertions(+), 64 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 87171553..690c248c 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -44,6 +44,14 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
fit_prior_nsteps::Int = 100::(_ > 0)
end
+
+# for fit:
+MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims,reshape(y, 1, :))
+
+MMI.reformat(::LaplaceRegressor, X) = (MLJBase.matrix(X) |> permutedims,)
+
+
+
@doc """
MMI.fit(m::LaplaceRegressor, verbosity, X, y)
@@ -74,8 +82,16 @@ This function performs the following steps:
"""
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
- X = MLJBase.matrix(X) |> permutedims
+ #X = MLJBase.matrix(X) |> permutedims
+ #y = reshape(y, 1, :)
+
+ if Tables.istable(X)
+ X = Tables.matrix(X)|>permutedims
+ end
+
+ # Reshape y if necessary
y = reshape(y, 1, :)
+
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
opt_state = Flux.setup(m.optimiser, m.model)
loss_history=[]
@@ -216,7 +232,10 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
Finally, it creates Normal distributions from these means and variances and returns them as an array.
"""
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
- Xnew = MLJBase.matrix(Xnew) |> permutedims
+ #Xnew = MLJBase.matrix(Xnew) |> permutedims
+ if Tables.istable(Xnew)
+ Xnew = Tables.matrix(Xnew)|>permutedims
+ end
la = fitresult
yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
# Extract mean and variance matrices
@@ -266,6 +285,7 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end
+
@doc """
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
@@ -292,75 +312,85 @@ end
"""
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
- X = MLJBase.matrix(X) |> permutedims
-# Store the first label as decode function
-decode = y[1]
-
-# Convert labels to integer format starting from 0 for one-hot encoding
-y_plain = MLJBase.int(y) .- 1
-
-# One-hot encoding of labels
-unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
-y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
-
-# Create a data loader for batching the data
-data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
-
-# Set up the optimizer for the model
-opt_state = Flux.setup(m.optimiser, m.model)
-loss_history = []
-
-# Training loop for the specified number of epochs
-for epoch in 1:m.epochs
- loss_per_epoch = 0.0 # Initialize loss for the current epoch
-
- for (X_batch, y_batch) in data_loader
- # Compute gradients explicitly
- grads = gradient(m.model) do model
- # Recompute predictions inside gradient context
- y_pred = model(X_batch)
- m.flux_loss(y_pred, y_batch)
- end
+ #X = MLJBase.matrix(X) |> permutedims
+
+ if Tables.istable(X)
+ X = Tables.matrix(X)|>permutedims
+ end
- # Update the model parameters using the computed gradients
- Flux.Optimise.update!(opt_state, Flux.params(m.model), grads)
- # Compute the loss for this batch
- loss = m.flux_loss(m.model(X_batch), y_batch)
+ # Store the first label as decode function
+ decode = y[1]
- # Accumulate the loss for the current epoch
- loss_per_epoch += sum(loss)
- end
+ # Convert labels to integer format starting from 0 for one-hot encoding
+ y_plain = MLJBase.int(y) .- 1
-
- # Record loss history for analysis
- push!(loss_history, loss_per_epoch)
+ # One-hot encoding of labels
+ unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
+ y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
- # Verbosity: print loss every 100 epochs
- if verbosity >= 1 && epoch % 100 == 0
- println("Epoch $epoch: Loss: $loss_per_epoch")
- end
-end
+ # Create a data loader for batching the data
+ data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
- la = LaplaceRedux.Laplace(
- m.model;
- likelihood=:classification,
- subset_of_weights=m.subset_of_weights,
- subnetwork_indices=m.subnetwork_indices,
- hessian_structure=m.hessian_structure,
- backend=m.backend,
- σ=m.σ,
- μ₀=m.μ₀,
- P₀=m.P₀,
- )
+ # Set up the optimizer for the model
+ opt_state = Flux.setup(m.optimiser, m.model)
+ loss_history = []
- # fit the Laplace model:
- LaplaceRedux.fit!(la, data_loader)
- optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
+ # Training loop for the specified number of epochs
+ for epoch in 1:(m.epochs)
- report = (status="success", loss_history = loss_history)
- cache = nothing
- return ((la, decode), cache, report)
+ loss_per_epoch= 0.0
+
+
+ for (X_batch, y_batch) in data_loader
+ # Forward pass: compute predictions
+ y_pred = m.model(X_batch)
+
+ # Compute loss
+ loss = m.flux_loss(y_pred, y_batch)
+
+ # Compute gradients explicitly
+ grads = gradient(m.model) do model
+ # Recompute predictions inside gradient context
+ y_pred = model(X_batch)
+ m.flux_loss(y_pred, y_batch)
+ end
+
+ # Update parameters using the optimizer and computed gradients
+ Flux.Optimise.update!(opt_state ,m.model , grads[1])
+
+ # Accumulate the loss for this batch
+ loss_per_epoch += sum(loss) # Summing the batch loss
+
+ end
+
+ push!(loss_history,loss_per_epoch )
+
+ # Print loss every 100 epochs if verbosity is 1 or more
+ if verbosity >= 1 && epoch % 100 == 0
+ println("Epoch $epoch: Loss: $loss_per_epoch ")
+ end
+ end
+
+ la = LaplaceRedux.Laplace(
+ m.model;
+ likelihood=:classification,
+ subset_of_weights=m.subset_of_weights,
+ subnetwork_indices=m.subnetwork_indices,
+ hessian_structure=m.hessian_structure,
+ backend=m.backend,
+ σ=m.σ,
+ μ₀=m.μ₀,
+ P₀=m.P₀,
+ )
+
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
+
+ report = (status="success", loss_history = loss_history)
+ cache = nothing
+ return ((la, decode), cache, report)
end
@@ -439,8 +469,11 @@ The function transforms the new data `Xnew` into a matrix, applies the LaplaceRe
prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object.
"""
function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
+ if Tables.istable(Xnew)
+ Xnew = Tables.matrix(Xnew)|>permutedims
+ end
la = fitresult
- Xnew = MLJBase.matrix(Xnew) |> permutedims
+ #Xnew = MLJBase.matrix(Xnew) |> permutedims
predictions =
LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |>
permutedims
@@ -449,6 +482,13 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
end
+
+
+# for fit:
+MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y)
+
+MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,)
+
MMI.metadata_pkg(
LaplaceRegressor,
name="LaplaceRedux",
From de784f1c09fcd7dda9665fa3abde43aa26cb79f0 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Sun, 22 Sep 2024 19:14:12 +0200
Subject: [PATCH 24/60] work on the docstring and then made it in a module
---
src/direct_mlj.jl | 71 ++++++++++++++++++++++++++++++++---------------
1 file changed, 48 insertions(+), 23 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 690c248c..dd33e3c0 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -1,3 +1,4 @@
+module MLJLaplaceRedux
using Flux
using Random
using Tables
@@ -66,11 +67,11 @@ Fit a LaplaceRegressor model using the provided features and target values.
# Returns
- `fitresult`: The fitted Laplace model.
- `cache`: Currently unused, returns `nothing`.
-- `report`: A tuple containing the status and message of the fitting process.
+- `report`: A tuple containing the loss history of the fitting process.
# Description
This function performs the following steps:
-1. Converts the input features `X` to a matrix and transposes it.
+1. If X is a Table: converts the input features `X` to a matrix and transposes it.
2. Reshapes the target values `y` to shape (1,:).
3. Creates a data loader for batching the data.
4. Sets up the optimizer state using the Adam optimizer.
@@ -78,7 +79,7 @@ This function performs the following steps:
6. Initializes a Laplace model with the trained Flux model and specified parameters.
7. Fits the Laplace model using the data loader.
8. Optimizes the prior of the Laplace model.
-9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success.
+9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report listing training related statistics.
"""
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
@@ -95,6 +96,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
opt_state = Flux.setup(m.optimiser, m.model)
loss_history=[]
+ push!(loss_history, m.flux_loss(m.model(X), y ))
for epoch in 1:(m.epochs)
@@ -148,7 +150,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
fitresult = la
- report = (status="success", loss_history = loss_history)
+ report = (loss_history = loss_history,)
cache = nothing
return (fitresult, cache, report)
end
@@ -304,11 +306,7 @@ end
- A tuple containing:
- `(la, decode)`: The fitted Laplace model and the decode function for the target labels.
- `cache`: A placeholder for any cached data (currently `nothing`).
- - `report`: A report dictionary containing the status and message of the fitting process.
-
- Notes:
- - The function uses the Flux library for neural network training and the LaplaceRedux library for fitting the Laplace approximation.
- - The `optimize_prior!` function is called to optimize the prior parameters of the Laplace model.
+ - `report`: A NamedTuple containing statistics related to the fitting process.
"""
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
@@ -388,7 +386,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
LaplaceRedux.fit!(la, data_loader)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
- report = (status="success", loss_history = loss_history)
+ report = (loss_history = loss_history,)
cache = nothing
return ((la, decode), cache, report)
end
@@ -545,9 +543,13 @@ $(MMI.doc_header(LaplaceClassifier))
# Training data
-In MLJ or MLJBase, bind an instance `model` to data with
+In MLJ or MLJBase, given a dataset X,y and a Flux Chain adapt to the dataset, pass the chain to the model
+
+laplace_model = LaplaceClassifier(model = Flux_Chain,kwargs...)
+
+then bind an instance `laplace_model` to data with
- mach = machine(model, X, y)
+ mach = machine(laplace_model, X, y)
where
@@ -601,6 +603,8 @@ Train the machine using `fit!(mach, rows=...)`.
- `predict_mode(mach, Xnew)`: instead return the mode of each
prediction above.
+- `training_losses(mach)`: return the loss history from report
+
# Fitted parameters
@@ -622,6 +626,14 @@ The fields of `fitted_params(mach)` are:
- `loss`: The loss value of the posterior distribution.
+
+
+ # Report
+
+The fields of `report(mach)` are:
+
+- `loss_history`: an array containing the total loss per epoch.
+
# Accessor functions
@@ -652,9 +664,9 @@ Xnew = (sepal_length = [6.4, 7.2, 7.4],
petal_width = [2.1, 1.6, 1.9],)
yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
+training_losses(mach) # loss history per epoch
pdf.(yhat, "virginica") # probabilities for the "verginica" class
-
-julia> tree = fitted_params(mach)
+fitted_params(mach) # NamedTuple with the fitted params of Laplace
```
@@ -670,10 +682,13 @@ $(MMI.doc_header(LaplaceRegressor))
# Training data
-In MLJ or MLJBase, bind an instance `model` to data with
+In MLJ or MLJBase, given a dataset X,y and a Flux Chain adapt to the dataset, pass the chain to the model
+
+laplace_model = LaplaceRegressor(model = Flux_Chain,kwargs...)
- mach = machine(model, X, y)
+then bind an instance `laplace_model` to data with
+ mach = machine(laplace_model, X, y)
where
- `X`: any table of input features (eg, a `DataFrame`) whose columns
@@ -689,7 +704,7 @@ Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters (format: name-type-default value-restrictions)
-- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
+- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
@@ -725,6 +740,8 @@ Train the machine using `fit!(mach, rows=...)`.
- `predict_mode(mach, Xnew)`: instead return the mode of each
prediction above.
+- `training_losses(mach)`: return the loss history from report
+
# Fitted parameters
@@ -747,6 +764,15 @@ The fields of `fitted_params(mach)` are:
- `loss`: The loss value of the posterior distribution.
+# Report
+
+The fields of `report(mach)` are:
+
+- `loss_history`: an array containing the total loss per epoch.
+
+
+
+
# Accessor functions
@@ -770,15 +796,14 @@ mach = machine(model, X, y) |> fit!
Xnew, _ = make_regression(3, 4; rng=123)
yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
+training_losses(mach) # loss history per epoch
+fitted_params(mach) # NamedTuple with the fitted params of Laplace
-julia> tree = fitted_params(mach)
-
-
-
-feature_importances(mach)
```
See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl).
"""
-LaplaceRegressor
\ No newline at end of file
+LaplaceRegressor
+
+end # module
\ No newline at end of file
From b7a99f6e40eb216dec08c34edd19a614e7d265de Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Mon, 23 Sep 2024 17:42:32 +0200
Subject: [PATCH 25/60] fixed uuid, made test file.for direct_mlj. shut down
the tests for mljflux.
---
src/direct_mlj.jl | 14 +++---
test/direct_mlj_interface.jl | 98 ++++++++++++++++++++++++++++++++++++
test/runtests.jl | 8 ++-
3 files changed, 111 insertions(+), 9 deletions(-)
create mode 100644 test/direct_mlj_interface.jl
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index dd33e3c0..30c57045 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -1,4 +1,4 @@
-module MLJLaplaceRedux
+#module MLJLaplaceRedux
using Flux
using Random
using Tables
@@ -304,7 +304,7 @@ end
Returns:
- A tuple containing:
- - `(la, decode)`: The fitted Laplace model and the decode function for the target labels.
+ - `(la, y[1])`: The fitted Laplace model and the decode function for the target labels.
- `cache`: A placeholder for any cached data (currently `nothing`).
- `report`: A NamedTuple containing statistics related to the fitting process.
@@ -318,7 +318,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
# Store the first label as decode function
- decode = y[1]
+ #decode = y[1]
# Convert labels to integer format starting from 0 for one-hot encoding
y_plain = MLJBase.int(y) .- 1
@@ -388,7 +388,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
report = (loss_history = loss_history,)
cache = nothing
- return ((la, decode), cache, report)
+ return ((la, y[1]), cache, report)
end
@@ -490,7 +490,7 @@ MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,)
MMI.metadata_pkg(
LaplaceRegressor,
name="LaplaceRedux",
- package_uuid="??????",
+ package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478",
package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
is_pure_julia=true,
is_wrapper=true,
@@ -500,7 +500,7 @@ MMI.metadata_pkg(
MMI.metadata_pkg(
LaplaceClassifier,
name="LaplaceRedux",
- package_uuid="dontknow",
+ package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478",
package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
is_pure_julia=true,
is_wrapper=true,
@@ -806,4 +806,4 @@ See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl
"""
LaplaceRegressor
-end # module
\ No newline at end of file
+#end # module
\ No newline at end of file
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
new file mode 100644
index 00000000..13c91a21
--- /dev/null
+++ b/test/direct_mlj_interface.jl
@@ -0,0 +1,98 @@
+using Random: Random
+import Random.seed!
+using MLJBase: MLJBase, categorical
+using Flux
+using StableRNGs
+
+
+@testset "Regression" begin
+ function basictest_regression(X, y, builder, optimiser, threshold)
+ optimiser = deepcopy(optimiser)
+
+ stable_rng = StableRNGs.StableRNG(123)
+
+ model = LaplaceRegression(;
+ builder=builder,
+ optimiser=optimiser,
+ acceleration=MLJBase.CPUThreads(),
+ loss=Flux.Losses.mse,
+ rng=stable_rng,
+ lambda=-1.0,
+ alpha=-1.0,
+ epochs=-1,
+ batch_size=-1,
+ subset_of_weights=:incorrect,
+ hessian_structure=:incorrect,
+ backend=:incorrect,
+ ret_distr=true,
+ )
+
+ fitresult, cache, _report = MLJBase.fit(model, 0, X, y)
+
+ history = _report.training_losses
+ @test length(history) == model.epochs + 1
+
+ # increase iterations and check update is incremental:
+ model.epochs = model.epochs + 3
+
+ fitresult, cache, _report = @test_logs(
+ (:info, r""), # one line of :info per extra epoch
+ (:info, r""),
+ (:info, r""),
+ MLJBase.update(model, 2, fitresult, cache, X, y)
+ )
+
+ @test :chain in keys(MLJBase.fitted_params(model, fitresult))
+
+ history = _report.training_losses
+ @test length(history) == model.epochs + 1
+
+ yhat = MLJBase.predict(model, fitresult, X)
+
+ # start fresh with small epochs:
+ model = LaplaceRegression(;
+ builder=builder,
+ optimiser=optimiser,
+ epochs=2,
+ acceleration=CPU1(),
+ rng=stable_rng,
+ )
+
+ fitresult, cache, _report = MLJBase.fit(model, 0, X, y)
+
+ # change batch_size and check it performs cold restart:
+ model.batch_size = 2
+ fitresult, cache, _report = @test_logs(
+ (:info, r""), # one line of :info per extra epoch
+ (:info, r""),
+ MLJBase.update(model, 2, fitresult, cache, X, y)
+ )
+
+ # change learning rate and check it does *not* restart:
+ model.optimiser.eta /= 2
+ fitresult, cache, _report = @test_logs(
+ MLJBase.update(model, 2, fitresult, cache, X, y)
+ )
+
+ # set `optimiser_changes_trigger_retraining = true` and change
+ # learning rate and check it does restart:
+ model.optimiser_changes_trigger_retraining = true
+ model.optimiser.eta /= 2
+ @test_logs(
+ (:info, r""), # one line of :info per extra epoch
+ (:info, r""),
+ MLJBase.update(model, 2, fitresult, cache, X, y)
+ )
+
+ return true
+ end
+
+ seed!(1234)
+ N = 300
+ X = MLJBase.table(rand(Float32, N, 4))
+ ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)
+ builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu)
+ optimiser = Flux.Optimise.Adam(0.03)
+
+ @test basictest_regression(X, ycont, builder, optimiser, 0.9)
+end
\ No newline at end of file
diff --git a/test/runtests.jl b/test/runtests.jl
index 0459cf5a..4995c6f9 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -35,7 +35,11 @@ using Test
include("krondecomposed.jl")
end
- @testset "MLJFlux" begin
- include("mlj_flux_interfacing.jl")
+ #@testset "MLJFlux" begin
+ #include("mlj_flux_interfacing.jl")
+ #end
+ @testset "ML" begin
+ include("direct_mlj_interface.jl")
end
+
end
From c44b8d83795690e32786a13c805c139728acb190 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Mon, 23 Sep 2024 18:54:51 +0200
Subject: [PATCH 26/60] added tests. should be good....
---
test/Project.toml | 1 +
test/direct_mlj_interface.jl | 133 +++++++++++------------------------
2 files changed, 44 insertions(+), 90 deletions(-)
diff --git a/test/Project.toml b/test/Project.toml
index 750ea47e..2d4f4647 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -7,6 +7,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 13c91a21..66fe020d 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -3,96 +3,49 @@ import Random.seed!
using MLJBase: MLJBase, categorical
using Flux
using StableRNGs
+using MLJ
+using LaplaceRedux
@testset "Regression" begin
- function basictest_regression(X, y, builder, optimiser, threshold)
- optimiser = deepcopy(optimiser)
-
- stable_rng = StableRNGs.StableRNG(123)
-
- model = LaplaceRegression(;
- builder=builder,
- optimiser=optimiser,
- acceleration=MLJBase.CPUThreads(),
- loss=Flux.Losses.mse,
- rng=stable_rng,
- lambda=-1.0,
- alpha=-1.0,
- epochs=-1,
- batch_size=-1,
- subset_of_weights=:incorrect,
- hessian_structure=:incorrect,
- backend=:incorrect,
- ret_distr=true,
- )
-
- fitresult, cache, _report = MLJBase.fit(model, 0, X, y)
-
- history = _report.training_losses
- @test length(history) == model.epochs + 1
-
- # increase iterations and check update is incremental:
- model.epochs = model.epochs + 3
-
- fitresult, cache, _report = @test_logs(
- (:info, r""), # one line of :info per extra epoch
- (:info, r""),
- (:info, r""),
- MLJBase.update(model, 2, fitresult, cache, X, y)
- )
-
- @test :chain in keys(MLJBase.fitted_params(model, fitresult))
-
- history = _report.training_losses
- @test length(history) == model.epochs + 1
-
- yhat = MLJBase.predict(model, fitresult, X)
-
- # start fresh with small epochs:
- model = LaplaceRegression(;
- builder=builder,
- optimiser=optimiser,
- epochs=2,
- acceleration=CPU1(),
- rng=stable_rng,
- )
-
- fitresult, cache, _report = MLJBase.fit(model, 0, X, y)
-
- # change batch_size and check it performs cold restart:
- model.batch_size = 2
- fitresult, cache, _report = @test_logs(
- (:info, r""), # one line of :info per extra epoch
- (:info, r""),
- MLJBase.update(model, 2, fitresult, cache, X, y)
- )
-
- # change learning rate and check it does *not* restart:
- model.optimiser.eta /= 2
- fitresult, cache, _report = @test_logs(
- MLJBase.update(model, 2, fitresult, cache, X, y)
- )
-
- # set `optimiser_changes_trigger_retraining = true` and change
- # learning rate and check it does restart:
- model.optimiser_changes_trigger_retraining = true
- model.optimiser.eta /= 2
- @test_logs(
- (:info, r""), # one line of :info per extra epoch
- (:info, r""),
- MLJBase.update(model, 2, fitresult, cache, X, y)
- )
-
- return true
- end
-
- seed!(1234)
- N = 300
- X = MLJBase.table(rand(Float32, N, 4))
- ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)
- builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu)
- optimiser = Flux.Optimise.Adam(0.03)
-
- @test basictest_regression(X, ycont, builder, optimiser, 0.9)
-end
\ No newline at end of file
+ flux_model = Chain(
+ Dense(4, 10, relu),
+ Dense(10, 10, relu),
+ Dense(10, 1)
+ )
+ model = LaplaceRegressor(model=flux_model,epochs=50)
+
+ X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
+ mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
+ MLJBase.fit!(mach)
+ Xnew, _ = make_regression(3, 4; rng=123)
+ yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
+ MLJBase.predict_mode(mach, Xnew) # point predictions
+ MLJBase.fitted_params(mach) #fitted params function
+end
+
+
+
+@testset "Classification" begin
+# Define the model
+flux_model = Chain(
+ Dense(4, 10, relu),
+ Dense(10, 10, relu),
+ Dense(10, 3)
+)
+
+model = LaplaceClassifier(model=flux_model,epochs=50)
+
+X, y = @load_iris
+mach = machine(model, X, y)
+MLJBase.fit!(mach)
+Xnew = (sepal_length = [6.4, 7.2, 7.4],
+ sepal_width = [2.8, 3.0, 2.8],
+ petal_length = [5.6, 5.8, 6.1],
+ petal_width = [2.1, 1.6, 1.9],)
+yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
+predict_mode(mach, Xnew) # point predictions
+pdf.(yhat, "virginica") # probabilities for the "verginica" class
+MLJBase.fitted_params(mach) # fitted params
+
+end
From b762185fb93cf81fd220718b435d60eeb817bc8f Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Mon, 23 Sep 2024 20:07:01 +0200
Subject: [PATCH 27/60] added mlj to the dependency in test
---
Project.toml | 1 +
1 file changed, 1 insertion(+)
diff --git a/Project.toml b/Project.toml
index 2028e0ba..f9cb74ea 100644
--- a/Project.toml
+++ b/Project.toml
@@ -25,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
Aqua = "0.8"
+CategoricalDistributions = "0.1.15"
ChainRulesCore = "1.23.0"
Compat = "4.7.0"
ComputationalResources = "0.3.2"
From ced3da056e0c46b5eaf910b3dc705659a2f2f026 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Thu, 26 Sep 2024 00:45:43 +0200
Subject: [PATCH 28/60] prep for update + added mljmodelinterface to doc env
---
docs/Project.toml | 1 +
src/direct_mlj.jl | 27 ++++++++++++++++++---------
2 files changed, 19 insertions(+), 9 deletions(-)
diff --git a/docs/Project.toml b/docs/Project.toml
index 7118d5fe..624220c4 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
+MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 30c57045..102196a2 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -9,7 +9,7 @@ import MLJModelInterface as MMI
using Distributions: Normal
"""
- MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
+ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
A mutable struct representing a Laplace regression model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
@@ -28,8 +28,12 @@ It has the following Hyperparameters:
- `P₀`: the covariance matrix of the prior distribution.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
-MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
- model::Flux.Chain = nothing
+MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
+ model::Flux.Chain = Chain(
+ Dense(4, 10, relu),
+ Dense(10, 10, relu),
+ Dense(10, 1)
+ )
flux_loss = Flux.Losses.mse
optimiser = Adam()
epochs::Integer = 1000::(_ > 0)
@@ -110,7 +114,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
# Compute loss
loss = m.flux_loss(y_pred, y_batch)
- # Compute gradients explicitly
+ # Compute gradients
grads = gradient(m.model) do model
# Recompute predictions inside gradient context
y_pred = model(X_batch)
@@ -151,11 +155,16 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
fitresult = la
report = (loss_history = loss_history,)
- cache = nothing
- return (fitresult, cache, report)
+ cache = (deepcopy(m),opt_state, loss_history)
+ return fitresult, cache, report
end
+
+
+
+
+
@doc """
function MMI.fitted_params(model::LaplaceRegressor, fitresult)
@@ -248,7 +257,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
end
"""
- MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
+ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
A mutable struct representing a Laplace Classification model.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
@@ -269,7 +278,7 @@ The model also has the following parameters:
- `link_approx`: the link approximation to use, either `:probit` or `:plugin`.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
-MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic
+MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
model::Flux.Chain = nothing
flux_loss = Flux.Losses.logitcrossentropy
optimiser = Adam()
@@ -387,7 +396,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
report = (loss_history = loss_history,)
- cache = nothing
+ cache = (deepcopy(m),opt_state,loss_history)
return ((la, y[1]), cache, report)
end
From b700f851570b56e54b95389e4e43af8f5eeccb3a Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 2 Oct 2024 01:42:17 +0200
Subject: [PATCH 29/60] changed the loop so that it nows uses optimisers from
optimisers.jl
---
docs/Project.toml | 1 +
src/direct_mlj.jl | 168 ++++++++++++++++++++++++++++------------------
2 files changed, 102 insertions(+), 67 deletions(-)
diff --git a/docs/Project.toml b/docs/Project.toml
index 624220c4..c8d6234a 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -7,6 +7,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
+Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 102196a2..1379aee7 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -1,4 +1,5 @@
#module MLJLaplaceRedux
+using Optimisers: Optimisers
using Flux
using Random
using Tables
@@ -8,34 +9,11 @@ using MLJBase
import MLJModelInterface as MMI
using Distributions: Normal
-"""
- MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
-
-A mutable struct representing a Laplace regression model.
-It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
-It has the following Hyperparameters:
-- `model`: A Flux model provided by the user and compatible with the dataset.
-- `flux_loss` : a Flux loss function
-- `optimiser` = a Flux optimiser
-- `epochs`: The number of training epochs.
-- `batch_size`: The batch size.
-- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
-- `subnetwork_indices`: the indices of the subnetworks.
-- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
-- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
-- `σ`: the standard deviation of the prior distribution.
-- `μ₀`: the mean of the prior distribution.
-- `P₀`: the covariance matrix of the prior distribution.
-- `fit_prior_nsteps`: the number of steps used to fit the priors.
-"""
+
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
- model::Flux.Chain = Chain(
- Dense(4, 10, relu),
- Dense(10, 10, relu),
- Dense(10, 1)
- )
+ model::Union{Flux.Chain,Nothing} = nothing
flux_loss = Flux.Losses.mse
- optimiser = Adam()
+ optimiser = Optimisers.Adam()
epochs::Integer = 1000::(_ > 0)
batch_size::Integer = 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
@@ -52,7 +30,7 @@ end
# for fit:
MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims,reshape(y, 1, :))
-
+#for predict:
MMI.reformat(::LaplaceRegressor, X) = (MLJBase.matrix(X) |> permutedims,)
@@ -96,11 +74,13 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
# Reshape y if necessary
y = reshape(y, 1, :)
+ # Make a copy of the model because Flux does not allow to mutate hyperparameters
+ copied_model = deepcopy(m.model)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
- opt_state = Flux.setup(m.optimiser, m.model)
+ state_tree = Optimisers.setup(m.optimiser, copied_model)
loss_history=[]
- push!(loss_history, m.flux_loss(m.model(X), y ))
+ push!(loss_history, m.flux_loss(copied_model(X), y ))
for epoch in 1:(m.epochs)
@@ -109,20 +89,20 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
for (X_batch, y_batch) in data_loader
# Forward pass: compute predictions
- y_pred = m.model(X_batch)
+ y_pred = copied_model(X_batch)
# Compute loss
loss = m.flux_loss(y_pred, y_batch)
# Compute gradients
- grads = gradient(m.model) do model
+ grads,_ = gradient(copied_model,X_batch) do model, X
# Recompute predictions inside gradient context
- y_pred = model(X_batch)
+ y_pred = model(X)
m.flux_loss(y_pred, y_batch)
end
# Update parameters using the optimizer and computed gradients
- Flux.Optimise.update!(opt_state ,m.model , grads[1])
+ state_tree, model = Optimisers.update!(state_tree ,copied_model, grads)
# Accumulate the loss for this batch
loss_per_epoch += sum(loss) # Summing the batch loss
@@ -138,7 +118,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
end
la = LaplaceRedux.Laplace(
- m.model;
+ copied_model;
likelihood=:regression,
subset_of_weights=m.subset_of_weights,
subnetwork_indices=m.subnetwork_indices,
@@ -155,12 +135,85 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
fitresult = la
report = (loss_history = loss_history,)
- cache = (deepcopy(m),opt_state, loss_history)
+ cache = (deepcopy(m),state_tree, loss_history)
return fitresult, cache, report
end
+# Define the function is_same_except
+function MMI.is_same_except(m1::LaplaceRegressor, m2::LaplaceRegressor, exceptions::Symbol...)
+ typeof(m1) === typeof(m2) || return false
+ names = propertynames(m1)
+ propertynames(m2) === names || return false
+
+ for name in names
+ if !(name in exceptions)
+ if !_isdefined(m1, name)
+ !_isdefined(m2, name) || return false
+ elseif _isdefined(m2, name)
+ if name in deep_properties(LaplaceRegressor)
+ _equal_to_depth_one(
+ getproperty(m1,name),
+ getproperty(m2, name)
+ ) || return false
+ else
+ (
+ is_same_except(
+ getproperty(m1, name),
+ getproperty(m2, name)
+ ) ||
+ getproperty(m1, name) isa AbstractRNG ||
+ getproperty(m2, name) isa AbstractRNG ||
+ (getproperty(m1, name) isa Flux.Chain && getproperty(m2, name) isa Flux.Chain && _equal_flux_chain(getproperty(m1, name), getproperty(m2, name)))
+ ) || return false
+ end
+ else
+ return false
+ end
+ end
+ end
+ return true
+end
+
+# Define helper functions used in is_same_except
+function _isdefined(obj, name)
+ return hasproperty(obj, name)
+end
+
+function deep_properties(::Type)
+ return Set{Symbol}()
+end
+
+function _equal_to_depth_one(a, b)
+ return a == b
+end
+
+function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
+ if length(chain1.layers) != length(chain2.layers)
+ println("no length chain")
+ return false
+ end
+ params1 = Flux.params(chain1)
+ params2 = Flux.params(chain2)
+ if length(params1) != length(params2)
+ println("no length params")
+ return false
+ end
+ for (p1, p2) in zip(params1, params2)
+ if !isequal(p1, p2)
+ println(" params differs")
+ return false
+ end
+ end
+ for (layer1, layer2) in zip(chain1.layers, chain2.layers)
+ if typeof(layer1) != typeof(layer2)
+ println("layer differ")
+ return false
+ end
+ end
+ return true
+end
@@ -256,32 +309,11 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]
end
-"""
- MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
-
-A mutable struct representing a Laplace Classification model.
-It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
-The model also has the following parameters:
-
-- `model`: A Flux model provided by the user and compatible with the dataset.
-- `flux_loss` : a Flux loss function
-- `optimiser` = a Flux optimiser
-- `epochs`: The number of training epochs.
-- `batch_size`: The batch size.
-- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
-- `subnetwork_indices`: the indices of the subnetworks.
-- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
-- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
-- `σ`: the standard deviation of the prior distribution.
-- `μ₀`: the mean of the prior distribution.
-- `P₀`: the covariance matrix of the prior distribution.
-- `link_approx`: the link approximation to use, either `:probit` or `:plugin`.
-- `fit_prior_nsteps`: the number of steps used to fit the priors.
-"""
+
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
- model::Flux.Chain = nothing
+ model::Union{Flux.Chain,Nothing} = nothing
flux_loss = Flux.Losses.logitcrossentropy
- optimiser = Adam()
+ optimiser = Optimisers.Adam()
epochs::Integer = 1000::(_ > 0)
batch_size::Integer = 32::(_ > 0)
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
@@ -335,12 +367,14 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
# One-hot encoding of labels
unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
+ #copy model
+ copied_model = deepcopy(m.model)
# Create a data loader for batching the data
data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
# Set up the optimizer for the model
- opt_state = Flux.setup(m.optimiser, m.model)
+ state_tree = Optimisers.setup(m.optimiser, copied_model)
loss_history = []
# Training loop for the specified number of epochs
@@ -351,20 +385,20 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
for (X_batch, y_batch) in data_loader
# Forward pass: compute predictions
- y_pred = m.model(X_batch)
+ y_pred = copied_model(X_batch)
# Compute loss
loss = m.flux_loss(y_pred, y_batch)
- # Compute gradients explicitly
- grads = gradient(m.model) do model
+ # Compute gradients
+ grads,_ = gradient(copied_model,X_batch) do model, X
# Recompute predictions inside gradient context
- y_pred = model(X_batch)
+ y_pred = model(X)
m.flux_loss(y_pred, y_batch)
end
# Update parameters using the optimizer and computed gradients
- Flux.Optimise.update!(opt_state ,m.model , grads[1])
+ state_tree, model = Optimisers.update!(state_tree ,copied_model, grads)
# Accumulate the loss for this batch
loss_per_epoch += sum(loss) # Summing the batch loss
@@ -396,7 +430,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
report = (loss_history = loss_history,)
- cache = (deepcopy(m),opt_state,loss_history)
+ cache = (deepcopy(m),state_tree,loss_history)
return ((la, y[1]), cache, report)
end
@@ -493,7 +527,7 @@ end
# for fit:
MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y)
-
+# for predict:
MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,)
MMI.metadata_pkg(
From da6fc76b345f66987d7b700241299b13ee50f4ed Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Thu, 3 Oct 2024 22:40:37 +0200
Subject: [PATCH 30/60] started joining the functions in a single common
function for both models
---
src/direct_mlj.jl | 257 ++++++++++++++++++++++++++++------------------
1 file changed, 157 insertions(+), 100 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 1379aee7..3b5b33af 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -9,6 +9,23 @@ using MLJBase
import MLJModelInterface as MMI
using Distributions: Normal
+MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
+ model::Union{Flux.Chain,Nothing} = nothing
+ flux_loss = Flux.Losses.logitcrossentropy
+ optimiser = Optimisers.Adam()
+ epochs::Integer = 1000::(_ > 0)
+ batch_size::Integer = 32::(_ > 0)
+ subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
+ subnetwork_indices = nothing
+ hessian_structure::Union{HessianStructure,Symbol,String} =
+ :full::(_ in (:full, :diagonal))
+ backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
+ σ::Float64 = 1.0
+ μ₀::Float64 = 0.0
+ P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
+ fit_prior_nsteps::Int = 100::(_ > 0)
+ link_approx::Symbol = :probit::(_ in (:probit, :plugin))
+end
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
model::Union{Flux.Chain,Nothing} = nothing
@@ -27,11 +44,17 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
fit_prior_nsteps::Int = 100::(_ > 0)
end
+Const_Models = Union{LaplaceRegressor,LaplaceClassifier}
# for fit:
-MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims,reshape(y, 1, :))
+MMI.reformat(::Const_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :))
#for predict:
-MMI.reformat(::LaplaceRegressor, X) = (MLJBase.matrix(X) |> permutedims,)
+MMI.reformat(::Const_Models, X) = (MLJBase.matrix(X) |> permutedims,)
+
+# for fit:
+MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y)
+# for predict:
+#MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,)
@@ -68,19 +91,18 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
#X = MLJBase.matrix(X) |> permutedims
#y = reshape(y, 1, :)
- if Tables.istable(X)
- X = Tables.matrix(X)|>permutedims
- end
+ #if Tables.istable(X)
+ #X = Tables.matrix(X)|>permutedims
+ #end
# Reshape y if necessary
- y = reshape(y, 1, :)
+ #y = reshape(y, 1, :)
# Make a copy of the model because Flux does not allow to mutate hyperparameters
copied_model = deepcopy(m.model)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
state_tree = Optimisers.setup(m.optimiser, copied_model)
loss_history=[]
- push!(loss_history, m.flux_loss(copied_model(X), y ))
for epoch in 1:(m.epochs)
@@ -95,14 +117,14 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
loss = m.flux_loss(y_pred, y_batch)
# Compute gradients
- grads,_ = gradient(copied_model,X_batch) do model, X
+ grads,_ = gradient(copied_model,X_batch) do grad_model, X
# Recompute predictions inside gradient context
- y_pred = model(X)
+ y_pred = grad_model(X)
m.flux_loss(y_pred, y_batch)
end
# Update parameters using the optimizer and computed gradients
- state_tree, model = Optimisers.update!(state_tree ,copied_model, grads)
+ state_tree, copied_model = Optimisers.update!(state_tree ,copied_model, grads)
# Accumulate the loss for this batch
loss_per_epoch += sum(loss) # Summing the batch loss
@@ -133,15 +155,126 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
LaplaceRedux.fit!(la, data_loader)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
- fitresult = la
+ fitresult = (la, y[1])
report = (loss_history = loss_history,)
- cache = (deepcopy(m),state_tree, loss_history)
+ cache = (deepcopy(m),state_tree,loss_history)
+ return fitresult, cache, report
+end
+
+function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, y)
+ println("we are in the update function")
+
+ data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
+ old_model = old_cache[1]
+ old_state_tree = old_cache[2]
+ old_loss_history = old_cache[3]
+ old_la = old_fitresult[1]
+
+ epochs = m.epochs
+
+ if MMI.is_same_except(m, old_model,:epochs)
+
+
+ if epochs > old_model.epochs
+
+
+ for epoch in (old_model.epochs+1):(epochs)
+
+ loss_per_epoch= 0.0
+
+
+ for (X_batch, y_batch) in data_loader
+ # Forward pass: compute predictions
+ y_pred = old_la.model(X_batch)
+
+ # Compute loss
+ loss = m.flux_loss(y_pred, y_batch)
+
+ # Compute gradients
+ grads,_ = gradient(old_la.model,X_batch) do grad_model, X
+ # Recompute predictions inside gradient context
+ y_pred = grad_model(X)
+ m.flux_loss(y_pred, y_batch)
+ end
+
+ # Update parameters using the optimizer and computed gradients
+ old_state_tree,old_la.model = Optimisers.update!(old_state_tree,old_la.model, grads)
+
+ # Accumulate the loss for this batch
+ loss_per_epoch += sum(loss) # Summing the batch loss
+
+ end
+
+ push!(old_loss_history,loss_per_epoch )
+
+ # Print loss every 100 epochs if verbosity is 1 or more
+ if verbosity >= 1 && epoch % 100 == 0
+ println("Epoch $epoch: Loss: $loss_per_epoch ")
+ end
+ end
+
+ la = LaplaceRedux.Laplace(
+ old_la.model;
+ likelihood=:regression,
+ subset_of_weights=m.subset_of_weights,
+ subnetwork_indices=m.subnetwork_indices,
+ hessian_structure=m.hessian_structure,
+ backend=m.backend,
+ σ=m.σ,
+ μ₀=m.μ₀,
+ P₀=m.P₀,
+ )
+
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
+
+ fitresult = (la, y[1])
+ report = (loss_history = old_loss_history,)
+ cache = (deepcopy(m),old_state_tree,old_loss_history)
+
+ else
+
+ nothing
+
+ end
+
+ end
+
+ if MMI.is_same_except(m, old_model,:fit_prior_nsteps,:subset_of_weights,:subnetwork_indices,:hessian_structure,:backend,:σ,:μ₀,:P₀)
+
+ println(" changing only the laplace optimization part")
+
+ la = LaplaceRedux.Laplace(
+ old_la.model;
+ likelihood=:regression,
+ subset_of_weights=m.subset_of_weights,
+ subnetwork_indices=m.subnetwork_indices,
+ hessian_structure=m.hessian_structure,
+ backend=m.backend,
+ σ=m.σ,
+ μ₀=m.μ₀,
+ P₀=m.P₀,
+ )
+
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
+
+ fitresult = la
+ report = (loss_history = old_loss_history,)
+ cache = (deepcopy(m),old_state_tree,old_loss_history)
+
+ end
+
+
return fitresult, cache, report
end
# Define the function is_same_except
-function MMI.is_same_except(m1::LaplaceRegressor, m2::LaplaceRegressor, exceptions::Symbol...)
+function MMI.is_same_except(m1::Const_Models, m2::Const_Models, exceptions::Symbol...)
+ println("overloaded")
typeof(m1) === typeof(m2) || return false
names = propertynames(m1)
propertynames(m2) === names || return false
@@ -241,8 +374,8 @@ end
- `loss`: The loss value of the posterior distribution.
"""
-function MMI.fitted_params(model::LaplaceRegressor, fitresult)
- la = fitresult
+function MMI.fitted_params(model::Const_Models, fitresult)
+ la,decode = fitresult
posterior = la.posterior
return (
μ = posterior.μ,
@@ -261,18 +394,18 @@ end
@doc """
- MMI.training_losses(model::LaplaceRegressor, report)
+ MMI.training_losses(model::Union{LaplaceRegressor,LaplaceClassifier}, report)
Retrieve the training loss history from the given `report`.
# Arguments
-- `model::LaplaceRegressor`: The model for which the training losses are being retrieved.
+- `model`: The model for which the training losses are being retrieved.
- `report`: An object containing the training report, which includes the loss history.
# Returns
- A collection representing the loss history from the training report.
"""
-function MMI.training_losses(model::LaplaceRegressor, report)
+function MMI.training_losses(model::Const_Models, report)
return report.loss_history
end
@@ -282,11 +415,11 @@ end
@doc """
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
- Predicts the response for new data using a fitted LaplaceRegressor model.
+ Predicts the response for new data using a fitted Laplace model.
# Arguments
- - `m::LaplaceRegressor`: The LaplaceRegressor model.
- - `fitresult`: The result of fitting the LaplaceRegressor model.
+ - `m::LaplaceRegressor`: The Laplace model.
+ - `fitresult`: The result of the fitting procedure.
- `Xnew`: The new data for which predictions are to be made.
# Returns
@@ -300,7 +433,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
if Tables.istable(Xnew)
Xnew = Tables.matrix(Xnew)|>permutedims
end
- la = fitresult
+ la, y = fitresult
yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
# Extract mean and variance matrices
means, variances = yhat
@@ -310,23 +443,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
end
-MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
- model::Union{Flux.Chain,Nothing} = nothing
- flux_loss = Flux.Losses.logitcrossentropy
- optimiser = Optimisers.Adam()
- epochs::Integer = 1000::(_ > 0)
- batch_size::Integer = 32::(_ > 0)
- subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
- subnetwork_indices = nothing
- hessian_structure::Union{HessianStructure,Symbol,String} =
- :full::(_ in (:full, :diagonal))
- backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
- σ::Float64 = 1.0
- μ₀::Float64 = 0.0
- P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
- fit_prior_nsteps::Int = 100::(_ > 0)
- link_approx::Symbol = :probit::(_ in (:probit, :plugin))
-end
+
@doc """
@@ -435,64 +552,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
end
-
-
-
-
-@doc """
-
- function MMI.fitted_params(model::LaplaceClassifier, fitresult)
-
- This function extracts the fitted parameters from a `LaplaceClassifier` model.
-
- # Arguments
- - `model::LaplaceClassifier`: The Laplace classifier model.
- - `fitresult`: A tuple containing the Laplace approximation (`la`) and a decode function.
-
- # Returns
- A named tuple containing:
- - `μ`: The mean of the posterior distribution.
- - `H`: The Hessian of the posterior distribution.
- - `P`: The precision matrix of the posterior distribution.
- - `Σ`: The covariance matrix of the posterior distribution.
- - `n_data`: The number of data points.
- - `n_params`: The number of parameters.
- - `n_out`: The number of outputs.
- - `loss`: The loss value of the posterior distribution.
-
-"""
-function MMI.fitted_params(model::LaplaceClassifier, fitresult)
- la, decode = fitresult
- posterior = la.posterior
- return (
- μ = posterior.μ,
- H = posterior.H,
- P = posterior.P,
- Σ = posterior.Σ,
- n_data = posterior.n_data,
- n_params = posterior.n_params,
- n_out = posterior.n_out,
- loss = posterior.loss
- )
-end
-
-
-@doc """
- MMI.training_losses(model::LaplaceClassifier, report)
-
-Retrieve the training loss history from the given `report`.
-
-# Arguments
-- `model::LaplaceClassifier`: The model for which the training losses are being retrieved.
-- `report`: An object containing the training report, which includes the loss history.
-
-# Returns
-- A collection representing the loss history from the training report.
-"""
-function MMI.training_losses(model::LaplaceClassifier, report)
- return report.loss_history
-end
@doc """
@@ -525,10 +585,7 @@ end
-# for fit:
-MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y)
-# for predict:
-MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,)
+
MMI.metadata_pkg(
LaplaceRegressor,
From 70df56889d86fdfdb8979e6552b74fbb5d6966a8 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 4 Oct 2024 02:18:30 +0200
Subject: [PATCH 31/60] various fixes
---
src/direct_mlj.jl | 53 ++++++++---------------------------------------
1 file changed, 9 insertions(+), 44 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 3b5b33af..8ffd65b6 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -51,11 +51,6 @@ MMI.reformat(::Const_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(
#for predict:
MMI.reformat(::Const_Models, X) = (MLJBase.matrix(X) |> permutedims,)
-# for fit:
-MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y)
-# for predict:
-#MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,)
-
@doc """
@@ -87,16 +82,6 @@ This function performs the following steps:
9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report listing training related statistics.
"""
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
-
- #X = MLJBase.matrix(X) |> permutedims
- #y = reshape(y, 1, :)
-
- #if Tables.istable(X)
- #X = Tables.matrix(X)|>permutedims
- #end
-
- # Reshape y if necessary
- #y = reshape(y, 1, :)
# Make a copy of the model because Flux does not allow to mutate hyperparameters
copied_model = deepcopy(m.model)
@@ -243,7 +228,7 @@ function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X,
if MMI.is_same_except(m, old_model,:fit_prior_nsteps,:subset_of_weights,:subnetwork_indices,:hessian_structure,:backend,:σ,:μ₀,:P₀)
- println(" changing only the laplace optimization part")
+ println(" updating only the laplace optimization part")
la = LaplaceRedux.Laplace(
old_la.model;
@@ -274,7 +259,6 @@ end
# Define the function is_same_except
function MMI.is_same_except(m1::Const_Models, m2::Const_Models, exceptions::Symbol...)
- println("overloaded")
typeof(m1) === typeof(m2) || return false
names = propertynames(m1)
propertynames(m2) === names || return false
@@ -323,24 +307,20 @@ end
function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
if length(chain1.layers) != length(chain2.layers)
- println("no length chain")
return false
end
params1 = Flux.params(chain1)
params2 = Flux.params(chain2)
if length(params1) != length(params2)
- println("no length params")
return false
end
for (p1, p2) in zip(params1, params2)
if !isequal(p1, p2)
- println(" params differs")
return false
end
end
for (layer1, layer2) in zip(chain1.layers, chain2.layers)
if typeof(layer1) != typeof(layer2)
- println("layer differ")
return false
end
@@ -375,7 +355,7 @@ end
"""
function MMI.fitted_params(model::Const_Models, fitresult)
- la,decode = fitresult
+ la, decode = fitresult
posterior = la.posterior
return (
μ = posterior.μ,
@@ -429,10 +409,6 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
Finally, it creates Normal distributions from these means and variances and returns them as an array.
"""
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
- #Xnew = MLJBase.matrix(Xnew) |> permutedims
- if Tables.istable(Xnew)
- Xnew = Tables.matrix(Xnew)|>permutedims
- end
la, y = fitresult
yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
# Extract mean and variance matrices
@@ -468,18 +444,11 @@ end
"""
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
- #X = MLJBase.matrix(X) |> permutedims
-
- if Tables.istable(X)
- X = Tables.matrix(X)|>permutedims
- end
-
-
- # Store the first label as decode function
- #decode = y[1]
+
+
# Convert labels to integer format starting from 0 for one-hot encoding
- y_plain = MLJBase.int(y) .- 1
+ y_plain = MLJBase.int(y[1,:]) .- 1
# One-hot encoding of labels
unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
@@ -508,9 +477,9 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
loss = m.flux_loss(y_pred, y_batch)
# Compute gradients
- grads,_ = gradient(copied_model,X_batch) do model, X
+ grads,_ = gradient(copied_model,X_batch) do grad_model, X
# Recompute predictions inside gradient context
- y_pred = model(X)
+ y_pred = grad_model(X)
m.flux_loss(y_pred, y_batch)
end
@@ -569,12 +538,8 @@ Predicts the class probabilities for new data using a Laplace classifier.
The function transforms the new data `Xnew` into a matrix, applies the LaplaceRedux
prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object.
"""
-function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
- if Tables.istable(Xnew)
- Xnew = Tables.matrix(Xnew)|>permutedims
- end
- la = fitresult
- #Xnew = MLJBase.matrix(Xnew) |> permutedims
+function MMI.predict(m::LaplaceClassifier, fitresult, Xnew)
+ la,decode = fitresult
predictions =
LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |>
permutedims
From 988987229838c30797c56f3b885ff1b8317c3a53 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 4 Oct 2024 14:57:06 +0200
Subject: [PATCH 32/60] merged functions for both cases
---
src/direct_mlj.jl | 277 +++++++++++++++++++---------------------------
1 file changed, 113 insertions(+), 164 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 8ffd65b6..3d592c9c 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -44,44 +44,48 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
fit_prior_nsteps::Int = 100::(_ > 0)
end
-Const_Models = Union{LaplaceRegressor,LaplaceClassifier}
+Laplace_Models = Union{LaplaceRegressor,LaplaceClassifier}
# for fit:
-MMI.reformat(::Const_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :))
+MMI.reformat(::Laplace_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :))
#for predict:
-MMI.reformat(::Const_Models, X) = (MLJBase.matrix(X) |> permutedims,)
+MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,)
@doc """
- MMI.fit(m::LaplaceRegressor, verbosity, X, y)
+ MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)
-Fit a LaplaceRegressor model using the provided features and target values.
+Fit a Laplace model using the provided features and target values.
# Arguments
-- `m::LaplaceRegressor`: The LaplaceRegressor model to be fitted.
+- `m::Laplace`: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted.
- `verbosity`: Verbosity level for logging.
- `X`: Input features, expected to be in a format compatible with MLJBase.matrix.
- `y`: Target values.
# Returns
-- `fitresult`: The fitted Laplace model.
-- `cache`: Currently unused, returns `nothing`.
-- `report`: A tuple containing the loss history of the fitting process.
-
-# Description
-This function performs the following steps:
-1. If X is a Table: converts the input features `X` to a matrix and transposes it.
-2. Reshapes the target values `y` to shape (1,:).
-3. Creates a data loader for batching the data.
-4. Sets up the optimizer state using the Adam optimizer.
-5. Trains the model for a specified number of epochs.
-6. Initializes a Laplace model with the trained Flux model and specified parameters.
-7. Fits the Laplace model using the data loader.
-8. Optimizes the prior of the Laplace model.
-9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report listing training related statistics.
+- `fitresult`: a tuple (la,decode) cointaing the fitted Laplace model and y[1],the first element of the categorical y vector.
+- `cache`: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history.
+- `report`: A Namedtuple containing the loss history of the fitting process.
"""
-function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
+function MMI.fit(m::Laplace_Models, verbosity, X, y)
+
+ decode = y[1]
+
+
+ if typeof(m) == LaplaceRegressor
+ nothing
+ else
+ # Convert labels to integer format starting from 0 for one-hot encoding
+ y_plain = MLJBase.int(y[1,:]) .- 1
+
+ # One-hot encoding of labels
+ unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
+ y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
+
+ end
+
# Make a copy of the model because Flux does not allow to mutate hyperparameters
copied_model = deepcopy(m.model)
@@ -136,18 +140,50 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
P₀=m.P₀,
)
+ if typeof(m) == LaplaceClassifier
+ la.likelihood = :classification
+ end
+
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
- fitresult = (la, y[1])
+ fitresult = (la, decode)
report = (loss_history = loss_history,)
cache = (deepcopy(m),state_tree,loss_history)
return fitresult, cache, report
end
-function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, y)
- println("we are in the update function")
+@doc """
+ MMI.update(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)
+
+Update the Laplace model using the provided new data points.
+
+# Arguments
+- `m`: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted.
+- `verbosity`: Verbosity level for logging.
+- `X`: New input features, expected to be in a format compatible with MLJBase.matrix.
+- `y`: New target values.
+
+# Returns
+- `fitresult`: a tuple (la,decode) cointaing the updated fitted Laplace model and y[1],the first element of the categorical y vector.
+- `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history.
+- `report`: A Namedtuple containing the complete loss history of the fitting process.
+"""
+function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y)
+
+
+ if typeof(m) == LaplaceRegressor
+ nothing
+ else
+ # Convert labels to integer format starting from 0 for one-hot encoding
+ y_plain = MLJBase.int(y[1,:]) .- 1
+
+ # One-hot encoding of labels
+ unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
+ y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
+
+ end
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
old_model = old_cache[1]
@@ -209,6 +245,9 @@ function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X,
μ₀=m.μ₀,
P₀=m.P₀,
)
+ if typeof(m) == LaplaceClassifier
+ la.likelihood = :classification
+ end
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader)
@@ -241,6 +280,9 @@ function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X,
μ₀=m.μ₀,
P₀=m.P₀,
)
+ if typeof(m) == LaplaceClassifier
+ la.likelihood = :classification
+ end
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader)
@@ -257,8 +299,35 @@ function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X,
end
-# Define the function is_same_except
-function MMI.is_same_except(m1::Const_Models, m2::Const_Models, exceptions::Symbol...)
+@doc """
+ function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...)
+
+If both `m1` and `m2` are of `MLJType`, return `true` if the
+following conditions all hold, and `false` otherwise:
+
+- `typeof(m1) === typeof(m2)`
+
+- `propertynames(m1) === propertynames(m2)`
+
+- with the exception of properties listed as `exceptions` or bound to
+ an `AbstractRNG`, each pair of corresponding property values is
+ either "equal" or both undefined. (If a property appears as a
+ `propertyname` but not a `fieldname`, it is deemed as always defined.)
+
+The meaining of "equal" depends on the type of the property value:
+
+- values that are themselves of `MLJType` are "equal" if they are
+ equal in the sense of `is_same_except` with no exceptions.
+
+- values that are not of `MLJType` are "equal" if they are `==`.
+
+In the special case of a "deep" property, "equal" has a different
+meaning; see [`deep_properties`](@ref)) for details.
+
+If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`.
+
+"""
+function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...)
typeof(m1) === typeof(m2) || return false
names = propertynames(m1)
propertynames(m2) === names || return false
@@ -354,7 +423,7 @@ end
- `loss`: The loss value of the posterior distribution.
"""
-function MMI.fitted_params(model::Const_Models, fitresult)
+function MMI.fitted_params(model::Laplace_Models, fitresult)
la, decode = fitresult
posterior = la.posterior
return (
@@ -385,7 +454,7 @@ Retrieve the training loss history from the given `report`.
# Returns
- A collection representing the loss history from the training report.
"""
-function MMI.training_losses(model::Const_Models, report)
+function MMI.training_losses(model::Laplace_Models, report)
return report.loss_history
end
@@ -403,149 +472,29 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
- `Xnew`: The new data for which predictions are to be made.
# Returns
- - An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`.
-
- The function first converts `Xnew` to a matrix and permutes its dimensions. It then uses the `LaplaceRedux.predict` function to obtain the predicted means and variances.
-Finally, it creates Normal distributions from these means and variances and returns them as an array.
-"""
-function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
- la, y = fitresult
- yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
- # Extract mean and variance matrices
- means, variances = yhat
-
- # Create Normal distributions from the means and variances
- return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]
-end
-
-
-
-
-
-@doc """
-
- function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
-
- Description:
- This function fits a LaplaceClassifier model using the provided data. It first preprocesses the input data `X` and target labels `y`,
- then trains a neural network model using the Flux library. After training, it fits a Laplace approximation to the trained model.
-
- Arguments:
- - `m::LaplaceClassifier`: The LaplaceClassifier model to be fitted.
- - `verbosity`: Verbosity level for logging.
- - `X`: Input data features.
- - `y`: Target labels.
-
- Returns:
- - A tuple containing:
- - `(la, y[1])`: The fitted Laplace model and the decode function for the target labels.
- - `cache`: A placeholder for any cached data (currently `nothing`).
- - `report`: A NamedTuple containing statistics related to the fitting process.
-
+ for LaplaceRegressor:
+ - An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`.
+ for LaplaceClassifier:
+ - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data.
"""
-function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
-
-
-
- # Convert labels to integer format starting from 0 for one-hot encoding
- y_plain = MLJBase.int(y[1,:]) .- 1
-
- # One-hot encoding of labels
- unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
- y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
- #copy model
- copied_model = deepcopy(m.model)
-
- # Create a data loader for batching the data
- data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size)
-
- # Set up the optimizer for the model
- state_tree = Optimisers.setup(m.optimiser, copied_model)
- loss_history = []
-
- # Training loop for the specified number of epochs
- for epoch in 1:(m.epochs)
-
- loss_per_epoch= 0.0
-
-
- for (X_batch, y_batch) in data_loader
- # Forward pass: compute predictions
- y_pred = copied_model(X_batch)
-
- # Compute loss
- loss = m.flux_loss(y_pred, y_batch)
-
- # Compute gradients
- grads,_ = gradient(copied_model,X_batch) do grad_model, X
- # Recompute predictions inside gradient context
- y_pred = grad_model(X)
- m.flux_loss(y_pred, y_batch)
- end
-
- # Update parameters using the optimizer and computed gradients
- state_tree, model = Optimisers.update!(state_tree ,copied_model, grads)
+function MMI.predict(m::Laplace_Models, fitresult, Xnew)
+ la, decode = fitresult
+ if typeof(m)== LaplaceRegressor
+ yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
+ # Extract mean and variance matrices
+ means, variances = yhat
- # Accumulate the loss for this batch
- loss_per_epoch += sum(loss) # Summing the batch loss
-
- end
+ # Create Normal distributions from the means and variances
+ return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]
- push!(loss_history,loss_per_epoch )
+ else
+ predictions = LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> permutedims
- # Print loss every 100 epochs if verbosity is 1 or more
- if verbosity >= 1 && epoch % 100 == 0
- println("Epoch $epoch: Loss: $loss_per_epoch ")
- end
+ return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions)
end
-
- la = LaplaceRedux.Laplace(
- m.model;
- likelihood=:classification,
- subset_of_weights=m.subset_of_weights,
- subnetwork_indices=m.subnetwork_indices,
- hessian_structure=m.hessian_structure,
- backend=m.backend,
- σ=m.σ,
- μ₀=m.μ₀,
- P₀=m.P₀,
- )
-
- # fit the Laplace model:
- LaplaceRedux.fit!(la, data_loader)
- optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
-
- report = (loss_history = loss_history,)
- cache = (deepcopy(m),state_tree,loss_history)
- return ((la, y[1]), cache, report)
end
-
-
-
-@doc """
-Predicts the class probabilities for new data using a Laplace classifier.
-
- # Arguments
- - `m::LaplaceClassifier`: The Laplace classifier model.
- - `(fitresult, decode)`: A tuple containing the fitted model result and the decode function.
- - `Xnew`: The new data for which predictions are to be made.
-
- # Returns
- - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data.
-
-The function transforms the new data `Xnew` into a matrix, applies the LaplaceRedux
-prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object.
-"""
-function MMI.predict(m::LaplaceClassifier, fitresult, Xnew)
- la,decode = fitresult
- predictions =
- LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |>
- permutedims
-
- return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions)
-end
From 0f46fd6452eac5813b3f24218a0c27e8eac56e88 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 4 Oct 2024 15:02:47 +0200
Subject: [PATCH 33/60] julia formatter
---
src/direct_mlj.jl | 257 ++++++++++++++++++++--------------------------
1 file changed, 113 insertions(+), 144 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 3d592c9c..15d3b39f 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -10,7 +10,7 @@ import MLJModelInterface as MMI
using Distributions: Normal
MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
- model::Union{Flux.Chain,Nothing} = nothing
+ model::Union{Flux.Chain,Nothing} = nothing
flux_loss = Flux.Losses.logitcrossentropy
optimiser = Optimisers.Adam()
epochs::Integer = 1000::(_ > 0)
@@ -28,7 +28,7 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
end
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
- model::Union{Flux.Chain,Nothing} = nothing
+ model::Union{Flux.Chain,Nothing} = nothing
flux_loss = Flux.Losses.mse
optimiser = Optimisers.Adam()
epochs::Integer = 1000::(_ > 0)
@@ -51,8 +51,6 @@ MMI.reformat(::Laplace_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshap
#for predict:
MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,)
-
-
@doc """
MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)
@@ -70,20 +68,17 @@ Fit a Laplace model using the provided features and target values.
- `report`: A Namedtuple containing the loss history of the fitting process.
"""
function MMI.fit(m::Laplace_Models, verbosity, X, y)
-
decode = y[1]
-
if typeof(m) == LaplaceRegressor
nothing
else
# Convert labels to integer format starting from 0 for one-hot encoding
- y_plain = MLJBase.int(y[1,:]) .- 1
+ y_plain = MLJBase.int(y[1, :]) .- 1
# One-hot encoding of labels
unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
-
end
# Make a copy of the model because Flux does not allow to mutate hyperparameters
@@ -91,12 +86,10 @@ function MMI.fit(m::Laplace_Models, verbosity, X, y)
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
state_tree = Optimisers.setup(m.optimiser, copied_model)
- loss_history=[]
+ loss_history = []
for epoch in 1:(m.epochs)
-
- loss_per_epoch= 0.0
-
+ loss_per_epoch = 0.0
for (X_batch, y_batch) in data_loader
# Forward pass: compute predictions
@@ -106,21 +99,20 @@ function MMI.fit(m::Laplace_Models, verbosity, X, y)
loss = m.flux_loss(y_pred, y_batch)
# Compute gradients
- grads,_ = gradient(copied_model,X_batch) do grad_model, X
+ grads, _ = gradient(copied_model, X_batch) do grad_model, X
# Recompute predictions inside gradient context
y_pred = grad_model(X)
m.flux_loss(y_pred, y_batch)
end
-
+
# Update parameters using the optimizer and computed gradients
- state_tree, copied_model = Optimisers.update!(state_tree ,copied_model, grads)
+ state_tree, copied_model = Optimisers.update!(state_tree, copied_model, grads)
# Accumulate the loss for this batch
loss_per_epoch += sum(loss) # Summing the batch loss
-
end
- push!(loss_history,loss_per_epoch )
+ push!(loss_history, loss_per_epoch)
# Print loss every 100 epochs if verbosity is 1 or more
if verbosity >= 1 && epoch % 100 == 0
@@ -149,8 +141,8 @@ function MMI.fit(m::Laplace_Models, verbosity, X, y)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
fitresult = (la, decode)
- report = (loss_history = loss_history,)
- cache = (deepcopy(m),state_tree,loss_history)
+ report = (loss_history=loss_history,)
+ cache = (deepcopy(m), state_tree, loss_history)
return fitresult, cache, report
end
@@ -170,19 +162,16 @@ Update the Laplace model using the provided new data points.
- `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history.
- `report`: A Namedtuple containing the complete loss history of the fitting process.
"""
-function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y)
-
-
+function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y)
if typeof(m) == LaplaceRegressor
nothing
else
# Convert labels to integer format starting from 0 for one-hot encoding
- y_plain = MLJBase.int(y[1,:]) .- 1
+ y_plain = MLJBase.int(y[1, :]) .- 1
# One-hot encoding of labels
unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
-
end
data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
@@ -193,80 +182,82 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
epochs = m.epochs
- if MMI.is_same_except(m, old_model,:epochs)
-
-
+ if MMI.is_same_except(m, old_model, :epochs)
if epochs > old_model.epochs
+ for epoch in (old_model.epochs + 1):(epochs)
+ loss_per_epoch = 0.0
-
- for epoch in (old_model.epochs+1):(epochs)
-
- loss_per_epoch= 0.0
-
-
for (X_batch, y_batch) in data_loader
# Forward pass: compute predictions
y_pred = old_la.model(X_batch)
-
+
# Compute loss
loss = m.flux_loss(y_pred, y_batch)
-
+
# Compute gradients
- grads,_ = gradient(old_la.model,X_batch) do grad_model, X
+ grads, _ = gradient(old_la.model, X_batch) do grad_model, X
# Recompute predictions inside gradient context
y_pred = grad_model(X)
m.flux_loss(y_pred, y_batch)
end
-
+
# Update parameters using the optimizer and computed gradients
- old_state_tree,old_la.model = Optimisers.update!(old_state_tree,old_la.model, grads)
-
+ old_state_tree, old_la.model = Optimisers.update!(
+ old_state_tree, old_la.model, grads
+ )
+
# Accumulate the loss for this batch
loss_per_epoch += sum(loss) # Summing the batch loss
-
end
-
- push!(old_loss_history,loss_per_epoch )
-
+
+ push!(old_loss_history, loss_per_epoch)
+
# Print loss every 100 epochs if verbosity is 1 or more
if verbosity >= 1 && epoch % 100 == 0
println("Epoch $epoch: Loss: $loss_per_epoch ")
end
end
- la = LaplaceRedux.Laplace(
- old_la.model;
- likelihood=:regression,
- subset_of_weights=m.subset_of_weights,
- subnetwork_indices=m.subnetwork_indices,
- hessian_structure=m.hessian_structure,
- backend=m.backend,
- σ=m.σ,
- μ₀=m.μ₀,
- P₀=m.P₀,
- )
- if typeof(m) == LaplaceClassifier
- la.likelihood = :classification
- end
+ la = LaplaceRedux.Laplace(
+ old_la.model;
+ likelihood=:regression,
+ subset_of_weights=m.subset_of_weights,
+ subnetwork_indices=m.subnetwork_indices,
+ hessian_structure=m.hessian_structure,
+ backend=m.backend,
+ σ=m.σ,
+ μ₀=m.μ₀,
+ P₀=m.P₀,
+ )
+ if typeof(m) == LaplaceClassifier
+ la.likelihood = :classification
+ end
- # fit the Laplace model:
- LaplaceRedux.fit!(la, data_loader)
- optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
- fitresult = (la, y[1])
- report = (loss_history = old_loss_history,)
- cache = (deepcopy(m),old_state_tree,old_loss_history)
+ fitresult = (la, y[1])
+ report = (loss_history=old_loss_history,)
+ cache = (deepcopy(m), old_state_tree, old_loss_history)
else
-
nothing
-
end
-
end
- if MMI.is_same_except(m, old_model,:fit_prior_nsteps,:subset_of_weights,:subnetwork_indices,:hessian_structure,:backend,:σ,:μ₀,:P₀)
-
+ if MMI.is_same_except(
+ m,
+ old_model,
+ :fit_prior_nsteps,
+ :subset_of_weights,
+ :subnetwork_indices,
+ :hessian_structure,
+ :backend,
+ :σ,
+ :μ₀,
+ :P₀,
+ )
println(" updating only the laplace optimization part")
la = LaplaceRedux.Laplace(
@@ -279,26 +270,23 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
- )
- if typeof(m) == LaplaceClassifier
- la.likelihood = :classification
- end
-
- # fit the Laplace model:
- LaplaceRedux.fit!(la, data_loader)
- optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
-
- fitresult = la
- report = (loss_history = old_loss_history,)
- cache = (deepcopy(m),old_state_tree,old_loss_history)
+ )
+ if typeof(m) == LaplaceClassifier
+ la.likelihood = :classification
+ end
- end
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
+ fitresult = la
+ report = (loss_history=old_loss_history,)
+ cache = (deepcopy(m), old_state_tree, old_loss_history)
+ end
return fitresult, cache, report
end
-
@doc """
function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...)
@@ -327,7 +315,7 @@ meaning; see [`deep_properties`](@ref)) for details.
If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`.
"""
-function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...)
+function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...)
typeof(m1) === typeof(m2) || return false
names = propertynames(m1)
propertynames(m2) === names || return false
@@ -335,22 +323,21 @@ function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::
for name in names
if !(name in exceptions)
if !_isdefined(m1, name)
- !_isdefined(m2, name) || return false
+ !_isdefined(m2, name) || return false
elseif _isdefined(m2, name)
if name in deep_properties(LaplaceRegressor)
- _equal_to_depth_one(
- getproperty(m1,name),
- getproperty(m2, name)
- ) || return false
+ _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) ||
+ return false
else
(
- is_same_except(
- getproperty(m1, name),
- getproperty(m2, name)
- ) ||
+ is_same_except(getproperty(m1, name), getproperty(m2, name)) ||
getproperty(m1, name) isa AbstractRNG ||
getproperty(m2, name) isa AbstractRNG ||
- (getproperty(m1, name) isa Flux.Chain && getproperty(m2, name) isa Flux.Chain && _equal_flux_chain(getproperty(m1, name), getproperty(m2, name)))
+ (
+ getproperty(m1, name) isa Flux.Chain &&
+ getproperty(m2, name) isa Flux.Chain &&
+ _equal_flux_chain(getproperty(m1, name), getproperty(m2, name))
+ )
) || return false
end
else
@@ -392,14 +379,10 @@ function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
if typeof(layer1) != typeof(layer2)
return false
end
-
end
return true
end
-
-
-
@doc """
function MMI.fitted_params(model::LaplaceRegressor, fitresult)
@@ -427,21 +410,17 @@ function MMI.fitted_params(model::Laplace_Models, fitresult)
la, decode = fitresult
posterior = la.posterior
return (
- μ = posterior.μ,
- H = posterior.H,
- P = posterior.P,
- Σ = posterior.Σ,
- n_data = posterior.n_data,
- n_params = posterior.n_params,
- n_out = posterior.n_out,
- loss = posterior.loss
+ μ=posterior.μ,
+ H=posterior.H,
+ P=posterior.P,
+ Σ=posterior.Σ,
+ n_data=posterior.n_data,
+ n_params=posterior.n_params,
+ n_out=posterior.n_out,
+ loss=posterior.loss,
)
end
-
-
-
-
@doc """
MMI.training_losses(model::Union{LaplaceRegressor,LaplaceClassifier}, report)
@@ -458,9 +437,6 @@ function MMI.training_losses(model::Laplace_Models, report)
return report.loss_history
end
-
-
-
@doc """
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
@@ -479,7 +455,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
"""
function MMI.predict(m::Laplace_Models, fitresult, Xnew)
la, decode = fitresult
- if typeof(m)== LaplaceRegressor
+ if typeof(m) == LaplaceRegressor
yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
# Extract mean and variance matrices
means, variances = yhat
@@ -488,41 +464,34 @@ function MMI.predict(m::Laplace_Models, fitresult, Xnew)
return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]
else
- predictions = LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> permutedims
+ predictions =
+ LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |>
+ permutedims
return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions)
end
end
-
-
-
-
-
-
-
MMI.metadata_pkg(
- LaplaceRegressor,
- name="LaplaceRedux",
- package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478",
- package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
- is_pure_julia=true,
- is_wrapper=true,
- package_license = "https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE",
+ LaplaceRegressor;
+ name="LaplaceRedux",
+ package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478",
+ package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
+ is_pure_julia=true,
+ is_wrapper=true,
+ package_license="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE",
)
MMI.metadata_pkg(
- LaplaceClassifier,
- name="LaplaceRedux",
- package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478",
- package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
- is_pure_julia=true,
- is_wrapper=true,
- package_license = "https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE",
+ LaplaceClassifier;
+ name="LaplaceRedux",
+ package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478",
+ package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
+ is_pure_julia=true,
+ is_wrapper=true,
+ package_license="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE",
)
-
-
MLJBase.metadata_model(
LaplaceClassifier;
input_scitype=Union{
@@ -530,7 +499,7 @@ MLJBase.metadata_model(
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass
- supports_training_losses = true,
+ supports_training_losses=true,
load_path="LaplaceRedux.LaplaceClassifier",
)
# metadata for each model,
@@ -541,16 +510,16 @@ MLJBase.metadata_model(
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{MLJBase.Continuous},
- supports_training_losses = true,
+ supports_training_losses=true,
load_path="LaplaceRedux.LaplaceRegressor",
-
)
-const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]"*
- "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in "*
+const DOC_LAPLACE_REDUX =
+ "[Laplace Redux – Effortless Bayesian Deep Learning]" *
+ "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in " *
"Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103"
- """
+"""
$(MMI.doc_header(LaplaceClassifier))
`LaplaceClassifier` implements the $DOC_LAPLACE_REDUX for classification models.
@@ -820,4 +789,4 @@ See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl
"""
LaplaceRegressor
-#end # module
\ No newline at end of file
+#end # module
From ab8b6bfb38f003219afc876ed649580da1797ab0 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Tue, 15 Oct 2024 18:02:10 +0200
Subject: [PATCH 34/60] added unit tests
---
test/direct_mlj_interface.jl | 10 ++++++++++
test/mlj_flux_interfacing.jl | 2 ++
2 files changed, 12 insertions(+)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 66fe020d..c42546a0 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -22,6 +22,11 @@ using LaplaceRedux
yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
MLJBase.predict_mode(mach, Xnew) # point predictions
MLJBase.fitted_params(mach) #fitted params function
+ MLJBase.training_losses(mach) #training loss history
+ model.epochs= 100 #changing number of epochs
+ MLJBase.fit!(mach) #testing update function
+ model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
+ MLJBase.fit!(mach) #testing update function (the laplace part)
end
@@ -47,5 +52,10 @@ yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
pdf.(yhat, "virginica") # probabilities for the "verginica" class
MLJBase.fitted_params(mach) # fitted params
+MLJBase.training_losses(mach) #training loss history
+model.epochs= 100 #changing number of epochs
+MLJBase.fit!(mach) #testing update function
+model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
+MLJBase.fit!(mach) #testing update function (the laplace part)
end
diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl
index cb238daf..292397a8 100644
--- a/test/mlj_flux_interfacing.jl
+++ b/test/mlj_flux_interfacing.jl
@@ -1,3 +1,5 @@
+#deactivated in runtests.jl
+
using Random: Random
import Random.seed!
using MLJBase: MLJBase, categorical
From 263cc67a68de24f1ec041a1f9d2185cf61eafb87 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Tue, 15 Oct 2024 18:15:50 +0200
Subject: [PATCH 35/60] more units
---
test/direct_mlj_interface.jl | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index c42546a0..16956064 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -25,6 +25,8 @@ using LaplaceRedux
MLJBase.training_losses(mach) #training loss history
model.epochs= 100 #changing number of epochs
MLJBase.fit!(mach) #testing update function
+ model.epochs= 50 #changing number of epochs to a lower number
+ MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
end
@@ -35,7 +37,6 @@ end
# Define the model
flux_model = Chain(
Dense(4, 10, relu),
- Dense(10, 10, relu),
Dense(10, 3)
)
@@ -55,7 +56,19 @@ MLJBase.fitted_params(mach) # fitted params
MLJBase.training_losses(mach) #training loss history
model.epochs= 100 #changing number of epochs
MLJBase.fit!(mach) #testing update function
+model.epochs= 50 #changing number of epochs to a lower number
+MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
+
+# Define a different model
+flux_model_two = Chain(
+ Dense(4, 6, relu),
+ Dense(6, 3)
+)
+
+model_two = LaplaceClassifier(model=flux_model_two,epochs=70)
+
+MLJBase.is_same_except(model_two, model, :epochs)
end
From 453b49fc19e2fa42e5a228f9f43759a8ed2da613 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Tue, 15 Oct 2024 19:07:37 +0200
Subject: [PATCH 36/60] fix
---
src/direct_mlj.jl | 30 +++++++++++++++++++++++++-----
1 file changed, 25 insertions(+), 5 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 15d3b39f..fb366f0c 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -239,14 +239,20 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
fitresult = (la, y[1])
report = (loss_history=old_loss_history,)
- cache = (deepcopy(m), old_state_tree, old_loss_history)
+ cache = old_cache
else
- nothing
+ fitresult = old_fitresult
+ report = (loss_history=old_loss_history,)
+ cache = (deepcopy(m), old_state_tree, old_loss_history)
end
- end
+
+
+
- if MMI.is_same_except(
+ return fitresult, cache, report
+
+ elseif MMI.is_same_except(
m,
old_model,
:fit_prior_nsteps,
@@ -282,9 +288,23 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
fitresult = la
report = (loss_history=old_loss_history,)
cache = (deepcopy(m), old_state_tree, old_loss_history)
+
+ return fitresult, cache, report
+
+
+
+ else
+
+ fitresult, cache, report = MMI.fit(m, verbosity, X, y)
+
+
+ return fitresult, cache, report
+
+
+
+
end
- return fitresult, cache, report
end
@doc """
From 656b24ea1c0b2a1f13397e17899edfc5afdf5238 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 16 Oct 2024 05:14:53 +0200
Subject: [PATCH 37/60] changed unit test and a minor fix in the update
function. there is still the if statement problem.
---
src/direct_mlj.jl | 9 ++++++++-
test/direct_mlj_interface.jl | 8 ++++----
2 files changed, 12 insertions(+), 5 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index fb366f0c..b6bc3023 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -178,11 +178,12 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
old_model = old_cache[1]
old_state_tree = old_cache[2]
old_loss_history = old_cache[3]
- old_la = old_fitresult[1]
+ #old_la = old_fitresult[1]
epochs = m.epochs
if MMI.is_same_except(m, old_model, :epochs)
+ old_la = old_fitresult[1]
if epochs > old_model.epochs
for epoch in (old_model.epochs + 1):(epochs)
loss_per_epoch = 0.0
@@ -264,7 +265,9 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
:μ₀,
:P₀,
)
+
println(" updating only the laplace optimization part")
+ old_la = old_fitresult[1]
la = LaplaceRedux.Laplace(
old_la.model;
@@ -294,6 +297,10 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
else
+ println(" I believe this error is provoked by that if statement in the fit function. This case should address the possibility that \n
+ the user change the flux chain. In this case the update! function should revert to the fallback fit! function, \n
+ however y has already been transformed during the previous fit! run . One way to solve this issue is to make sure that the \n
+ data preparation is completely done in reformat")
fitresult, cache, report = MMI.fit(m, verbosity, X, y)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 16956064..ac7b4847 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -17,7 +17,7 @@ using LaplaceRedux
X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
- MLJBase.fit!(mach)
+ MLJBase.fit!(mach,verbosity=1)
Xnew, _ = make_regression(3, 4; rng=123)
yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
MLJBase.predict_mode(mach, Xnew) # point predictions
@@ -44,7 +44,7 @@ model = LaplaceClassifier(model=flux_model,epochs=50)
X, y = @load_iris
mach = machine(model, X, y)
-MLJBase.fit!(mach)
+MLJBase.fit!(mach,verbosity=1)
Xnew = (sepal_length = [6.4, 7.2, 7.4],
sepal_width = [2.8, 3.0, 2.8],
petal_length = [5.6, 5.8, 6.1],
@@ -67,8 +67,8 @@ flux_model_two = Chain(
Dense(6, 3)
)
-model_two = LaplaceClassifier(model=flux_model_two,epochs=70)
+model.model = flux_model_two
-MLJBase.is_same_except(model_two, model, :epochs)
+MLJBase.fit!(mach)
end
From 7c4d74406c23413273271c7c06ad83f0d7587866 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 16 Oct 2024 09:03:27 +0200
Subject: [PATCH 38/60] only things left to fix are the selectrows functions
---
.../mlj-interfacing/direct_mlj.ipynb | 1137 +++++++++++++++++
src/direct_mlj.jl | 79 +-
test/direct_mlj_interface.jl | 2 +
3 files changed, 1172 insertions(+), 46 deletions(-)
create mode 100644 dev/notebooks/mlj-interfacing/direct_mlj.ipynb
diff --git a/dev/notebooks/mlj-interfacing/direct_mlj.ipynb b/dev/notebooks/mlj-interfacing/direct_mlj.ipynb
new file mode 100644
index 00000000..4169dc0e
--- /dev/null
+++ b/dev/notebooks/mlj-interfacing/direct_mlj.ipynb
@@ -0,0 +1,1137 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "using Revise\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\docs`\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m\u001b[1mStatus\u001b[22m\u001b[39m `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\docs\\Project.toml`\n",
+ " \u001b[90m[324d7699] \u001b[39mCategoricalArrays v0.10.8\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[a93c6f00] \u001b[39mDataFrames v1.6.1\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[31c24e10] \u001b[39mDistributions v0.25.111\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[e30172f5] \u001b[39mDocumenter v1.6.0\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[587475ba] \u001b[39mFlux v0.14.19\n",
+ " \u001b[90m[c52c1a26] \u001b[39mLaplaceRedux v1.1.1 `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl`\n",
+ "\u001b[33m⌅\u001b[39m \u001b[90m[094fc8d1] \u001b[39mMLJFlux v0.5.1\n",
+ " \u001b[90m[e80e1ace] \u001b[39mMLJModelInterface v1.11.0\n",
+ " \u001b[90m[3bd65402] \u001b[39mOptimisers v0.3.3\n",
+ " \u001b[90m[91a5bcdd] \u001b[39mPlots v1.40.8\n",
+ " \u001b[90m[ce6b1742] \u001b[39mRDatasets v0.7.7\n",
+ " \u001b[90m[860ef19b] \u001b[39mStableRNGs v1.0.2\n",
+ " \u001b[90m[2913bbd2] \u001b[39mStatsBase v0.34.3\n",
+ " \u001b[90m[bd369af6] \u001b[39mTables v1.12.0\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[bd7198b4] \u001b[39mTaijaPlotting v1.2.0\n",
+ " \u001b[90m[592b5752] \u001b[39mTrapz v2.0.3\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[e88e6eb3] \u001b[39mZygote v0.6.70\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[02a925ec] \u001b[39mcuDNN v1.3.2\n",
+ " \u001b[90m[9a3f8284] \u001b[39mRandom\n",
+ " \u001b[90m[10745b16] \u001b[39mStatistics v1.10.0\n",
+ "\u001b[36m\u001b[1mInfo\u001b[22m\u001b[39m Packages marked with \u001b[32m⌃\u001b[39m and \u001b[33m⌅\u001b[39m have new versions available. Those with \u001b[32m⌃\u001b[39m may be upgradable, but those with \u001b[33m⌅\u001b[39m are restricted by compatibility constraints from upgrading. To see why use `status --outdated`\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING: using MLJBase.predict in module LaplaceRedux conflicts with an existing identifier.\n",
+ "WARNING: using MLJBase.fit! in module LaplaceRedux conflicts with an existing identifier.\n"
+ ]
+ }
+ ],
+ "source": [
+ "using Pkg\n",
+ "\n",
+ "\n",
+ "Pkg.activate(\"C:/Users/Pasqu/Documents/julia_projects/LaplaceRedux.jl/docs\")\n",
+ "Pkg.status()\n",
+ "using LaplaceRedux\n",
+ "using MLJBase\n",
+ "using Random\n",
+ "using DataFrames\n",
+ "using Flux"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(Tables.MatrixTable{Matrix{Float64}} with 100 rows, 2 columns, and schema:\n",
+ " :x1 Float64\n",
+ " :x2 Float64, [-0.15053240354230857, -0.16143107735113107, -0.28104782384528254, 0.8905842690519058, -0.2716955057136559, 0.9606721208381163, 0.14403243794060133, 0.13743002853667605, 0.820641892942472, 0.2270783932443115 … 0.4639933046961763, 0.30449384622096687, 0.2744588755171263, -31.785240173822324, 2.58951832655098, 0.10969223924903307, 0.1600255529817666, 0.5913011997917647, -0.39253898214541977, -0.7247243478011852])"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "X, y = make_regression(100, 2; noise=0.5, sparse=0.2, outliers=0.1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of rows: 100\n",
+ "Number of columns: 2\n"
+ ]
+ }
+ ],
+ "source": [
+ "using Tables\n",
+ "# Get the number of rows\n",
+ "num_rows = Tables.rows(X) |> length\n",
+ "\n",
+ "# Get the number of columns\n",
+ "num_columns = Tables.columnnames(X) |> length\n",
+ "\n",
+ "# Display the dimensions\n",
+ "println(\"Number of rows: \", num_rows)\n",
+ "println(\"Number of columns: \", num_columns)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Chain(\n",
+ " Dense(2 => 10, relu), \u001b[90m# 30 parameters\u001b[39m\n",
+ " Dense(10 => 10, relu), \u001b[90m# 110 parameters\u001b[39m\n",
+ " Dense(10 => 1), \u001b[90m# 11 parameters\u001b[39m\n",
+ ") \u001b[90m # Total: 6 arrays, \u001b[39m151 parameters, 988 bytes."
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using Flux\n",
+ "# Define the model\n",
+ "flux_model = Chain(\n",
+ " Dense(2, 10, relu),\n",
+ " Dense(10, 10, relu),\n",
+ " Dense(10, 1)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "LaplaceRegressor(\n",
+ " model = Chain(Dense(2 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), \n",
+ " flux_loss = Flux.Losses.mse, \n",
+ " optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), \n",
+ " epochs = 1000, \n",
+ " batch_size = 32, \n",
+ " subset_of_weights = :all, \n",
+ " subnetwork_indices = nothing, \n",
+ " hessian_structure = :full, \n",
+ " backend = :GGN, \n",
+ " σ = 1.0, \n",
+ " μ₀ = 0.0, \n",
+ " P₀ = nothing, \n",
+ " fit_prior_nsteps = 100)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "\n",
+ "# Create an instance of LaplaceRegressor with the Flux model\n",
+ "laplace_regressor = LaplaceRegressor(\n",
+ " model = flux_model,\n",
+ " subset_of_weights = :all,\n",
+ " subnetwork_indices = nothing,\n",
+ " hessian_structure = :full,\n",
+ " backend = :GGN,\n",
+ " σ = 1.0,\n",
+ " μ₀ = 0.0,\n",
+ " P₀ = nothing,\n",
+ " fit_prior_nsteps = 100\n",
+ ")\n",
+ "\n",
+ "# Display the LaplaceRegressor object\n",
+ "laplace_regressor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Tables.MatrixTable{Matrix{Float64}}"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "typeof(X)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
1 | 5.89419 | 12.9979 | 6.18512 | 8.99286 | 2 |
2 | -8.14834 | 6.23246 | -1.68497 | 8.96905 | 1 |
3 | -4.88229 | 5.35276 | -0.45876 | 8.20756 | 1 |
4 | 4.02585 | 6.94769 | 13.4032 | -0.0419223 | 2 |
5 | -5.53635 | 6.55656 | -1.67063 | 8.77041 | 1 |
6 | -6.61858 | 4.65032 | -1.15198 | 8.34897 | 1 |
7 | 10.2344 | 11.4278 | 13.0544 | 8.53025 | 2 |
8 | -6.05052 | 8.12027 | -3.68708 | 8.78732 | 1 |
9 | -5.06769 | 4.8631 | -3.58346 | 8.41371 | 1 |
10 | 10.8373 | 6.32472 | 9.79163 | 6.65962 | 2 |
11 | -6.63226 | 5.45149 | -0.38861 | 9.0007 | 1 |
12 | 1.62812 | 4.61073 | 11.6602 | 11.7241 | 2 |
13 | -6.48679 | 6.68166 | -3.32569 | 9.19618 | 1 |
14 | 7.95596 | 2.23928 | 12.6897 | 1.77857 | 2 |
15 | -6.36466 | 5.82985 | -0.702502 | 8.44976 | 1 |
16 | 8.03294 | 3.85901 | 5.50741 | 2.2014 | 2 |
17 | -7.45067 | 7.01011 | -1.96187 | 7.84336 | 1 |
"
+ ],
+ "text/latex": [
+ "\\begin{tabular}{r|ccccc}\n",
+ "\t& x1 & x2 & x3 & x4 & y\\\\\n",
+ "\t\\hline\n",
+ "\t& Float64 & Float64 & Float64 & Float64 & Cat…\\\\\n",
+ "\t\\hline\n",
+ "\t1 & 5.89419 & 12.9979 & 6.18512 & 8.99286 & 2 \\\\\n",
+ "\t2 & -8.14834 & 6.23246 & -1.68497 & 8.96905 & 1 \\\\\n",
+ "\t3 & -4.88229 & 5.35276 & -0.45876 & 8.20756 & 1 \\\\\n",
+ "\t4 & 4.02585 & 6.94769 & 13.4032 & -0.0419223 & 2 \\\\\n",
+ "\t5 & -5.53635 & 6.55656 & -1.67063 & 8.77041 & 1 \\\\\n",
+ "\t6 & -6.61858 & 4.65032 & -1.15198 & 8.34897 & 1 \\\\\n",
+ "\t7 & 10.2344 & 11.4278 & 13.0544 & 8.53025 & 2 \\\\\n",
+ "\t8 & -6.05052 & 8.12027 & -3.68708 & 8.78732 & 1 \\\\\n",
+ "\t9 & -5.06769 & 4.8631 & -3.58346 & 8.41371 & 1 \\\\\n",
+ "\t10 & 10.8373 & 6.32472 & 9.79163 & 6.65962 & 2 \\\\\n",
+ "\t11 & -6.63226 & 5.45149 & -0.38861 & 9.0007 & 1 \\\\\n",
+ "\t12 & 1.62812 & 4.61073 & 11.6602 & 11.7241 & 2 \\\\\n",
+ "\t13 & -6.48679 & 6.68166 & -3.32569 & 9.19618 & 1 \\\\\n",
+ "\t14 & 7.95596 & 2.23928 & 12.6897 & 1.77857 & 2 \\\\\n",
+ "\t15 & -6.36466 & 5.82985 & -0.702502 & 8.44976 & 1 \\\\\n",
+ "\t16 & 8.03294 & 3.85901 & 5.50741 & 2.2014 & 2 \\\\\n",
+ "\t17 & -7.45067 & 7.01011 & -1.96187 & 7.84336 & 1 \\\\\n",
+ "\\end{tabular}\n"
+ ],
+ "text/plain": [
+ "\u001b[1m17×5 DataFrame\u001b[0m\n",
+ "\u001b[1m Row \u001b[0m│\u001b[1m x1 \u001b[0m\u001b[1m x2 \u001b[0m\u001b[1m x3 \u001b[0m\u001b[1m x4 \u001b[0m\u001b[1m y \u001b[0m\n",
+ " │\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Cat… \u001b[0m\n",
+ "─────┼─────────────────────────────────────────────────\n",
+ " 1 │ 5.89419 12.9979 6.18512 8.99286 2\n",
+ " 2 │ -8.14834 6.23246 -1.68497 8.96905 1\n",
+ " 3 │ -4.88229 5.35276 -0.45876 8.20756 1\n",
+ " 4 │ 4.02585 6.94769 13.4032 -0.0419223 2\n",
+ " 5 │ -5.53635 6.55656 -1.67063 8.77041 1\n",
+ " 6 │ -6.61858 4.65032 -1.15198 8.34897 1\n",
+ " 7 │ 10.2344 11.4278 13.0544 8.53025 2\n",
+ " 8 │ -6.05052 8.12027 -3.68708 8.78732 1\n",
+ " 9 │ -5.06769 4.8631 -3.58346 8.41371 1\n",
+ " 10 │ 10.8373 6.32472 9.79163 6.65962 2\n",
+ " 11 │ -6.63226 5.45149 -0.38861 9.0007 1\n",
+ " 12 │ 1.62812 4.61073 11.6602 11.7241 2\n",
+ " 13 │ -6.48679 6.68166 -3.32569 9.19618 1\n",
+ " 14 │ 7.95596 2.23928 12.6897 1.77857 2\n",
+ " 15 │ -6.36466 5.82985 -0.702502 8.44976 1\n",
+ " 16 │ 8.03294 3.85901 5.50741 2.2014 2\n",
+ " 17 │ -7.45067 7.01011 -1.96187 7.84336 1"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using MLJ, DataFrames\n",
+ "X, y = make_blobs(100, 4; centers=2, cluster_std=[1.0, 3.0 ])\n",
+ "dfBlobs = DataFrame(X)\n",
+ "dfBlobs.y = y\n",
+ "first(dfBlobs, 17)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Chain(\n",
+ " Dense(4 => 10, relu), \u001b[90m# 50 parameters\u001b[39m\n",
+ " Dense(10 => 10, relu), \u001b[90m# 110 parameters\u001b[39m\n",
+ " Dense(10 => 2), \u001b[90m# 22 parameters\u001b[39m\n",
+ ") \u001b[90m # Total: 6 arrays, \u001b[39m182 parameters, 1.086 KiB."
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using Flux\n",
+ "# Define the model\n",
+ "flux_model = Chain(\n",
+ " Dense(4, 10, relu),\n",
+ " Dense(10, 10, relu),\n",
+ " Dense(10, 2)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "LaplaceClassifier(\n",
+ " model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 2)), \n",
+ " flux_loss = Flux.Losses.logitcrossentropy, \n",
+ " optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), \n",
+ " epochs = 1000, \n",
+ " batch_size = 32, \n",
+ " subset_of_weights = :all, \n",
+ " subnetwork_indices = nothing, \n",
+ " hessian_structure = :full, \n",
+ " backend = :GGN, \n",
+ " σ = 1.0, \n",
+ " μ₀ = 0.0, \n",
+ " P₀ = nothing, \n",
+ " fit_prior_nsteps = 100, \n",
+ " link_approx = :probit)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Create an instance of LaplaceRegressor with the Flux model\n",
+ "laplace_classifier = LaplaceClassifier(\n",
+ " model = flux_model,\n",
+ " subset_of_weights = :all,\n",
+ " subnetwork_indices = nothing,\n",
+ " hessian_structure = :full,\n",
+ " backend = :GGN,\n",
+ " σ = 1.0,\n",
+ " μ₀ = 0.0,\n",
+ " P₀ = nothing,\n",
+ " fit_prior_nsteps = 100\n",
+ ")\n",
+ "\n",
+ "# Display the LaplaceRegressor object\n",
+ "laplace_classifier"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "([5.894190948229072 -8.148338397036424 … 4.6901216310216896 8.754850004556681; 12.99791729188348 6.232458473633034 … 10.904369405648966 12.187158715203248; 6.185122057812649 -1.6849673572899138 … 10.09295616393206 15.616135680604215; 8.992861668498659 8.969052840985539 … 8.673387526088336 3.444822170655238], (Bool[1 0 … 1 1; 0 1 … 0 0], CategoricalArrays.CategoricalValue{Int64, UInt32} 2))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "X_test_reformat,y_test_reformat= MLJBase.reformat(laplace_classifier,X,y)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "2×100 OneHotMatrix(::Vector{UInt32}) with eltype Bool:\n",
+ " 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ … 1 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1\n",
+ " ⋅ 1 1 ⋅ 1 1 ⋅ 1 1 ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 1 1 1 1 1 ⋅ ⋅"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "y_test_reformat[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "2×3 view(OneHotMatrix(::Vector{UInt32}), :, [2, 3, 4]) with eltype Bool:\n",
+ " 0 0 1\n",
+ " 1 1 0"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "view(y_test_reformat[1],:, [2,3,4])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "4×3 view(::Matrix{Float64}, :, [2, 3, 4]) with eltype Float64:\n",
+ " -8.14834 -4.88229 4.02585\n",
+ " 6.23246 5.35276 6.94769\n",
+ " -1.68497 -0.45876 13.4032\n",
+ " 8.96905 8.20756 -0.0419223"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "view(X_test_reformat, :, [2,3,4])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((sepal_length = [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9 … 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9], sepal_width = [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1 … 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0], petal_length = [1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5 … 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1], petal_width = [0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1 … 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]), CategoricalArrays.CategoricalValue{String, UInt32}[\"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\" … \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\"])"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using MLJ\n",
+ "#DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree\n",
+ "\n",
+ "using Flux\n",
+ "# Define the model\n",
+ "flux_model = Chain(\n",
+ " Dense(4, 10, relu),\n",
+ " Dense(10, 10, relu),\n",
+ " Dense(10, 3)\n",
+ ")\n",
+ "\n",
+ "model = LaplaceClassifier(model=flux_model)\n",
+ "\n",
+ "X, y = @load_iris\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "3-element Vector{String}:\n",
+ " \"setosa\"\n",
+ " \"versicolor\"\n",
+ " \"virginica\""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "levels(y)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Training machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:499\n",
+ "┌ Warning: Layer with Float32 parameters got Float64 input.\n",
+ "│ The input will be converted, but any earlier layers may be very slow.\n",
+ "│ layer = Dense(4 => 10, relu)\n",
+ "│ summary(x) = 4×32 Matrix{Float64}\n",
+ "└ @ Flux C:\\Users\\Pasqu\\.julia\\packages\\Flux\\HBF2N\\src\\layers\\stateless.jl:60\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 100: Loss: 4.042439937591553 \n",
+ "Epoch 200: Loss: 3.533539354801178 \n",
+ "Epoch 300: Loss: 3.2006053030490875 \n",
+ "Epoch 400: Loss: 2.9201304614543915 \n",
+ "Epoch 500: Loss: 2.647124856710434 \n",
+ "Epoch 600: Loss: 2.3577273190021515 \n",
+ "Epoch 700: Loss: 2.0284294933080673 \n",
+ "Epoch 800: Loss: 1.6497879326343536 \n",
+ "Epoch 900: Loss: 1.3182911276817322 \n",
+ "Epoch 1000: Loss: 1.077766239643097 \n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n",
+ " args: \n",
+ " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "mach = machine(model, X, y) |> MLJBase.fit!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "3-element Vector{Float64}:\n",
+ " 0.9488185211171959\n",
+ " 0.7685919895442062\n",
+ " 0.9454937120287679"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "Xnew = (sepal_length = [6.4, 7.2, 7.4],\n",
+ " sepal_width = [2.8, 3.0, 2.8],\n",
+ " petal_length = [5.6, 5.8, 6.1],\n",
+ " petal_width = [2.1, 1.6, 1.9],)\n",
+ "yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions\n",
+ "predict_mode(mach, Xnew) # point predictions\n",
+ "pdf.(yhat, \"virginica\") # probabilities for the \"verginica\" class"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "CV(\n",
+ " nfolds = 3, \n",
+ " shuffle = false, \n",
+ " rng = TaskLocalRNG())"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using Random\n",
+ "cv = CV(nfolds=3)\n",
+ "CV(\n",
+ " nfolds = 3,\n",
+ " shuffle = false,\n",
+ " rng = Random.default_rng())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(μ = Float32[0.36796558, -0.590969, -0.0005966219, -0.1918786, -0.21046183, 0.33172563, -0.6454487, -0.43057427, 0.09267368, -0.31252357 … 0.31747338, -0.42893136, -0.036560666, 0.12848923, -0.022290554, -0.17562295, 0.10335993, 2.0584264, -0.27860576, -0.9749556],\n",
+ " H = [19221.80854034424 0.0 … 175.01133728027344 416.5687526464462; 0.0 0.0 … 0.0 0.0; … ; 175.01133728027344 0.0 … 150.0 0.0; 416.5687526464462 0.0 … 0.0 150.0],\n",
+ " P = [19222.80854034424 0.0 … 175.01133728027344 416.5687526464462; 0.0 1.0 … 0.0 0.0; … ; 175.01133728027344 0.0 … 151.0 0.0; 416.5687526464462 0.0 … 0.0 151.0],\n",
+ " Σ = [0.4942412474220449 0.0 … 0.00012692197672594825 -0.00029493662610333626; 0.0 1.0 … -0.0 -0.0; … ; 0.00012692197672495198 0.0 … 0.019500175331854216 2.549733600297664e-5; -0.000294936626108007 0.0 … 2.5497336002974752e-5 0.019426448013349775],\n",
+ " n_data = 160,\n",
+ " n_params = 193,\n",
+ " n_out = 3,\n",
+ " loss = 76.21374633908272,)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "MLJBase.fitted_params(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1100: Loss: 0.9005963057279587 \n",
+ "Epoch 1200: Loss: 0.7674828618764877 \n",
+ "Epoch 1300: Loss: 0.6668129041790962 "
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Epoch 1400: Loss: 0.5900264084339142 \n",
+ "Epoch 1500: Loss: 0.5311783701181412 \n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n",
+ " args: \n",
+ " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.epochs= 1500\n",
+ "uff=MLJBase.fit!(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " updating only the laplace optimization part\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n",
+ " args: \n",
+ " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.fit_prior_nsteps = 200\n",
+ "uff=MLJBase.fit!(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Training machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:499\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 100: Loss: 123.99619626934783 \n",
+ "Epoch 200: Loss: 110.70134709925446 \n",
+ "Epoch 300: Loss: 101.54836914035226 \n",
+ "Epoch 400: Loss: 93.59849316803738 \n",
+ "Epoch 500: Loss: 85.18775669356168 \n",
+ "Epoch 600: Loss: 75.76182416873147 \n",
+ "Epoch 700: Loss: 67.35201676247976 \n",
+ "Epoch 800: Loss: 61.22682561405186 \n",
+ "Epoch 900: Loss: 56.388587609986764 \n",
+ "Epoch 1000: Loss: 51.8525101259686 \n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "1×3 Matrix{Float64}:\n",
+ " -2.0295 -1.03661 1.29323"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using MLJ\n",
+ "#LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux\n",
+ "flux_model = Chain(\n",
+ " Dense(4, 10, relu),\n",
+ " Dense(10, 10, relu),\n",
+ " Dense(10, 1)\n",
+ ")\n",
+ "model = LaplaceRegressor(model=flux_model)\n",
+ "\n",
+ "X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)\n",
+ "mach = machine(model, X, y) \n",
+ "uff=MLJBase.fit!(mach)\n",
+ "Xnew, _ = make_regression(3, 4; rng=123)\n",
+ "yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions\n",
+ "MLJBase.predict_mode(mach, Xnew) # point predictions\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(μ = Float32[1.0187618, 1.1801981, 1.5047036, 0.09397529, -1.0688609, -0.59974676, 1.2977643, -1.8645017, -0.6493797, 0.8598072 … -3.1050463, -2.267896, 1.018275, -2.0316908, 0.72548866, -2.7309833, -2.1938443, 0.8200029, 1.2517836, 0.3313799],\n",
+ " H = [8124.873962402344 1238.1352081298828 … 1980.7237396240234 324.216007232666; 1238.1352081298828 496.2319145202637 … 451.8493309020996 62.39444828033447; … ; 1980.7237396240234 451.8493309020996 … 1277.6469421386719 212.83043670654297; 324.216007232666 62.39444828033447 … 212.83043670654297 100.0],\n",
+ " P = [8125.873962402344 1238.1352081298828 … 1980.7237396240234 324.216007232666; 1238.1352081298828 497.2319145202637 … 451.8493309020996 62.39444828033447; … ; 1980.7237396240234 451.8493309020996 … 1278.6469421386719 212.83043670654297; 324.216007232666 62.39444828033447 … 212.83043670654297 101.0],\n",
+ " Σ = [0.08818398027482557 0.01068994964485469 … -0.028911073548363073 0.0029380249323247053; 0.01068994964484666 0.3740392430965252 … -0.03321343036791264 0.03808189340124685; … ; -0.02891107354841918 -0.03321343036788035 … 0.37993018751981145 -0.007007167070441715; 0.002938024932320219 0.03808189340124439 … -0.007007167070367749 0.8779610764649773],\n",
+ " n_data = 128,\n",
+ " n_params = 171,\n",
+ " n_out = 1,\n",
+ " loss = 25.89132467732224,)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "MLJBase.fitted_params(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1000-element Vector{Any}:\n",
+ " 132.4757012476366\n",
+ " 132.34841820836348\n",
+ " 132.2594675536756\n",
+ " 132.1764167781873\n",
+ " 132.09720400291872\n",
+ " 132.02269947440837\n",
+ " 131.95014241005182\n",
+ " 131.87929837827625\n",
+ " 131.8104252672236\n",
+ " 131.74319014461662\n",
+ " ⋮\n",
+ " 52.175198365270234\n",
+ " 52.148401306572445\n",
+ " 52.10149557862212\n",
+ " 52.032551867920525\n",
+ " 51.99690913470439\n",
+ " 51.994054497037524\n",
+ " 51.97321907106968\n",
+ " 51.932579915942064\n",
+ " 51.8525101259686"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "a= MLJBase.training_losses(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1100: Loss: 47.37399118448011 \n",
+ "Epoch 1200: Loss: 42.67007994813126 \n",
+ "Epoch 1300: Loss: 38.057605481177234 \n",
+ "Epoch 1400: Loss: 33.846777755430416 \n",
+ "Epoch 1500: Loss: 29.514834265471144 \n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n",
+ " args: \n",
+ " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.epochs= 1500\n",
+ "uff=MLJBase.fit!(mach)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1500-element Vector{Any}:\n",
+ " 132.4757012476366\n",
+ " 132.34841820836348\n",
+ " 132.2594675536756\n",
+ " 132.1764167781873\n",
+ " 132.09720400291872\n",
+ " 132.02269947440837\n",
+ " 131.95014241005182\n",
+ " 131.87929837827625\n",
+ " 131.8104252672236\n",
+ " 131.74319014461662\n",
+ " ⋮\n",
+ " 29.979098347946184\n",
+ " 29.896050775635096\n",
+ " 29.83609799111873\n",
+ " 29.763058304997937\n",
+ " 29.69908817371154\n",
+ " 29.646864259231556\n",
+ " 29.623977677653713\n",
+ " 29.549562760589527\n",
+ " 29.514834265471144"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "MLJBase.training_losses(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n",
+ " args: \n",
+ " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.epochs= 1200\n",
+ "uff=MLJBase.fit!(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " updating only the laplace optimization part\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n",
+ " args: \n",
+ " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.fit_prior_nsteps = 200\n",
+ "uff=MLJBase.fit!(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MLJBase.selectrows(model, I, Xmatrix, y) = (view(Xmatrix, :, I), view(y[1], I))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#MLJBase.evaluate(model, X, y, resampling=cv, measure=l2, verbosity=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "BoundsError",
+ "evalue": "BoundsError: attempt to access 4×100 Matrix{Float64} at index [[35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], 1:100]",
+ "output_type": "error",
+ "traceback": [
+ "BoundsError: attempt to access 4×100 Matrix{Float64} at index [[35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], 1:100]\n",
+ "\n",
+ "Stacktrace:\n",
+ " [1] throw_boundserror(A::Matrix{Float64}, I::Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}})\n",
+ " @ Base .\\abstractarray.jl:737\n",
+ " [2] checkbounds\n",
+ " @ .\\abstractarray.jl:702 [inlined]\n",
+ " [3] _getindex\n",
+ " @ .\\multidimensional.jl:888 [inlined]\n",
+ " [4] getindex\n",
+ " @ .\\abstractarray.jl:1291 [inlined]\n",
+ " [5] _selectrows(::MLJModelInterface.FullInterface, ::Val{:other}, X::Matrix{Float64}, r::Vector{Int64})\n",
+ " @ MLJModelInterface C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:350\n",
+ " [6] selectrows\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:340 [inlined]\n",
+ " [7] selectrows\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:336 [inlined]\n",
+ " [8] #29\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\model_api.jl:88 [inlined]\n",
+ " [9] map\n",
+ " @ .\\tuple.jl:292 [inlined]\n",
+ " [10] selectrows(::LaplaceRegressor, ::Vector{Int64}, ::Matrix{Float64}, ::Tuple{Matrix{Float64}, Nothing})\n",
+ " @ MLJModelInterface C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\model_api.jl:88\n",
+ " [11] fit_only!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}; rows::Vector{Int64}, verbosity::Int64, force::Bool, composite::Nothing)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:676\n",
+ " [12] fit_only!\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:617 [inlined]\n",
+ " [13] #fit!#63\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:789 [inlined]\n",
+ " [14] fit!\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:786 [inlined]\n",
+ " [15] fit_and_extract_on_fold\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1463 [inlined]\n",
+ " [16] (::MLJBase.var\"#277#278\"{MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, Machine{LaplaceRegressor, LaplaceRegressor, true}, Int64})(k::Int64)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1289\n",
+ " [17] _mapreduce(f::MLJBase.var\"#277#278\"{MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, Machine{LaplaceRegressor, LaplaceRegressor, true}, Int64}, op::typeof(vcat), ::IndexLinear, A::UnitRange{Int64})\n",
+ " @ Base .\\reduce.jl:440\n",
+ " [18] _mapreduce_dim\n",
+ " @ .\\reducedim.jl:365 [inlined]\n",
+ " [19] mapreduce\n",
+ " @ .\\reducedim.jl:357 [inlined]\n",
+ " [20] _evaluate!(func::MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, mach::Machine{LaplaceRegressor, LaplaceRegressor, true}, ::CPU1{Nothing}, nfolds::Int64, verbosity::Int64)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1288\n",
+ " [21] evaluate!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}, resampling::Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, weights::Nothing, class_weights::Nothing, rows::Nothing, verbosity::Int64, repeats::Int64, measures::Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, operations::Vector{typeof(predict_mean)}, acceleration::CPU1{Nothing}, force::Bool, per_observation_flag::Bool, logger::Nothing, user_resampling::CV, compact::Bool)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1510\n",
+ " [22] evaluate!(::Machine{LaplaceRegressor, LaplaceRegressor, true}, ::CV, ::Nothing, ::Nothing, ::Nothing, ::Int64, ::Int64, ::Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, ::Vector{typeof(predict_mean)}, ::CPU1{Nothing}, ::Bool, ::Bool, ::Nothing, ::CV, ::Bool)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1603\n",
+ " [23] evaluate!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}; resampling::CV, measures::Nothing, measure::StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}, weights::Nothing, class_weights::Nothing, operations::Nothing, operation::Nothing, acceleration::CPU1{Nothing}, rows::Nothing, repeats::Int64, force::Bool, check_measure::Bool, per_observation::Bool, verbosity::Int64, logger::Nothing, compact::Bool)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1232\n",
+ " [24] top-level scope\n",
+ " @ c:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\dev\\notebooks\\mlj-interfacing\\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X54sZmlsZQ==.jl:1"
+ ]
+ }
+ ],
+ "source": [
+ "evaluate!(mach, resampling=cv, measure=l2, verbosity=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#flux_model = Chain(\n",
+ " #Dense(4, 10, relu),\n",
+ " #Dense(10, 10, relu),\n",
+ " #Dense(10, 1))\n",
+ "\n",
+ "\n",
+ "#nested_flux_model = Chain(\n",
+ " #Chain(Dense(10, 5, relu), Dense(5, 5, relu)),\n",
+ " #Chain(Dense(5, 3, relu), Dense(3, 2)))\n",
+ "#model = LaplaceRegressor()\n",
+ "#model.model= nested_flux_model\n",
+ "\n",
+ "#copy_model= deepcopy(model)\n",
+ "\n",
+ "#copy_model.epochs= 2000\n",
+ "#copy_model.optimiser = Descent()\n",
+ "#MLJBase.is_same_except(model , copy_model,:epochs )"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Julia 1.10.5",
+ "language": "julia",
+ "name": "julia-1.10"
+ },
+ "language_info": {
+ "file_extension": ".jl",
+ "mimetype": "application/julia",
+ "name": "julia",
+ "version": "1.10.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index b6bc3023..1781fb6c 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -47,7 +47,25 @@ end
Laplace_Models = Union{LaplaceRegressor,LaplaceClassifier}
# for fit:
-MMI.reformat(::Laplace_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :))
+MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :),nothing))
+
+
+function MMI.reformat(::LaplaceClassifier, X, y)
+
+ X = MLJBase.matrix(X) |> permutedims
+
+
+ y= reshape(y, 1, :)
+ # Convert labels to integer format starting from 0 for one-hot encoding
+ y_plain = MLJBase.int(y[1, :]) .- 1
+ # One-hot encoding of labels
+ unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
+ y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
+ return X,(y_onehot, y[1])
+end
+
+#MMI.selectrows(::LaplaceClassifier, I, Xmatrix, y) = (view(Xmatrix, :, I), (view(y[1],I),y[2]))
+#MMI.selectrows(::LaplaceRegressor, I, Xmatrix, y) = (view(Xmatrix, :,I), (view(y[1],I),y[2]) )
#for predict:
MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,)
@@ -68,18 +86,10 @@ Fit a Laplace model using the provided features and target values.
- `report`: A Namedtuple containing the loss history of the fitting process.
"""
function MMI.fit(m::Laplace_Models, verbosity, X, y)
- decode = y[1]
+ decode = y[2]
- if typeof(m) == LaplaceRegressor
- nothing
- else
- # Convert labels to integer format starting from 0 for one-hot encoding
- y_plain = MLJBase.int(y[1, :]) .- 1
+ y= y[1]
- # One-hot encoding of labels
- unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
- y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
- end
# Make a copy of the model because Flux does not allow to mutate hyperparameters
copied_model = deepcopy(m.model)
@@ -163,22 +173,14 @@ Update the Laplace model using the provided new data points.
- `report`: A Namedtuple containing the complete loss history of the fitting process.
"""
function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y)
- if typeof(m) == LaplaceRegressor
- nothing
- else
- # Convert labels to integer format starting from 0 for one-hot encoding
- y_plain = MLJBase.int(y[1, :]) .- 1
-
- # One-hot encoding of labels
- unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
- y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
- end
+ decode = y[2]
+ y_up=y[1]
- data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
+ data_loader = Flux.DataLoader((X, y_up); batchsize=m.batch_size)
old_model = old_cache[1]
old_state_tree = old_cache[2]
old_loss_history = old_cache[3]
- #old_la = old_fitresult[1]
+
epochs = m.epochs
@@ -238,21 +240,18 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
LaplaceRedux.fit!(la, data_loader)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
- fitresult = (la, y[1])
+ fitresult = (la, decode)
report = (loss_history=old_loss_history,)
- cache = old_cache
+ cache = (deepcopy(m), old_state_tree, old_loss_history)
else
- fitresult = old_fitresult
+ println("The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary")
+ fitresult = (old_la, decode)
report = (loss_history=old_loss_history,)
cache = (deepcopy(m), old_state_tree, old_loss_history)
end
-
-
- return fitresult, cache, report
-
elseif MMI.is_same_except(
m,
old_model,
@@ -265,7 +264,6 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
:μ₀,
:P₀,
)
-
println(" updating only the laplace optimization part")
old_la = old_fitresult[1]
@@ -288,30 +286,19 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
LaplaceRedux.fit!(la, data_loader)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
- fitresult = la
+ fitresult = (la,decode)
report = (loss_history=old_loss_history,)
cache = (deepcopy(m), old_state_tree, old_loss_history)
- return fitresult, cache, report
-
-
else
- println(" I believe this error is provoked by that if statement in the fit function. This case should address the possibility that \n
- the user change the flux chain. In this case the update! function should revert to the fallback fit! function, \n
- however y has already been transformed during the previous fit! run . One way to solve this issue is to make sure that the \n
- data preparation is completely done in reformat")
-
- fitresult, cache, report = MMI.fit(m, verbosity, X, y)
-
-
- return fitresult, cache, report
-
-
+ fitresult, cache, report = MLJBase.fit(m, verbosity,X,y)
+
end
+ return fitresult, cache, report
end
@doc """
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index ac7b4847..c225bfde 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -4,6 +4,7 @@ using MLJBase: MLJBase, categorical
using Flux
using StableRNGs
using MLJ
+using MLJ:predict,fit!
using LaplaceRedux
@@ -29,6 +30,7 @@ using LaplaceRedux
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
+ evaluate!(mach, resampling=cv, measure=l2, verbosity=0)
end
From f872d965fdb546c2105ed8582e51b01e7ac63b3c Mon Sep 17 00:00:00 2001
From: pat-alt
Date: Wed, 16 Oct 2024 13:49:27 +0200
Subject: [PATCH 39/60] returning one-hot encoded directly
---
src/direct_mlj.jl | 36 ++++++++------------
test/direct_mlj_interface.jl | 64 ++++++++++++++++++------------------
2 files changed, 46 insertions(+), 54 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 1781fb6c..e5b611d1 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -44,30 +44,22 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
fit_prior_nsteps::Int = 100::(_ > 0)
end
-Laplace_Models = Union{LaplaceRegressor,LaplaceClassifier}
+LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier}
# for fit:
-MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :),nothing))
-
+MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :))
function MMI.reformat(::LaplaceClassifier, X, y)
X = MLJBase.matrix(X) |> permutedims
+ y = categorical(y)
+ unique_labels = y.pool.levels
+ y = Flux.onehotbatch(y, unique_labels) # One-hot encoding
-
- y= reshape(y, 1, :)
- # Convert labels to integer format starting from 0 for one-hot encoding
- y_plain = MLJBase.int(y[1, :]) .- 1
- # One-hot encoding of labels
- unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
- y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
- return X,(y_onehot, y[1])
+ return X, y
end
-#MMI.selectrows(::LaplaceClassifier, I, Xmatrix, y) = (view(Xmatrix, :, I), (view(y[1],I),y[2]))
-#MMI.selectrows(::LaplaceRegressor, I, Xmatrix, y) = (view(Xmatrix, :,I), (view(y[1],I),y[2]) )
-#for predict:
-MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,)
+MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,)
@doc """
MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)
@@ -85,7 +77,7 @@ Fit a Laplace model using the provided features and target values.
- `cache`: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history.
- `report`: A Namedtuple containing the loss history of the fitting process.
"""
-function MMI.fit(m::Laplace_Models, verbosity, X, y)
+function MMI.fit(m::LaplaceModels, verbosity, X, y)
decode = y[2]
y= y[1]
@@ -172,7 +164,7 @@ Update the Laplace model using the provided new data points.
- `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history.
- `report`: A Namedtuple containing the complete loss history of the fitting process.
"""
-function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y)
+function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
decode = y[2]
y_up=y[1]
@@ -302,7 +294,7 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y
end
@doc """
- function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...)
+ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...)
If both `m1` and `m2` are of `MLJType`, return `true` if the
following conditions all hold, and `false` otherwise:
@@ -329,7 +321,7 @@ meaning; see [`deep_properties`](@ref)) for details.
If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`.
"""
-function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...)
+function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...)
typeof(m1) === typeof(m2) || return false
names = propertynames(m1)
propertynames(m2) === names || return false
@@ -420,7 +412,7 @@ end
- `loss`: The loss value of the posterior distribution.
"""
-function MMI.fitted_params(model::Laplace_Models, fitresult)
+function MMI.fitted_params(model::LaplaceModels, fitresult)
la, decode = fitresult
posterior = la.posterior
return (
@@ -447,7 +439,7 @@ Retrieve the training loss history from the given `report`.
# Returns
- A collection representing the loss history from the training report.
"""
-function MMI.training_losses(model::Laplace_Models, report)
+function MMI.training_losses(model::LaplaceModels, report)
return report.loss_history
end
@@ -467,7 +459,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
for LaplaceClassifier:
- `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data.
"""
-function MMI.predict(m::Laplace_Models, fitresult, Xnew)
+function MMI.predict(m::LaplaceModels, fitresult, Xnew)
la, decode = fitresult
if typeof(m) == LaplaceRegressor
yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index c225bfde..6254a7de 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -7,6 +7,7 @@ using MLJ
using MLJ:predict,fit!
using LaplaceRedux
+cv = CV(; nfolds=3)
@testset "Regression" begin
flux_model = Chain(
@@ -36,41 +37,40 @@ end
@testset "Classification" begin
-# Define the model
-flux_model = Chain(
- Dense(4, 10, relu),
- Dense(10, 3)
-)
+ # Define the model
+ flux_model = Chain(
+ Dense(4, 10, relu),
+ Dense(10, 3)
+ )
-model = LaplaceClassifier(model=flux_model,epochs=50)
+ model = LaplaceClassifier(model=flux_model,epochs=50)
-X, y = @load_iris
-mach = machine(model, X, y)
-MLJBase.fit!(mach,verbosity=1)
-Xnew = (sepal_length = [6.4, 7.2, 7.4],
- sepal_width = [2.8, 3.0, 2.8],
- petal_length = [5.6, 5.8, 6.1],
- petal_width = [2.1, 1.6, 1.9],)
-yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
-predict_mode(mach, Xnew) # point predictions
-pdf.(yhat, "virginica") # probabilities for the "verginica" class
-MLJBase.fitted_params(mach) # fitted params
-MLJBase.training_losses(mach) #training loss history
-model.epochs= 100 #changing number of epochs
-MLJBase.fit!(mach) #testing update function
-model.epochs= 50 #changing number of epochs to a lower number
-MLJBase.fit!(mach) #testing update function
-model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
-MLJBase.fit!(mach) #testing update function (the laplace part)
+ X, y = @load_iris
+ mach = machine(model, X, y)
+ MLJBase.fit!(mach,verbosity=1)
+ Xnew = (sepal_length = [6.4, 7.2, 7.4],
+ sepal_width = [2.8, 3.0, 2.8],
+ petal_length = [5.6, 5.8, 6.1],
+ petal_width = [2.1, 1.6, 1.9],)
+ yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
+ predict_mode(mach, Xnew) # point predictions
+ pdf.(yhat, "virginica") # probabilities for the "verginica" class
+ MLJBase.fitted_params(mach) # fitted params
+ MLJBase.training_losses(mach) #training loss history
+ model.epochs= 100 #changing number of epochs
+ MLJBase.fit!(mach) #testing update function
+ model.epochs= 50 #changing number of epochs to a lower number
+ MLJBase.fit!(mach) #testing update function
+ model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
+ MLJBase.fit!(mach) #testing update function (the laplace part)
-# Define a different model
-flux_model_two = Chain(
- Dense(4, 6, relu),
- Dense(6, 3)
-)
+ # Define a different model
+ flux_model_two = Chain(
+ Dense(4, 6, relu),
+ Dense(6, 3)
+ )
-model.model = flux_model_two
+ model.model = flux_model_two
-MLJBase.fit!(mach)
-
+ MLJBase.fit!(mach)
end
From 71a361135980f1a812fd6de63aa71d55d13698b7 Mon Sep 17 00:00:00 2001
From: pat-alt
Date: Wed, 16 Oct 2024 15:12:50 +0200
Subject: [PATCH 40/60] nearly there I think
---
src/direct_mlj.jl | 19 +-
test/Manifest.toml | 539 +++++++++++++++++++++++++++++----------------
test/runtests.jl | 3 -
3 files changed, 357 insertions(+), 204 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index e5b611d1..cc6133fd 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -47,20 +47,23 @@ end
LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier}
# for fit:
-MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :))
+MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing))
function MMI.reformat(::LaplaceClassifier, X, y)
X = MLJBase.matrix(X) |> permutedims
y = categorical(y)
- unique_labels = y.pool.levels
- y = Flux.onehotbatch(y, unique_labels) # One-hot encoding
+ labels = y.pool.levels
+ y = Flux.onehotbatch(y, labels) # One-hot encoding
- return X, y
+ return X, (y, labels)
end
MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,)
+MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:,I], y[2]))
+MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],)
+
@doc """
MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)
@@ -78,10 +81,7 @@ Fit a Laplace model using the provided features and target values.
- `report`: A Namedtuple containing the loss history of the fitting process.
"""
function MMI.fit(m::LaplaceModels, verbosity, X, y)
- decode = y[2]
-
- y= y[1]
-
+ y, decode = y
# Make a copy of the model because Flux does not allow to mutate hyperparameters
copied_model = deepcopy(m.model)
@@ -165,8 +165,7 @@ Update the Laplace model using the provided new data points.
- `report`: A Namedtuple containing the complete loss history of the fitting process.
"""
function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
- decode = y[2]
- y_up=y[1]
+ y_up, decode = y
data_loader = Flux.DataLoader((X, y_up); batchsize=m.batch_size)
old_model = old_cache[1]
diff --git a/test/Manifest.toml b/test/Manifest.toml
index 7f9e18a9..3448bc9a 100644
--- a/test/Manifest.toml
+++ b/test/Manifest.toml
@@ -1,8 +1,14 @@
# This file is machine-generated - editing it directly is not advised
-julia_version = "1.10.3"
+julia_version = "1.10.5"
manifest_format = "2.0"
-project_hash = "30dc96d6146892242111894ebf221bf701ee0fdd"
+project_hash = "48e3a5a4625c4493599b02acbfe0e972463bd78f"
+
+[[deps.ARFFFiles]]
+deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"]
+git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409"
+uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8"
+version = "1.4.1"
[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
@@ -21,24 +27,28 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.4.5"
[[deps.Accessors]]
-deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"]
-git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a"
+deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"]
+git-tree-sha1 = "b392ede862e506d451fc1616e79aa6f4c673dab8"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
-version = "0.1.37"
+version = "0.1.38"
[deps.Accessors.extensions]
AccessorsAxisKeysExt = "AxisKeys"
+ AccessorsDatesExt = "Dates"
AccessorsIntervalSetsExt = "IntervalSets"
AccessorsStaticArraysExt = "StaticArrays"
AccessorsStructArraysExt = "StructArrays"
+ AccessorsTestExt = "Test"
AccessorsUnitfulExt = "Unitful"
[deps.Accessors.weakdeps]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
+ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
+ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
[[deps.Adapt]]
@@ -59,9 +69,9 @@ version = "1.1.3"
[[deps.Aqua]]
deps = ["Compat", "Pkg", "Test"]
-git-tree-sha1 = "12e575f31a6f233ba2485ed86b9325b85df37c61"
+git-tree-sha1 = "49b1d7a9870c87ba13dc63f8ccfcf578cb266f95"
uuid = "4c88cf16-eb10-579e-8560-4a9242c79595"
-version = "0.8.7"
+version = "0.8.9"
[[deps.ArgCheck]]
git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
@@ -111,10 +121,10 @@ uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
version = "0.3.9"
[[deps.BangBang]]
-deps = ["Accessors", "Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"]
-git-tree-sha1 = "08e5fc6620a8d83534bf6149795054f1b1e8370a"
+deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"]
+git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae"
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
-version = "0.4.2"
+version = "0.4.3"
[deps.BangBang.extensions]
BangBangChainRulesCoreExt = "ChainRulesCore"
@@ -146,9 +156,9 @@ uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35"
version = "0.1.9"
[[deps.BufferedStreams]]
-git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec"
+git-tree-sha1 = "6863c5b7fc997eadcabdbaf6c5f201dc30032643"
uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d"
-version = "1.2.1"
+version = "1.2.2"
[[deps.Bzip2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -169,15 +179,9 @@ version = "0.10.14"
[[deps.Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd"
+git-tree-sha1 = "009060c9a6168704143100f36ab08f06c2af4642"
uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a"
-version = "1.18.0+2"
-
-[[deps.Calculus]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad"
-uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
-version = "0.5.1"
+version = "1.18.2+1"
[[deps.CategoricalArrays]]
deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"]
@@ -206,15 +210,15 @@ version = "0.1.15"
[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"]
-git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03"
+git-tree-sha1 = "be227d253d132a6d57f9ccf5f67c0fb6488afd87"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
-version = "1.69.0"
+version = "1.71.0"
[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
-git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f"
+git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
-version = "1.24.0"
+version = "1.25.0"
weakdeps = ["SparseArrays"]
[deps.ChainRulesCore.extensions]
@@ -234,9 +238,9 @@ version = "0.10.4+0"
[[deps.CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
-git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8"
+git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
-version = "0.7.5"
+version = "0.7.6"
[[deps.ColorSchemes]]
deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"]
@@ -272,16 +276,16 @@ uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
version = "1.0.2"
[[deps.CommonSubexpressions]]
-deps = ["MacroTools", "Test"]
-git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
+deps = ["MacroTools"]
+git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
-version = "0.3.0"
+version = "0.3.1"
[[deps.Compat]]
deps = ["TOML", "UUIDs"]
-git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248"
+git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
-version = "4.15.0"
+version = "4.16.0"
weakdeps = ["Dates", "LinearAlgebra"]
[deps.Compat.extensions]
@@ -313,17 +317,18 @@ uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb"
version = "2.4.2"
[[deps.ConstructionBase]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98"
+git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157"
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
-version = "1.5.6"
+version = "1.5.8"
[deps.ConstructionBase.extensions]
ConstructionBaseIntervalSetsExt = "IntervalSets"
+ ConstructionBaseLinearAlgebraExt = "LinearAlgebra"
ConstructionBaseStaticArraysExt = "StaticArrays"
[deps.ConstructionBase.weakdeps]
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
+ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[[deps.ContextVariablesX]]
@@ -370,10 +375,10 @@ uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
version = "0.7.13"
[[deps.DataFrames]]
-deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
-git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8"
+deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
+git-tree-sha1 = "fb61b4812c49343d7ef0b533ba982c46021938a6"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
-version = "1.6.1"
+version = "1.7.0"
[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
@@ -390,6 +395,12 @@ version = "1.0.0"
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
+[[deps.Dbus_jll]]
+deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "fc173b380865f70627d7dd1190dc2fce6cc105af"
+uuid = "ee1fde0b-3d02-5ea6-8484-8dfef6360eab"
+version = "1.14.10+0"
+
[[deps.DecisionTree]]
deps = ["AbstractTrees", "DelimitedFiles", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics"]
git-tree-sha1 = "526ca14aaaf2d5a0e242f3a8a7966eb9065d7d78"
@@ -436,9 +447,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
-git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e"
+git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
-version = "0.25.109"
+version = "0.25.112"
[deps.Distributions.extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore"
@@ -461,11 +472,11 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"
-[[deps.DualNumbers]]
-deps = ["Calculus", "NaNMath", "SpecialFunctions"]
-git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566"
-uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
-version = "0.6.8"
+[[deps.EarlyStopping]]
+deps = ["Dates", "Statistics"]
+git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6"
+uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
+version = "0.3.0"
[[deps.EpollShim_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
@@ -487,9 +498,9 @@ version = "2.6.2+0"
[[deps.FFMPEG]]
deps = ["FFMPEG_jll"]
-git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8"
+git-tree-sha1 = "53ebe7511fa11d33bec688a9178fac4e49eeee00"
uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a"
-version = "0.4.1"
+version = "0.4.2"
[[deps.FFMPEG_jll]]
deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"]
@@ -511,24 +522,29 @@ version = "0.1.1"
[[deps.FileIO]]
deps = ["Pkg", "Requires", "UUIDs"]
-git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322"
+git-tree-sha1 = "62ca0547a14c57e98154423419d8a342dca75ca9"
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
-version = "1.16.3"
+version = "1.16.4"
[[deps.FilePathsBase]]
-deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"]
-git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa"
+deps = ["Compat", "Dates"]
+git-tree-sha1 = "7878ff7172a8e6beedd1dea14bd27c3c6340d361"
uuid = "48062228-2e41-5def-b9a4-89aafe57970f"
-version = "0.9.21"
+version = "0.9.22"
+weakdeps = ["Mmap", "Test"]
+
+ [deps.FilePathsBase.extensions]
+ FilePathsBaseMmapExt = "Mmap"
+ FilePathsBaseTestExt = "Test"
[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
[[deps.FillArrays]]
deps = ["LinearAlgebra"]
-git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57"
+git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
-version = "1.11.0"
+version = "1.13.0"
weakdeps = ["PDMats", "SparseArrays", "Statistics"]
[deps.FillArrays.extensions]
@@ -543,21 +559,27 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.5"
[[deps.Flux]]
-deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
-git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a"
+deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
+git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.14.16"
+version = "0.14.22"
[deps.Flux.extensions]
FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
+ FluxEnzymeExt = "Enzyme"
+ FluxMPIExt = "MPI"
+ FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
FluxMetalExt = "Metal"
[deps.Flux.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
+ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
+ NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[[deps.Fontconfig_jll]]
@@ -595,25 +617,25 @@ version = "1.0.14+0"
[[deps.Functors]]
deps = ["LinearAlgebra"]
-git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05"
+git-tree-sha1 = "64d8e93700c7a3f28f717d265382d52fac9fa1c1"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
-version = "0.4.11"
+version = "0.4.12"
[[deps.Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
[[deps.GLFW_jll]]
-deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"]
-git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297"
+deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"]
+git-tree-sha1 = "532f9126ad901533af1d4f5c198867227a7bb077"
uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89"
-version = "3.4.0+0"
+version = "3.4.0+1"
[[deps.GPUArrays]]
deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
-git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30"
+git-tree-sha1 = "62ee71528cca49be797076a76bdc654a170a523e"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
-version = "10.3.0"
+version = "10.3.1"
[[deps.GPUArraysCore]]
deps = ["Adapt"]
@@ -623,15 +645,15 @@ version = "0.1.6"
[[deps.GR]]
deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"]
-git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5"
+git-tree-sha1 = "ee28ddcd5517d54e417182fec3886e7412d3926f"
uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
-version = "0.73.7"
+version = "0.73.8"
[[deps.GR_jll]]
deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d"
+git-tree-sha1 = "f31929b9e67066bee48eec8b03c0df47d31a74b3"
uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
-version = "0.73.7+0"
+version = "0.73.8+0"
[[deps.GZip]]
deps = ["Libdl", "Zlib_jll"]
@@ -647,9 +669,9 @@ version = "0.21.0+0"
[[deps.Glib_jll]]
deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"]
-git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba"
+git-tree-sha1 = "674ff0db93fffcd11a3573986e550d66cd4fd71f"
uuid = "7746bdde-850d-59dc-9ae8-88ece973131d"
-version = "2.80.2+0"
+version = "2.80.5+0"
[[deps.Glob]]
git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496"
@@ -680,10 +702,10 @@ version = "0.17.2"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
[[deps.HDF5_jll]]
-deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"]
-git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997"
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"]
+git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739"
uuid = "0234f1f7-429e-5d53-9886-15a909be8d59"
-version = "1.14.3+3"
+version = "1.14.2+1"
[[deps.HTTP]]
deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"]
@@ -692,22 +714,22 @@ uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3"
version = "1.10.8"
[[deps.HarfBuzz_jll]]
-deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"]
-git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3"
+deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll"]
+git-tree-sha1 = "401e4f3f30f43af2c8478fc008da50096ea5240f"
uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566"
-version = "2.8.1+1"
+version = "8.3.1+0"
[[deps.Hwloc_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e"
+git-tree-sha1 = "dd3b49277ec2bb2c6b94eb1604d4d0616016f7a6"
uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8"
-version = "2.11.1+0"
+version = "2.11.2+0"
[[deps.HypergeometricFunctions]]
-deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
-git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685"
+deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
+git-tree-sha1 = "7c4195be1649ae622304031ed46a2f4df989f1eb"
uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
-version = "0.3.23"
+version = "0.3.24"
[[deps.IRTools]]
deps = ["InteractiveUtils", "MacroTools"]
@@ -762,14 +784,14 @@ uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01"
version = "0.7.0"
[[deps.InverseFunctions]]
-deps = ["Test"]
-git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd"
+git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
-version = "0.1.15"
-weakdeps = ["Dates"]
+version = "0.1.17"
+weakdeps = ["Dates", "Test"]
[deps.InverseFunctions.extensions]
- DatesExt = "Dates"
+ InverseFunctionsDatesExt = "Dates"
+ InverseFunctionsTestExt = "Test"
[[deps.InvertedIndices]]
git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038"
@@ -781,28 +803,34 @@ git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.2.2"
+[[deps.IterationControl]]
+deps = ["EarlyStopping", "InteractiveUtils"]
+git-tree-sha1 = "e663925ebc3d93c1150a7570d114f9ea2f664726"
+uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
+version = "0.5.4"
+
[[deps.IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"
[[deps.JLD2]]
-deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"]
-git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b"
+deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
+git-tree-sha1 = "aeab5c68eb2cf326619bf71235d8f4561c62fe22"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
-version = "0.4.50"
+version = "0.5.5"
[[deps.JLFzf]]
deps = ["Pipe", "REPL", "Random", "fzf_jll"]
-git-tree-sha1 = "a53ebe394b71470c7f97c2e7e170d51df21b17af"
+git-tree-sha1 = "39d64b09147620f5ffbf6b2d3255be3c901bec63"
uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c"
-version = "0.1.7"
+version = "0.1.8"
[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
-git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
+git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
-version = "1.5.0"
+version = "1.6.1"
[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
@@ -824,9 +852,9 @@ version = "1.14.0"
[[deps.JpegTurbo_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637"
+git-tree-sha1 = "25ee0be4d43d0269027024d75a24c24d6c6e590c"
uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8"
-version = "3.0.3+0"
+version = "3.0.4+0"
[[deps.JuliaVariables]]
deps = ["MLStyle", "NameResolution"]
@@ -835,16 +863,20 @@ uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec"
version = "0.2.4"
[[deps.KernelAbstractions]]
-deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
-git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec"
+deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
+git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
-version = "0.9.22"
+version = "0.9.28"
[deps.KernelAbstractions.extensions]
EnzymeExt = "EnzymeCore"
+ LinearAlgebraExt = "LinearAlgebra"
+ SparseArraysExt = "SparseArrays"
[deps.KernelAbstractions.weakdeps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
+ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[deps.LAME_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
@@ -853,16 +885,16 @@ uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d"
version = "3.100.2+0"
[[deps.LERC_jll]]
-deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434"
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "36bdbc52f13a7d1dcb0f3cd694e01677a515655b"
uuid = "88015f11-f218-50d7-93a8-a6af411a945d"
-version = "3.0.0+1"
+version = "4.0.0+0"
[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"]
-git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c"
+git-tree-sha1 = "4ad43cb0a4bb5e5b1506e1d1f48646d7e0c80363"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
-version = "8.0.0"
+version = "9.1.2"
weakdeps = ["BFloat16s"]
[deps.LLVM.extensions]
@@ -870,41 +902,49 @@ weakdeps = ["BFloat16s"]
[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e"
+git-tree-sha1 = "05a8bd5a42309a9ec82f700876903abce1017dd3"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
-version = "0.0.30+0"
+version = "0.0.34+0"
[[deps.LLVMOpenMP_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713"
+git-tree-sha1 = "78211fb6cbc872f77cad3fc0b6cf647d923f4929"
uuid = "1d63c593-3942-5779-bab2-d838dc0a180e"
-version = "15.0.7+0"
+version = "18.1.7+0"
[[deps.LZO_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d"
+git-tree-sha1 = "854a9c268c43b77b0a27f22d7fab8d33cdb3a731"
uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac"
-version = "2.10.2+0"
+version = "2.10.2+1"
[[deps.LaTeXStrings]]
-git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
+git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c"
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
-version = "1.3.1"
+version = "1.4.0"
[[deps.Latexify]]
deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"]
-git-tree-sha1 = "5b0d630f3020b82c0775a51d05895852f8506f50"
+git-tree-sha1 = "ce5f5621cac23a86011836badfedf664a612cee4"
uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
-version = "0.16.4"
+version = "0.16.5"
[deps.Latexify.extensions]
DataFramesExt = "DataFrames"
+ SparseArraysExt = "SparseArrays"
SymEngineExt = "SymEngine"
[deps.Latexify.weakdeps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
+ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8"
+[[deps.LatinHypercubeSampling]]
+deps = ["Random", "StableRNGs", "StatsBase", "Test"]
+git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8"
+uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d"
+version = "1.9.0"
+
[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
@@ -985,9 +1025,9 @@ version = "2.40.1+0"
[[deps.Libtiff_jll]]
deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"]
-git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a"
+git-tree-sha1 = "b404131d06f7886402758c9ce2214b636eb4d54a"
uuid = "89763e89-9b03-5906-acba-b20f662cd828"
-version = "4.5.1+1"
+version = "4.7.0+0"
[[deps.Libuuid_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
@@ -1030,35 +1070,109 @@ git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96"
uuid = "23992714-dd62-5051-b70f-ba57cb901cac"
version = "0.10.7"
+[[deps.MLDataDevices]]
+deps = ["Adapt", "Functors", "Preferences", "Random"]
+git-tree-sha1 = "e16288e37e76d68c3f1c418e0a2bec88d98d55fc"
+uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
+version = "1.2.0"
+
+ [deps.MLDataDevices.extensions]
+ MLDataDevicesAMDGPUExt = "AMDGPU"
+ MLDataDevicesCUDAExt = "CUDA"
+ MLDataDevicesChainRulesCoreExt = "ChainRulesCore"
+ MLDataDevicesFillArraysExt = "FillArrays"
+ MLDataDevicesGPUArraysExt = "GPUArrays"
+ MLDataDevicesMLUtilsExt = "MLUtils"
+ MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
+ MLDataDevicesReactantExt = "Reactant"
+ MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
+ MLDataDevicesReverseDiffExt = "ReverseDiff"
+ MLDataDevicesSparseArraysExt = "SparseArrays"
+ MLDataDevicesTrackerExt = "Tracker"
+ MLDataDevicesZygoteExt = "Zygote"
+ MLDataDevicescuDNNExt = ["CUDA", "cuDNN"]
+ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
+
+ [deps.MLDataDevices.weakdeps]
+ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
+ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
+ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
+ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
+ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
+ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
+ Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
+ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
+ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
+ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
+ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
+
[[deps.MLDatasets]]
deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"]
-git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818"
+git-tree-sha1 = "361c2692ee730944764945859f1a6b31072e275d"
uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458"
-version = "0.7.16"
+version = "0.7.18"
+
+[[deps.MLFlowClient]]
+deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"]
+git-tree-sha1 = "9abb12b62debc27261c008daa13627255bf79967"
+uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
+version = "0.5.1"
+
+[[deps.MLJ]]
+deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBalancing", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "StatisticalMeasures", "Statistics", "StatsBase", "Tables"]
+git-tree-sha1 = "bd2072e9cd65be0a3cb841f3d8cda1d2cacfe5db"
+uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
+version = "0.20.5"
+
+[[deps.MLJBalancing]]
+deps = ["MLJBase", "MLJModelInterface", "MLUtils", "OrderedCollections", "Random", "StatsBase"]
+git-tree-sha1 = "f707a01a92d664479522313907c07afa5d81df19"
+uuid = "45f359ea-796d-4f51-95a5-deb1a414c586"
+version = "0.1.5"
[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "1.7.0"
+weakdeps = ["StatisticalMeasures"]
[deps.MLJBase.extensions]
DefaultMeasuresExt = "StatisticalMeasures"
- [deps.MLJBase.weakdeps]
- StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
-
[[deps.MLJDecisionTreeInterface]]
deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"]
git-tree-sha1 = "90ef4d3b6cacec631c57cc034e1e61b4aa0ce511"
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
version = "0.4.2"
+[[deps.MLJEnsembles]]
+deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatisticalMeasuresBase", "StatsBase"]
+git-tree-sha1 = "84a5be55a364bb6b6dc7780bbd64317ebdd3ad1e"
+uuid = "50ed68f4-41fd-4504-931a-ed422449fee0"
+version = "0.4.3"
+
+[[deps.MLJFlow]]
+deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"]
+git-tree-sha1 = "508bff8071d7d1902d6f1b9d1e868d58821f1cfe"
+uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f"
+version = "0.5.0"
+
[[deps.MLJFlux]]
deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables"]
-git-tree-sha1 = "50c7f24b84005a2a80875c10d4f4059df17a0f68"
+git-tree-sha1 = "98fd05da1bc1527f7849efb645ef1921ccf97c9a"
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
-version = "0.5.1"
+version = "0.6.0"
+
+[[deps.MLJIteration]]
+deps = ["IterationControl", "MLJBase", "Random", "Serialization"]
+git-tree-sha1 = "ad16cfd261e28204847f509d1221a581286448ae"
+uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
+version = "0.6.3"
[[deps.MLJModelInterface]]
deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
@@ -1072,6 +1186,12 @@ git-tree-sha1 = "410da88e0e6ece5467293d2c76b51b7c6df7d072"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
version = "0.16.17"
+[[deps.MLJTuning]]
+deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase", "StatisticalMeasuresBase"]
+git-tree-sha1 = "38aab60b1274ce7d6da784808e3be69e585dbbf6"
+uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
+version = "0.8.8"
+
[[deps.MLStyle]]
git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
uuid = "d8e11817-5142-5d16-987a-aa16d5891078"
@@ -1085,9 +1205,9 @@ version = "0.4.4"
[[deps.MPICH_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
-git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004"
+git-tree-sha1 = "7715e65c47ba3941c502bffb7f266a41a7f54423"
uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4"
-version = "4.2.2+0"
+version = "4.2.3+0"
[[deps.MPIPreferences]]
deps = ["Libdl", "Preferences"]
@@ -1097,9 +1217,9 @@ version = "0.1.11"
[[deps.MPItrampoline_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
-git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0"
+git-tree-sha1 = "70e830dab5d0775183c99fc75e4c24c614ed7142"
uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748"
-version = "5.4.0+0"
+version = "5.5.1+0"
[[deps.MacroTools]]
deps = ["Markdown", "Random"]
@@ -1134,9 +1254,9 @@ version = "0.3.2"
[[deps.Metalhead]]
deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"]
-git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152"
+git-tree-sha1 = "aef476e4958303f5ea9e1deb81a1ba2f510d4e11"
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
-version = "0.9.3"
+version = "0.9.4"
[deps.Metalhead.extensions]
MetalheadCUDAExt = "CUDA"
@@ -1182,10 +1302,10 @@ uuid = "6f286f6a-111f-5878-ab1e-185364afe411"
version = "0.10.3"
[[deps.NNlib]]
-deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
-git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577"
+deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"]
+git-tree-sha1 = "da09a1e112fd75f9af2a5229323f01b56ec96a4c"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.9.21"
+version = "0.9.24"
[deps.NNlib.extensions]
NNlibAMDGPUExt = "AMDGPU"
@@ -1193,12 +1313,14 @@ version = "0.9.21"
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"
+ NNlibForwardDiffExt = "ForwardDiff"
[deps.NNlib.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
+ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[[deps.NPZ]]
@@ -1254,11 +1376,17 @@ deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+2"
+[[deps.OpenML]]
+deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"]
+git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33"
+uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66"
+version = "0.3.1"
+
[[deps.OpenMPI_jll]]
-deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
-git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762"
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"]
+git-tree-sha1 = "bfce6d523861a6c562721b262c0d1aaeead2647f"
uuid = "fe0851c0-eecd-5654-98d4-656369965a5c"
-version = "4.1.6+0"
+version = "5.0.5+0"
[[deps.OpenSSL]]
deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"]
@@ -1268,9 +1396,9 @@ version = "1.4.3"
[[deps.OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5"
+git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10"
uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
-version = "3.0.14+0"
+version = "3.0.15+1"
[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
@@ -1285,10 +1413,10 @@ uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.3.3"
[[deps.Opus_jll]]
-deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720"
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "6703a85cb3781bd5909d48730a67205f3f31a575"
uuid = "91d4177d-7536-5919-b921-800302f37372"
-version = "1.3.2+0"
+version = "1.3.3+0"
[[deps.OrderedCollections]]
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
@@ -1318,6 +1446,12 @@ git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f"
uuid = "5432bcbf-9aad-5242-b902-cca2824c8663"
version = "0.5.12"
+[[deps.Pango_jll]]
+deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "e127b609fb9ecba6f201ba7ab753d5a605d53801"
+uuid = "36c8627f-9965-5494-a995-c6b170f724f3"
+version = "1.54.1+0"
+
[[deps.Parameters]]
deps = ["OrderedCollections", "UnPack"]
git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe"
@@ -1378,9 +1512,9 @@ version = "1.4.1"
[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"]
-git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf"
+git-tree-sha1 = "45470145863035bb124ca51b320ed35d071cc6c2"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
-version = "1.40.5"
+version = "1.40.8"
[deps.Plots.extensions]
FileIOExt = "FileIO"
@@ -1426,9 +1560,9 @@ version = "0.4.2"
[[deps.PrettyTables]]
deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"]
-git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7"
+git-tree-sha1 = "1101cd475833706e4d0e7b122218257178f48f34"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
-version = "2.3.2"
+version = "2.4.0"
[[deps.Printf]]
deps = ["Unicode"]
@@ -1447,9 +1581,9 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.10.2"
[[deps.PtrArrays]]
-git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
+git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f"
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
-version = "1.2.0"
+version = "1.2.1"
[[deps.Qt6Base_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"]
@@ -1477,9 +1611,15 @@ version = "6.7.1+1"
[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
-git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e"
+git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
-version = "2.9.4"
+version = "2.11.1"
+
+ [deps.QuadGK.extensions]
+ QuadGKEnzymeExt = "Enzyme"
+
+ [deps.QuadGK.weakdeps]
+ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
[[deps.REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
@@ -1526,15 +1666,15 @@ version = "1.3.0"
[[deps.Rmath]]
deps = ["Random", "Rmath_jll"]
-git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b"
+git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
-version = "0.7.1"
+version = "0.8.0"
[[deps.Rmath_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21"
+git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
-version = "0.4.2+0"
+version = "0.5.1+0"
[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
@@ -1590,9 +1730,9 @@ uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f"
version = "1.0.3"
[[deps.SimpleBufferStream]]
-git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1"
+git-tree-sha1 = "f305871d2f381d21527c770d4788c06c097c9bc1"
uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7"
-version = "1.1.0"
+version = "1.2.0"
[[deps.SimpleTraits]]
deps = ["InteractiveUtils", "MacroTools"]
@@ -1664,6 +1804,20 @@ git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682"
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
version = "1.4.3"
+[[deps.StatisticalMeasures]]
+deps = ["CategoricalArrays", "CategoricalDistributions", "Distributions", "LearnAPI", "LinearAlgebra", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "StatisticalMeasuresBase", "Statistics", "StatsBase"]
+git-tree-sha1 = "c1d4318fa41056b839dfbb3ee841f011fa6e8518"
+uuid = "a19d573c-0a75-4610-95b3-7071388c7541"
+version = "0.1.7"
+
+ [deps.StatisticalMeasures.extensions]
+ LossFunctionsExt = "LossFunctions"
+ ScientificTypesExt = "ScientificTypes"
+
+ [deps.StatisticalMeasures.weakdeps]
+ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
+ ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
+
[[deps.StatisticalMeasuresBase]]
deps = ["CategoricalArrays", "InteractiveUtils", "MLUtils", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "Statistics"]
git-tree-sha1 = "17dfb22e2e4ccc9cd59b487dce52883e0151b4d3"
@@ -1695,9 +1849,9 @@ version = "0.34.3"
[[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
-git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a"
+git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
-version = "1.3.1"
+version = "1.3.2"
weakdeps = ["ChainRulesCore", "InverseFunctions"]
[deps.StatsFuns.extensions]
@@ -1724,9 +1878,9 @@ version = "0.3.7"
[[deps.StringManipulation]]
deps = ["PrecompileTools"]
-git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5"
+git-tree-sha1 = "a6b1675a536c5ad1a60e5a5153e1fee12eb146e3"
uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e"
-version = "0.3.4"
+version = "0.4.0"
[[deps.StructArrays]]
deps = ["ConstructionBase", "DataAPI", "Tables"]
@@ -1743,9 +1897,9 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"]
[[deps.StructTypes]]
deps = ["Dates", "UUIDs"]
-git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70"
+git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8"
uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
-version = "1.10.0"
+version = "1.11.0"
[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
@@ -1774,10 +1928,9 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.12.0"
[[deps.TaijaBase]]
-deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"]
-git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0"
+git-tree-sha1 = "4076f60078b12095ca71a2c26e2e4515e3a6a5e5"
uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6"
-version = "1.2.2"
+version = "1.2.3"
[[deps.TaijaData]]
deps = ["CSV", "CounterfactualExplanations", "DataAPI", "DataFrames", "Flux", "LazyArtifacts", "MLDatasets", "MLJBase", "MLJModels", "Random", "StatsBase"]
@@ -1801,21 +1954,18 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[deps.TranscodingStreams]]
-git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a"
+git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
-version = "0.11.1"
-weakdeps = ["Random", "Test"]
-
- [deps.TranscodingStreams.extensions]
- TestExt = ["Test", "Random"]
+version = "0.11.3"
[[deps.Transducers]]
-deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"]
-git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23"
+deps = ["Accessors", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"]
+git-tree-sha1 = "7deeab4ff96b85c5f72c824cae53a1398da3d1cb"
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
-version = "0.4.82"
+version = "0.4.84"
[deps.Transducers.extensions]
+ TransducersAdaptExt = "Adapt"
TransducersBlockArraysExt = "BlockArrays"
TransducersDataFramesExt = "DataFrames"
TransducersLazyArraysExt = "LazyArrays"
@@ -1823,6 +1973,7 @@ version = "0.4.82"
TransducersReferenceablesExt = "Referenceables"
[deps.Transducers.weakdeps]
+ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
@@ -1905,9 +2056,9 @@ version = "0.2.1"
[[deps.UnsafeAtomicsLLVM]]
deps = ["LLVM", "UnsafeAtomics"]
-git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5"
+git-tree-sha1 = "2d17fabcd17e67d7625ce9c531fb9f40b7c42ce4"
uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
-version = "0.1.5"
+version = "0.2.1"
[[deps.Unzip]]
git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78"
@@ -1945,9 +2096,9 @@ version = "1.6.1"
[[deps.XML2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"]
-git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d"
+git-tree-sha1 = "1165b0443d0eca63ac1e32b8c0eb69ed2f4f8127"
uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a"
-version = "2.13.1+0"
+version = "2.13.3+0"
[[deps.XSLT_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"]
@@ -2118,15 +2269,15 @@ version = "1.2.13+1"
[[deps.Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b"
+git-tree-sha1 = "555d1076590a6cc2fdee2ef1469451f872d8b41b"
uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
-version = "1.5.6+0"
+version = "1.5.6+1"
[[deps.Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
-git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54"
+git-tree-sha1 = "f816633be6dc5c0ed9ffedda157ecfda0b3b6a69"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.70"
+version = "0.6.72"
[deps.Zygote.extensions]
ZygoteColorsExt = "Colors"
@@ -2152,9 +2303,9 @@ version = "3.2.9+0"
[[deps.fzf_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "a68c9655fbe6dfcab3d972808f1aafec151ce3f8"
+git-tree-sha1 = "936081b536ae4aa65415d869287d43ef3cb576b2"
uuid = "214eeab7-80f7-51ab-84ad-2988db7cef09"
-version = "0.43.0+0"
+version = "0.53.0+0"
[[deps.gperf_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -2175,15 +2326,21 @@ uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b"
version = "3.9.0+0"
[[deps.libass_jll]]
-deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"]
-git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47"
+deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Zlib_jll"]
+git-tree-sha1 = "e17c115d55c5fbb7e52ebedb427a0dca79d4484e"
uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0"
-version = "0.15.1+0"
+version = "0.15.2+0"
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
-version = "5.8.0+1"
+version = "5.11.0+0"
+
+[[deps.libdecor_jll]]
+deps = ["Artifacts", "Dbus_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pango_jll", "Wayland_jll", "xkbcommon_jll"]
+git-tree-sha1 = "9bf7903af251d2050b467f76bdbe57ce541f7f4f"
+uuid = "1183f4f0-6f2a-5f1a-908b-139f9cdfea6f"
+version = "0.2.2+0"
[[deps.libevdev_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -2192,10 +2349,10 @@ uuid = "2db6ffa8-e38f-5e21-84af-90c45d0032cc"
version = "1.11.0+0"
[[deps.libfdk_aac_jll]]
-deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55"
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "8a22cf860a7d27e4f3498a0fe0811a7957badb38"
uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280"
-version = "2.0.2+0"
+version = "2.0.3+0"
[[deps.libinput_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "eudev_jll", "libevdev_jll", "mtdev_jll"]
@@ -2205,15 +2362,15 @@ version = "1.18.0+0"
[[deps.libpng_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"]
-git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4"
+git-tree-sha1 = "b70c870239dc3d7bc094eb2d6be9b73d27bef280"
uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f"
-version = "1.6.43+1"
+version = "1.6.44+0"
[[deps.libvorbis_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"]
-git-tree-sha1 = "b910cb81ef3fe6e78bf6acee440bda86fd6ae00c"
+git-tree-sha1 = "490376214c4721cdaca654041f635213c6165cb3"
uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a"
-version = "1.3.7+1"
+version = "1.3.7+2"
[[deps.mtdev_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
diff --git a/test/runtests.jl b/test/runtests.jl
index 4995c6f9..031224ea 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -35,9 +35,6 @@ using Test
include("krondecomposed.jl")
end
- #@testset "MLJFlux" begin
- #include("mlj_flux_interfacing.jl")
- #end
@testset "ML" begin
include("direct_mlj_interface.jl")
end
From 74d778e0fd47e5dc7703454a63b8a26106cd23b8 Mon Sep 17 00:00:00 2001
From: pat-alt
Date: Wed, 16 Oct 2024 15:25:01 +0200
Subject: [PATCH 41/60] one more issue with regression
---
src/direct_mlj.jl | 2 +-
test/direct_mlj_interface.jl | 3 ++-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index cc6133fd..819f9b85 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -473,7 +473,7 @@ function MMI.predict(m::LaplaceModels, fitresult, Xnew)
LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |>
permutedims
- return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions)
+ return MLJBase.UnivariateFinite(decode, predictions; pool=missing)
end
end
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 6254a7de..629bce86 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -31,7 +31,7 @@ cv = CV(; nfolds=3)
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
- evaluate!(mach, resampling=cv, measure=l2, verbosity=0)
+ # evaluate!(mach, resampling=cv, measure=l2, verbosity=0)
end
@@ -63,6 +63,7 @@ end
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
+ evaluate!(mach; resampling=cv, measure=brier_loss, verbosity=0)
# Define a different model
flux_model_two = Chain(
From 80784bb02c4af72a4186bed9cd4aed8f8769f463 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 18 Oct 2024 09:11:47 +0200
Subject: [PATCH 42/60] fixed predict so that it return a vector of
distributions-> fixed evaluate!
---
docs/Manifest.toml | 8 ++++----
test/direct_mlj_interface.jl | 18 +++++++++++++-----
2 files changed, 17 insertions(+), 9 deletions(-)
diff --git a/docs/Manifest.toml b/docs/Manifest.toml
index 6d42a1e2..a97bfd27 100644
--- a/docs/Manifest.toml
+++ b/docs/Manifest.toml
@@ -2,7 +2,7 @@
julia_version = "1.10.5"
manifest_format = "2.0"
-project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca"
+project_hash = "616a9e89f5c520a58672ad91b5525001e0dadab3"
[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
@@ -955,10 +955,10 @@ uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
version = "1.3.1"
[[deps.LaplaceRedux]]
-deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
-git-tree-sha1 = "a84b72a27c93c72a6af5d22216eb81a419b1b97a"
+deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
+path = "C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl"
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
-version = "1.0.2"
+version = "1.1.1"
[[deps.Latexify]]
deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"]
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 629bce86..c69f9cc5 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -18,11 +18,12 @@ cv = CV(; nfolds=3)
model = LaplaceRegressor(model=flux_model,epochs=50)
X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
+ #train, test = partition(eachindex(y), 0.7); # 70:30 split
mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
- MLJBase.fit!(mach,verbosity=1)
- Xnew, _ = make_regression(3, 4; rng=123)
- yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
- MLJBase.predict_mode(mach, Xnew) # point predictions
+ MLJBase.fit!(mach, verbosity=1)
+ #Xnew, ynew = make_regression(3, 4; rng=123)
+ yhat = MLJBase.predict(mach, X) # probabilistic predictions
+ MLJBase.predict_mode(mach, X) # point predictions
MLJBase.fitted_params(mach) #fitted params function
MLJBase.training_losses(mach) #training loss history
model.epochs= 100 #changing number of epochs
@@ -31,7 +32,14 @@ cv = CV(; nfolds=3)
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
- # evaluate!(mach, resampling=cv, measure=l2, verbosity=0)
+ yhat = MLJBase.predict(mach, X) # probabilistic predictions
+ println( typeof(yhat) )
+ println( size(yhat) )
+ println( typeof(y) )
+ println( size(y) )
+
+ evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0)
+
end
From be80e32e291897a1ceef508c7bcc485a0cb29751 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 18 Oct 2024 09:11:47 +0200
Subject: [PATCH 43/60] amend: fixed predict so that it return a vector of
distributions-> fixed evaluate!
---
docs/Manifest.toml | 8 ++++----
src/direct_mlj.jl | 2 +-
test/direct_mlj_interface.jl | 18 +++++++++++++-----
3 files changed, 18 insertions(+), 10 deletions(-)
diff --git a/docs/Manifest.toml b/docs/Manifest.toml
index 6d42a1e2..a97bfd27 100644
--- a/docs/Manifest.toml
+++ b/docs/Manifest.toml
@@ -2,7 +2,7 @@
julia_version = "1.10.5"
manifest_format = "2.0"
-project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca"
+project_hash = "616a9e89f5c520a58672ad91b5525001e0dadab3"
[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
@@ -955,10 +955,10 @@ uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
version = "1.3.1"
[[deps.LaplaceRedux]]
-deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
-git-tree-sha1 = "a84b72a27c93c72a6af5d22216eb81a419b1b97a"
+deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
+path = "C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl"
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
-version = "1.0.2"
+version = "1.1.1"
[[deps.Latexify]]
deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"]
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 819f9b85..14c5c8fd 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -466,7 +466,7 @@ function MMI.predict(m::LaplaceModels, fitresult, Xnew)
means, variances = yhat
# Create Normal distributions from the means and variances
- return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]
+ return vec([Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)])
else
predictions =
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 629bce86..c69f9cc5 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -18,11 +18,12 @@ cv = CV(; nfolds=3)
model = LaplaceRegressor(model=flux_model,epochs=50)
X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
+ #train, test = partition(eachindex(y), 0.7); # 70:30 split
mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
- MLJBase.fit!(mach,verbosity=1)
- Xnew, _ = make_regression(3, 4; rng=123)
- yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
- MLJBase.predict_mode(mach, Xnew) # point predictions
+ MLJBase.fit!(mach, verbosity=1)
+ #Xnew, ynew = make_regression(3, 4; rng=123)
+ yhat = MLJBase.predict(mach, X) # probabilistic predictions
+ MLJBase.predict_mode(mach, X) # point predictions
MLJBase.fitted_params(mach) #fitted params function
MLJBase.training_losses(mach) #training loss history
model.epochs= 100 #changing number of epochs
@@ -31,7 +32,14 @@ cv = CV(; nfolds=3)
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
- # evaluate!(mach, resampling=cv, measure=l2, verbosity=0)
+ yhat = MLJBase.predict(mach, X) # probabilistic predictions
+ println( typeof(yhat) )
+ println( size(yhat) )
+ println( typeof(y) )
+ println( size(y) )
+
+ evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0)
+
end
From f4fcd958ff6f0722d051b7c2f1d10f0ac0c5e410 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 18 Oct 2024 09:36:38 +0200
Subject: [PATCH 44/60] madea mess with commits.... bah
---
test/direct_mlj_interface.jl | 5 -----
1 file changed, 5 deletions(-)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index c69f9cc5..83d4e86f 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -33,11 +33,6 @@ cv = CV(; nfolds=3)
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
- println( typeof(yhat) )
- println( size(yhat) )
- println( typeof(y) )
- println( size(y) )
-
evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0)
end
From 851784fabea41be5c8c5cee183f897610a86c00e Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 18 Oct 2024 11:27:05 +0200
Subject: [PATCH 45/60] trying to increase patch coverage
---
test/direct_mlj_interface.jl | 15 +++++++++++++--
1 file changed, 13 insertions(+), 2 deletions(-)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 83d4e86f..da0782e7 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -4,8 +4,7 @@ using MLJBase: MLJBase, categorical
using Flux
using StableRNGs
using MLJ
-using MLJ:predict,fit!
-using LaplaceRedux
+import LaplaceRedux: LaplaceClassifier, LaplaceRegressor
cv = CV(; nfolds=3)
@@ -35,6 +34,18 @@ cv = CV(; nfolds=3)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0)
+
+ # Define a different model
+ flux_model_two = Chain(
+ Dense(4, 6, relu),
+ Dense(6, 1)
+ )
+ # test update! fallback to fit!
+ model.model = flux_model_two
+ MLJBase.fit!(mach)
+
+
+
end
From 0752b83180b7282d3726da8baafebb48bc503f9d Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Fri, 18 Oct 2024 11:46:16 +0200
Subject: [PATCH 46/60] fkn hell this codecov bot is worse than the inquisition
---
test/direct_mlj_interface.jl | 3 +++
1 file changed, 3 insertions(+)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index da0782e7..84d0440a 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -44,6 +44,9 @@ cv = CV(; nfolds=3)
model.model = flux_model_two
MLJBase.fit!(mach)
+ model_two = LaplaceRegressor(model=flux_model_two,epochs=100)
+ @test !MLJBase.is_same_except(model,model_two)
+
end
From 573ffd8956af000acfd6ee5bf33336960c66156a Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Mon, 21 Oct 2024 12:30:27 +0200
Subject: [PATCH 47/60] uhmmmmmm
---
src/direct_mlj.jl | 7 -------
1 file changed, 7 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 14c5c8fd..98165616 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -353,18 +353,11 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy
return true
end
-# Define helper functions used in is_same_except
function _isdefined(obj, name)
return hasproperty(obj, name)
end
-function deep_properties(::Type)
- return Set{Symbol}()
-end
-function _equal_to_depth_one(a, b)
- return a == b
-end
function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
if length(chain1.layers) != length(chain2.layers)
From db14b84f30ae256521150fd070f843d6890eb46e Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Mon, 21 Oct 2024 13:35:30 +0200
Subject: [PATCH 48/60] fixed _isdefined
---
src/direct_mlj.jl | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 98165616..b1ffa328 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -352,13 +352,16 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy
end
return true
end
-
-function _isdefined(obj, name)
- return hasproperty(obj, name)
+function _isdefined(object, name)
+ pnames = propertynames(object)
+ fnames = fieldnames(typeof(object))
+ name in pnames && !(name in fnames) && return true
+ isdefined(object, name)
end
+
function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
if length(chain1.layers) != length(chain2.layers)
return false
From 82c5714681b8ab0aa5e8b39581ef8ef8125f7939 Mon Sep 17 00:00:00 2001
From: pat-alt
Date: Tue, 22 Oct 2024 13:07:51 +0200
Subject: [PATCH 49/60] trying to fix docs issue and no longer importing MLJ
nor MLJBase namespace
---
src/direct_mlj.jl | 6 +++---
test/direct_mlj_interface.jl | 1 -
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index b1ffa328..317d4b78 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -5,7 +5,7 @@ using Random
using Tables
using LinearAlgebra
using LaplaceRedux
-using MLJBase
+using MLJBase: MLJBase
import MLJModelInterface as MMI
using Distributions: Normal
@@ -315,7 +315,7 @@ The meaining of "equal" depends on the type of the property value:
- values that are not of `MLJType` are "equal" if they are `==`.
In the special case of a "deep" property, "equal" has a different
-meaning; see [`deep_properties`](@ref)) for details.
+meaning; see [`MMI.StatTraits.deep_properties`](@ref)) for details.
If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`.
@@ -330,7 +330,7 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy
if !_isdefined(m1, name)
!_isdefined(m2, name) || return false
elseif _isdefined(m2, name)
- if name in deep_properties(LaplaceRegressor)
+ if name in MMI.StatTraits.deep_properties(LaplaceRegressor)
_equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) ||
return false
else
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 84d0440a..8853ec4f 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -3,7 +3,6 @@ import Random.seed!
using MLJBase: MLJBase, categorical
using Flux
using StableRNGs
-using MLJ
import LaplaceRedux: LaplaceClassifier, LaplaceRegressor
cv = CV(; nfolds=3)
From 72020130bb669fe287f955001b0b39b02354a108 Mon Sep 17 00:00:00 2001
From: pat-alt
Date: Tue, 22 Oct 2024 13:08:23 +0200
Subject: [PATCH 50/60] formatting
---
src/direct_mlj.jl | 28 +++++++----------
test/direct_mlj_interface.jl | 61 +++++++++++++-----------------------
test/runtests.jl | 1 -
3 files changed, 33 insertions(+), 57 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 317d4b78..b76cdd93 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -47,10 +47,11 @@ end
LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier}
# for fit:
-MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing))
-
-function MMI.reformat(::LaplaceClassifier, X, y)
+function MMI.reformat(::LaplaceRegressor, X, y)
+ return (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing))
+end
+function MMI.reformat(::LaplaceClassifier, X, y)
X = MLJBase.matrix(X) |> permutedims
y = categorical(y)
labels = y.pool.levels
@@ -61,7 +62,7 @@ end
MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,)
-MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:,I], y[2]))
+MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:, I], y[2]))
MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],)
@doc """
@@ -171,7 +172,6 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
old_model = old_cache[1]
old_state_tree = old_cache[2]
old_loss_history = old_cache[3]
-
epochs = m.epochs
@@ -236,13 +236,14 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
cache = (deepcopy(m), old_state_tree, old_loss_history)
else
- println("The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary")
+ println(
+ "The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary",
+ )
fitresult = (old_la, decode)
report = (loss_history=old_loss_history,)
cache = (deepcopy(m), old_state_tree, old_loss_history)
end
-
elseif MMI.is_same_except(
m,
old_model,
@@ -277,16 +278,12 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
LaplaceRedux.fit!(la, data_loader)
optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
- fitresult = (la,decode)
+ fitresult = (la, decode)
report = (loss_history=old_loss_history,)
cache = (deepcopy(m), old_state_tree, old_loss_history)
-
else
-
- fitresult, cache, report = MLJBase.fit(m, verbosity,X,y)
-
-
+ fitresult, cache, report = MLJBase.fit(m, verbosity, X, y)
end
return fitresult, cache, report
@@ -356,12 +353,9 @@ function _isdefined(object, name)
pnames = propertynames(object)
fnames = fieldnames(typeof(object))
name in pnames && !(name in fnames) && return true
- isdefined(object, name)
+ return isdefined(object, name)
end
-
-
-
function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
if length(chain1.layers) != length(chain2.layers)
return false
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 8853ec4f..c9e5b689 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -8,84 +8,67 @@ import LaplaceRedux: LaplaceClassifier, LaplaceRegressor
cv = CV(; nfolds=3)
@testset "Regression" begin
- flux_model = Chain(
- Dense(4, 10, relu),
- Dense(10, 10, relu),
- Dense(10, 1)
- )
- model = LaplaceRegressor(model=flux_model,epochs=50)
-
+ flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1))
+ model = LaplaceRegressor(; model=flux_model, epochs=50)
+
X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
#train, test = partition(eachindex(y), 0.7); # 70:30 split
mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
- MLJBase.fit!(mach, verbosity=1)
+ MLJBase.fit!(mach; verbosity=1)
#Xnew, ynew = make_regression(3, 4; rng=123)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
MLJBase.predict_mode(mach, X) # point predictions
MLJBase.fitted_params(mach) #fitted params function
MLJBase.training_losses(mach) #training loss history
- model.epochs= 100 #changing number of epochs
+ model.epochs = 100 #changing number of epochs
MLJBase.fit!(mach) #testing update function
- model.epochs= 50 #changing number of epochs to a lower number
+ model.epochs = 50 #changing number of epochs to a lower number
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
- evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0)
-
+ evaluate!(mach; resampling=cv, measure=log_loss, verbosity=0)
# Define a different model
- flux_model_two = Chain(
- Dense(4, 6, relu),
- Dense(6, 1)
- )
+ flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 1))
# test update! fallback to fit!
model.model = flux_model_two
MLJBase.fit!(mach)
- model_two = LaplaceRegressor(model=flux_model_two,epochs=100)
- @test !MLJBase.is_same_except(model,model_two)
-
-
-
+ model_two = LaplaceRegressor(; model=flux_model_two, epochs=100)
+ @test !MLJBase.is_same_except(model, model_two)
end
-
-
@testset "Classification" begin
# Define the model
- flux_model = Chain(
- Dense(4, 10, relu),
- Dense(10, 3)
- )
+ flux_model = Chain(Dense(4, 10, relu), Dense(10, 3))
- model = LaplaceClassifier(model=flux_model,epochs=50)
+ model = LaplaceClassifier(; model=flux_model, epochs=50)
X, y = @load_iris
mach = machine(model, X, y)
- MLJBase.fit!(mach,verbosity=1)
- Xnew = (sepal_length = [6.4, 7.2, 7.4],
- sepal_width = [2.8, 3.0, 2.8],
- petal_length = [5.6, 5.8, 6.1],
- petal_width = [2.1, 1.6, 1.9],)
+ MLJBase.fit!(mach; verbosity=1)
+ Xnew = (
+ sepal_length=[6.4, 7.2, 7.4],
+ sepal_width=[2.8, 3.0, 2.8],
+ petal_length=[5.6, 5.8, 6.1],
+ petal_width=[2.1, 1.6, 1.9],
+ )
yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
pdf.(yhat, "virginica") # probabilities for the "verginica" class
MLJBase.fitted_params(mach) # fitted params
MLJBase.training_losses(mach) #training loss history
- model.epochs= 100 #changing number of epochs
+ model.epochs = 100 #changing number of epochs
MLJBase.fit!(mach) #testing update function
- model.epochs= 50 #changing number of epochs to a lower number
+ model.epochs = 50 #changing number of epochs to a lower number
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
evaluate!(mach; resampling=cv, measure=brier_loss, verbosity=0)
# Define a different model
- flux_model_two = Chain(
- Dense(4, 6, relu),
- Dense(6, 3)
- )
+ flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 3))
model.model = flux_model_two
diff --git a/test/runtests.jl b/test/runtests.jl
index 031224ea..ed16d37b 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -38,5 +38,4 @@ using Test
@testset "ML" begin
include("direct_mlj_interface.jl")
end
-
end
From a05e25fe69db103959f1720243e7e4a4e1064805 Mon Sep 17 00:00:00 2001
From: pat-alt
Date: Tue, 22 Oct 2024 13:12:51 +0200
Subject: [PATCH 51/60] removing mlj_flux
---
src/LaplaceRedux.jl | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl
index 15ab3d4f..08496e0f 100644
--- a/src/LaplaceRedux.jl
+++ b/src/LaplaceRedux.jl
@@ -19,7 +19,6 @@ export fit!, predict
export optimize_prior!,
glm_predictive_distribution, posterior_covariance, posterior_precision
-include("mlj_flux.jl")
export LaplaceClassification
export LaplaceRegression
From 05df2e1408db43ac1eff64931efb4522fe1fffa9 Mon Sep 17 00:00:00 2001
From: pat-alt
Date: Tue, 22 Oct 2024 15:11:42 +0200
Subject: [PATCH 52/60] fixed issues
---
.github/workflows/CI.yml | 1 +
Project.toml | 8 +-
src/LaplaceRedux.jl | 2 -
src/direct_mlj.jl | 4 +-
src/mlj_flux.jl | 494 -----------------------------------
test/Manifest.toml | 101 ++++---
test/direct_mlj_interface.jl | 19 +-
test/runtests.jl | 2 +-
8 files changed, 85 insertions(+), 546 deletions(-)
delete mode 100644 src/mlj_flux.jl
diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 64bf7af1..22e67adc 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -22,6 +22,7 @@ jobs:
version:
- '1.9'
- '1.10'
+ - '1'
os:
- ubuntu-latest
- windows-latest
diff --git a/Project.toml b/Project.toml
index f9cb74ea..d78ce06a 100644
--- a/Project.toml
+++ b/Project.toml
@@ -12,7 +12,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
-MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
@@ -33,18 +32,17 @@ Distributions = "0.25.109"
Flux = "0.12, 0.13, 0.14"
LinearAlgebra = "1.7, 1.10"
MLJBase = "1"
-MLJFlux = "0.5"
MLJModelInterface = "1.8.0"
MLUtils = "0.4"
Optimisers = "0.2, 0.3"
ProgressMeter = "1.7.2"
-Random = "1.9, 1.10"
+Random = "1"
Statistics = "1"
Tables = "1.10.1"
-Test = "1.9, 1.10"
+Test = "1"
Tullio = "0.3.5"
Zygote = "0.6"
-julia = "1.9, 1.10"
+julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl
index 08496e0f..b6324f95 100644
--- a/src/LaplaceRedux.jl
+++ b/src/LaplaceRedux.jl
@@ -19,8 +19,6 @@ export fit!, predict
export optimize_prior!,
glm_predictive_distribution, posterior_covariance, posterior_precision
-export LaplaceClassification
-export LaplaceRegression
include("calibration_functions.jl")
export empirical_frequency_binary_classification,
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index b76cdd93..42ec720a 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -53,7 +53,7 @@ end
function MMI.reformat(::LaplaceClassifier, X, y)
X = MLJBase.matrix(X) |> permutedims
- y = categorical(y)
+ y = MLJBase.categorical(y)
labels = y.pool.levels
y = Flux.onehotbatch(y, labels) # One-hot encoding
@@ -332,7 +332,7 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy
return false
else
(
- is_same_except(getproperty(m1, name), getproperty(m2, name)) ||
+ MMI.is_same_except(getproperty(m1, name), getproperty(m2, name)) ||
getproperty(m1, name) isa AbstractRNG ||
getproperty(m2, name) isa AbstractRNG ||
(
diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl
deleted file mode 100644
index 0b1fa68f..00000000
--- a/src/mlj_flux.jl
+++ /dev/null
@@ -1,494 +0,0 @@
-using Flux
-using MLJFlux
-using ProgressMeter: Progress, next!, BarGlyphs
-using Random
-using Tables
-using LinearAlgebra
-using LaplaceRedux
-using ComputationalResources
-using MLJBase: MLJBase
-import MLJModelInterface as MMI
-using Optimisers: Optimisers
-
-"""
- MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic
-
-A mutable struct representing a Laplace regression model that extends the `MLJFlux.MLJFluxProbabilistic` abstract type.
-It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
-
-The model is defined by the following default parameters for all `MLJFlux` models:
-
-- `builder`: a Flux model that constructs the neural network.
-- `optimiser`: a Flux optimiser.
-- `loss`: a loss function that takes the predicted output and the true output as arguments.
-- `epochs`: the number of epochs.
-- `batch_size`: the size of a batch.
-- `lambda`: the regularization strength.
-- `alpha`: the regularization mix (0 for all l2, 1 for all l1).
-- `rng`: a random number generator.
-- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining.
-- `acceleration`: the computational resource to use.
-
-The model also has the following parameters, which are specific to the Laplace approximation:
-
-- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
-- `subnetwork_indices`: the indices of the subnetworks.
-- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
-- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
-- `σ`: the standard deviation of the prior distribution.
-- `μ₀`: the mean of the prior distribution.
-- `P₀`: the covariance matrix of the prior distribution.
-- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
-- `fit_prior_nsteps`: the number of steps used to fit the priors.
-"""
-MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic
- builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish)
- optimiser = Optimisers.Adam()
- loss = Flux.Losses.mse
- epochs::Int = 10::(_ > 0)
- batch_size::Int = 1::(_ > 0)
- lambda::Float64 = 1.0
- alpha::Float64 = 0.0
- rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG
- optimiser_changes_trigger_retraining::Bool = false::(_ in (true, false))
- acceleration = CPU1()::(_ in (CPU1(), CUDALibs()))
- subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
- subnetwork_indices = nothing
- hessian_structure::Union{HessianStructure,Symbol,String} =
- :full::(_ in (:full, :diagonal))
- backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
- σ::Float64 = 1.0
- μ₀::Float64 = 0.0
- P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
- ret_distr::Bool = false::(_ in (true, false))
- fit_prior_nsteps::Int = 100::(_ > 0)
-end
-
-"""
- MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic
-
-A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type.
-It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
-
-The model is defined by the following default parameters for all `MLJFlux` models:
-- `builder`: a Flux model that constructs the neural network.
-- `finaliser`: a Flux model that processes the output of the neural network.
-- `optimiser`: a Flux optimiser.
-- `loss`: a loss function that takes the predicted output and the true output as arguments.
-- `epochs`: the number of epochs.
-- `batch_size`: the size of a batch.
-- `lambda`: the regularization strength.
-- `alpha`: the regularization mix (0 for all l2, 1 for all l1).
-- `rng`: a random number generator.
-- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining.
-- `acceleration`: the computational resource to use.
-
-The model also has the following parameters, which are specific to the Laplace approximation:
-
-- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
-- `subnetwork_indices`: the indices of the subnetworks.
-- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
-- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
-- `σ`: the standard deviation of the prior distribution.
-- `μ₀`: the mean of the prior distribution.
-- `P₀`: the covariance matrix of the prior distribution.
-- `link_approx`: the link approximation to use, either `:probit` or `:plugin`.
-- `predict_proba`: a boolean that select whether to predict probabilities or not.
-- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
-- `fit_prior_nsteps`: the number of steps used to fit the priors.
-"""
-MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic
- builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish)
- finaliser = Flux.softmax
- optimiser = Optimisers.Adam()
- loss = Flux.crossentropy
- epochs::Int = 10::(_ > 0)
- batch_size::Int = 1::(_ > 0)
- lambda::Float64 = 1.0
- alpha::Float64 = 0.0
- rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG
- optimiser_changes_trigger_retraining::Bool = false::(_ in (true, false))
- acceleration = CPU1()::(_ in (CPU1(), CUDALibs()))
- subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
- subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([])
- hessian_structure::Union{HessianStructure,Symbol,String} =
- :full::(_ in (:full, :diagonal))
- backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
- σ::Float64 = 1.0
- μ₀::Float64 = 0.0
- P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
- link_approx::Symbol = :probit::(_ in (:probit, :plugin))
- predict_proba::Bool = true::(_ in (true, false))
- ret_distr::Bool = false::(_ in (true, false))
- fit_prior_nsteps::Int = 100::(_ > 0)
-end
-
-const MLJ_Laplace = Union{LaplaceClassification,LaplaceRegression}
-
-"""
- MLJFlux.shape(model::LaplaceRegression, X, y)
-
-Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.
-
-# Arguments
-- `model::LaplaceRegression`: The LaplaceRegression model to fit.
-- `X`: The input data for training.
-- `y`: The target labels for training one-hot encoded.
-
-# Returns
-- (input size, output size)
-"""
-function MLJFlux.shape(model::LaplaceRegression, X, y)
- X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
- n_input = size(X, 2)
- dims = size(y)
- if length(dims) == 1
- n_output = 1
- else
- n_output = dims[1]
- end
- return (n_input, n_output)
-end
-
-"""
- MLJFlux.build(model::LaplaceRegression, rng, shape)
-
-Builds an MLJFlux model for Laplace regression compatible with the dimensions of the input and output layers specified by `shape`.
-
-# Arguments
-- `model::LaplaceRegression`: The Laplace regression model.
-- `rng`: A random number generator to ensure reproducibility.
-- `shape`: A tuple or array specifying the dimensions of the input and output layers.
-
-# Returns
-- The constructed MLJFlux model, compatible with the specified input and output dimensions.
-"""
-function MLJFlux.build(model::LaplaceRegression, rng, shape)
- chain = MLJFlux.build(model.builder, rng, shape...)
- return chain
-end
-
-"""
- MLJFlux.fitresult(model::LaplaceRegression, chain, y)
-
-Computes the fit result for a Laplace Regression model, returning the model chain and the number of output variables in the target data.
-
-# Arguments
-- `model::LaplaceRegression`: The Laplace Regression model to be evaluated.
-- `chain`: The trained model chain.
-- `y`: The target data, typically a vector of class labels.
-
-# Returns
- A tuple containing:
- - The trained Flux chain.
- - a deepcopy of the laplace model.
-"""
-function MLJFlux.fitresult(model::LaplaceRegression, chain, y)
- return (chain, deepcopy(model))
-end
-
-"""
- MLJFlux.train(model::LaplaceRegression, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)
-
-Fit the LaplaceRegression model using Flux.jl.
-
-# Arguments
-- `model::LaplaceRegression`: The LaplaceRegression model.
-- `regularized_optimiser`: the regularized optimiser to apply to the loss function.
-- `optimiser_state`: thestate of the optimiser.
-- `epochs`: The number of epochs for training.
-- `verbosity`: The verbosity level for training.
-- `X`: The input data for training.
-- `y`: The target labels for training.
-
-# Returns (la, optimiser_state, history )
-where
-- `la`: the fitted Laplace model.
-- `optimiser_state`: the state of the optimiser.
-- `history`: the training loss history.
-"""
-function MLJFlux.train(
- model::LaplaceRegression,
- chain,
- regularized_optimiser,
- optimiser_state,
- epochs,
- verbosity,
- X,
- y,
-)
- X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
-
- # Initialize history:
- history = []
- verbose_laplace = false
- # intitialize and start progress meter:
- meter = Progress(
- epochs + 1;
- dt=1.0,
- desc="Optimising neural net:",
- barglyphs=BarGlyphs("[=> ]"),
- barlen=25,
- color=:yellow,
- )
- verbosity != 1 || next!(meter)
-
- # initiate history:
- loss = model.loss
- losses = (loss(chain(X[i]), y[i]) for i in 1:length(y))
- history = [mean(losses)]
-
- for i in 1:epochs
- chain, optimiser_state, current_loss = MLJFlux.train_epoch(
- model, chain, regularized_optimiser, optimiser_state, X, y
- )
- verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))"
- verbosity != 1 || next!(meter)
- push!(history, current_loss)
- end
-
- if !isa(chain, AbstractLaplace)
- la = LaplaceRedux.Laplace(
- chain;
- likelihood=:regression,
- subset_of_weights=model.subset_of_weights,
- subnetwork_indices=model.subnetwork_indices,
- hessian_structure=model.hessian_structure,
- backend=model.backend,
- σ=model.σ,
- μ₀=model.μ₀,
- P₀=model.P₀,
- )
- else
- la = chain
- end
-
- # fit the Laplace model:
- LaplaceRedux.fit!(la, zip(X, y))
- optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
-
- return la, optimiser_state, history
-end
-
-"""
- predict(model::LaplaceRegression, Xnew)
-
-Predict the output for new input data using a Laplace regression model.
-
-# Arguments
-- `model::LaplaceRegression`: The trained Laplace regression model.
-- the fitresult output produced by MLJFlux.fit!
-- `Xnew`: The new input data.
-
-# Returns
-- The predicted output for the new input data.
-
-"""
-function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew)
- Xnew = MLJBase.matrix(Xnew) |> permutedims
- la = fitresult[1]
- yhat = LaplaceRedux.predict(la, Xnew; ret_distr=model.ret_distr)
- return yhat
-end
-
-"""
- MLJFlux.shape(model::LaplaceClassification, X, y)
-
-Compute the the number of features of the dataset X and the number of unique classes in y.
-
-# Arguments
-- `model::LaplaceClassification`: The LaplaceClassification model to fit.
-- `X`: The input data for training.
-- `y`: The target labels for training one-hot encoded.
-
-# Returns
-- (input size, output size)
-"""
-
-function MLJFlux.shape(model::LaplaceClassification, X, y)
- X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
- n_input = size(X, 2)
- levels = unique(y)
- n_output = length(levels)
- return (n_input, n_output)
-end
-
-"""
- MLJFlux.build(model::LaplaceClassification, rng, shape)
-
-Builds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by `shape`.
-
-# Arguments
-- `model::LaplaceClassification`: The Laplace classification model.
-- `rng`: A random number generator to ensure reproducibility.
-- `shape`: A tuple or array specifying the dimensions of the input and output layers.
-
-# Returns
-- The constructed MLJFlux model, compatible with the specified input and output dimensions.
-"""
-function MLJFlux.build(model::LaplaceClassification, rng, shape)
- chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser)
-
- return chain
-end
-
-"""
- MLJFlux.fitresult(model::LaplaceClassification, chain, y)
-
-Computes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data.
-
-# Arguments
-- `model::LaplaceClassification`: The Laplace classification model to be evaluated.
-- `chain`: The trained model chain.
-- `y`: The target data, typically a vector of class labels.
-
-# Returns
-# Returns
- A tuple containing:
- - The trained Flux chain.
- - a deepcopy of the laplace model.
-"""
-function MLJFlux.fitresult(model::LaplaceClassification, chain, y)
- return (chain, deepcopy(model))
-end
-
-"""
- MLJFlux.train(model::LaplaceClassification, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)
-
-Fit the LaplaceRegression model using Flux.jl.
-
-# Arguments
-- `model::LaplaceClassification`: The LaplaceClassification model.
-- `regularized_optimiser`: the regularized optimiser to apply to the loss function.
-- `optimiser_state`: thestate of the optimiser.
-- `epochs`: The number of epochs for training.
-- `verbosity`: The verbosity level for training.
-- `X`: The input data for training.
-- `y`: The target labels for training.
-
-# Returns (fitresult, cache, report )
-where
-- `la`: the fitted Laplace model.
-- `optimiser_state`: the state of the optimiser.
-- `history`: the training loss history.
-"""
-function MLJFlux.train(
- model::LaplaceClassification,
- chain,
- regularized_optimiser,
- optimiser_state,
- epochs,
- verbosity,
- X,
- y,
-)
- X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
-
- # Initialize history:
- history = []
- verbose_laplace = false
- # intitialize and start progress meter:
- meter = Progress(
- epochs + 1;
- dt=1.0,
- desc="Optimising neural net:",
- barglyphs=BarGlyphs("[=> ]"),
- barlen=25,
- color=:yellow,
- )
- verbosity != 1 || next!(meter)
-
- # initiate history:
- loss = model.loss
- n_batches = length(y)
- losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches)
- history = [mean(losses)]
-
- for i in 1:epochs
- chain, optimiser_state, current_loss = MLJFlux.train_epoch(
- model, chain, regularized_optimiser, optimiser_state, X, y
- )
- verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))"
- verbosity != 1 || next!(meter)
- push!(history, current_loss)
- end
-
- if !isa(chain, AbstractLaplace)
- la = LaplaceRedux.Laplace(
- chain;
- likelihood=:regression,
- subset_of_weights=model.subset_of_weights,
- subnetwork_indices=model.subnetwork_indices,
- hessian_structure=model.hessian_structure,
- backend=model.backend,
- σ=model.σ,
- μ₀=model.μ₀,
- P₀=model.P₀,
- )
- else
- la = chain
- end
-
- # fit the Laplace model:
- LaplaceRedux.fit!(la, zip(X, y))
- optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
-
- return la, optimiser_state, history
-end
-
-"""
- predict(model::LaplaceClassification, Xnew)
-
-Predicts the class labels for new data using the LaplaceClassification model.
-
-# Arguments
-- `model::LaplaceClassification`: The trained LaplaceClassification model.
-- fitresult: the fitresult output produced by MLJFlux.fit!
-- `Xnew`: The new data to make predictions on.
-
-# Returns
-An array of predicted class labels.
-
-"""
-function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew)
- la = fitresult[1]
- Xnew = MLJBase.matrix(Xnew) |> permutedims
- predictions = LaplaceRedux.predict(
- la,
- Xnew;
- link_approx=model.link_approx,
- predict_proba=model.predict_proba,
- ret_distr=model.ret_distr,
- )
-
- return predictions
-end
-
-# metadata for each model,
-MLJBase.metadata_model(
- LaplaceClassification;
- input=Union{
- AbstractMatrix{MLJBase.Finite},
- MLJBase.Table(MLJBase.Finite),
- AbstractMatrix{MLJBase.Continuous},
- MLJBase.Table(MLJBase.Continuous),
- MLJBase.Table{AbstractVector{MLJBase.Continuous}},
- MLJBase.Table{AbstractVector{MLJBase.Finite}},
- },
- target=Union{AbstractArray{MLJBase.Finite},AbstractArray{MLJBase.Continuous}},
- path="MLJFlux.LaplaceClassification",
-)
-# metadata for each model,
-MLJBase.metadata_model(
- LaplaceRegression;
- input=Union{
- AbstractMatrix{MLJBase.Continuous},
- MLJBase.Table(MLJBase.Continuous),
- AbstractMatrix{MLJBase.Finite},
- MLJBase.Table(MLJBase.Finite),
- MLJBase.Table{AbstractVector{MLJBase.Continuous}},
- MLJBase.Table{AbstractVector{MLJBase.Finite}},
- },
- target=Union{AbstractArray{MLJBase.Finite},AbstractArray{MLJBase.Continuous}},
- path="MLJFlux.LaplaceRegression",
-)
diff --git a/test/Manifest.toml b/test/Manifest.toml
index 3448bc9a..bed7d79f 100644
--- a/test/Manifest.toml
+++ b/test/Manifest.toml
@@ -1,14 +1,14 @@
# This file is machine-generated - editing it directly is not advised
-julia_version = "1.10.5"
+julia_version = "1.11.1"
manifest_format = "2.0"
project_hash = "48e3a5a4625c4493599b02acbfe0e972463bd78f"
[[deps.ARFFFiles]]
deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"]
-git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409"
+git-tree-sha1 = "678eb18590a8bc6674363da4d5faa4ac09c40a18"
uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8"
-version = "1.4.1"
+version = "1.5.0"
[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
@@ -80,7 +80,7 @@ version = "2.3.0"
[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
-version = "1.1.1"
+version = "1.1.2"
[[deps.Arpack]]
deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"]
@@ -96,6 +96,7 @@ version = "3.5.1+1"
[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+version = "1.11.0"
[[deps.Atomix]]
deps = ["UnsafeAtomics"]
@@ -144,6 +145,7 @@ version = "0.4.3"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
+version = "1.11.0"
[[deps.Baselet]]
git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
@@ -162,9 +164,9 @@ version = "1.2.2"
[[deps.Bzip2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd"
+git-tree-sha1 = "8873e196c2eb87962a2048b3b8e08946535864a1"
uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0"
-version = "1.0.8+1"
+version = "1.0.8+2"
[[deps.CEnum]]
git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc"
@@ -173,9 +175,9 @@ version = "0.5.0"
[[deps.CSV]]
deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
-git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab"
+git-tree-sha1 = "deddd8725e5e1cc49ee205a1964256043720a6c3"
uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
-version = "0.10.14"
+version = "0.10.15"
[[deps.Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
@@ -394,6 +396,7 @@ version = "1.0.0"
[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
+version = "1.11.0"
[[deps.Dbus_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"]
@@ -444,6 +447,7 @@ weakdeps = ["ChainRulesCore", "SparseArrays"]
[[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
+version = "1.11.0"
[[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
@@ -539,6 +543,7 @@ weakdeps = ["Mmap", "Test"]
[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
+version = "1.11.0"
[[deps.FillArrays]]
deps = ["LinearAlgebra"]
@@ -624,6 +629,7 @@ version = "0.4.12"
[[deps.Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
+version = "1.11.0"
[[deps.GLFW_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"]
@@ -776,6 +782,7 @@ version = "1.4.2"
[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
+version = "1.11.0"
[[deps.InternedStrings]]
deps = ["Random", "Test"]
@@ -816,9 +823,9 @@ version = "1.0.0"
[[deps.JLD2]]
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
-git-tree-sha1 = "aeab5c68eb2cf326619bf71235d8f4561c62fe22"
+git-tree-sha1 = "b464b9b461ee989b435a689a4f7d870b68d467ed"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
-version = "0.5.5"
+version = "0.5.6"
[[deps.JLFzf]]
deps = ["Pipe", "REPL", "Random", "fzf_jll"]
@@ -840,9 +847,9 @@ version = "0.21.4"
[[deps.JSON3]]
deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"]
-git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b"
+git-tree-sha1 = "1d322381ef7b087548321d3f878cb4c9bd8f8f9b"
uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
-version = "1.14.0"
+version = "1.14.1"
[deps.JSON3.extensions]
JSON3ArrowExt = ["ArrowTypes"]
@@ -948,6 +955,7 @@ version = "1.9.0"
[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
+version = "1.11.0"
[[deps.LazyModules]]
git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e"
@@ -968,16 +976,17 @@ version = "0.6.4"
[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
-version = "8.4.0+0"
+version = "8.6.0+0"
[[deps.LibGit2]]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
+version = "1.11.0"
[[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
-version = "1.6.4+0"
+version = "1.7.2+0"
[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
@@ -986,6 +995,7 @@ version = "1.11.0+1"
[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
+version = "1.11.0"
[[deps.Libffi_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1038,6 +1048,7 @@ version = "2.40.1+0"
[[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+version = "1.11.0"
[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
@@ -1057,6 +1068,7 @@ version = "0.3.28"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
+version = "1.11.0"
[[deps.LoggingExtras]]
deps = ["Dates", "Logging"]
@@ -1071,10 +1083,10 @@ uuid = "23992714-dd62-5051-b70f-ba57cb901cac"
version = "0.10.7"
[[deps.MLDataDevices]]
-deps = ["Adapt", "Functors", "Preferences", "Random"]
-git-tree-sha1 = "e16288e37e76d68c3f1c418e0a2bec88d98d55fc"
+deps = ["Adapt", "Compat", "Functors", "Preferences", "Random"]
+git-tree-sha1 = "f19f2629ad20176e524c71d06e1c29689ab002fa"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
-version = "1.2.0"
+version = "1.4.0"
[deps.MLDataDevices.extensions]
MLDataDevicesAMDGPUExt = "AMDGPU"
@@ -1235,6 +1247,7 @@ version = "0.4.2"
[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
+version = "1.11.0"
[[deps.MbedTLS]]
deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"]
@@ -1245,7 +1258,7 @@ version = "1.1.9"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
-version = "2.28.2+1"
+version = "2.28.6+0"
[[deps.Measures]]
git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102"
@@ -1284,6 +1297,7 @@ version = "1.2.0"
[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
+version = "1.11.0"
[[deps.MosaicViews]]
deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"]
@@ -1293,7 +1307,7 @@ version = "0.3.4"
[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
-version = "2023.1.10"
+version = "2023.12.12"
[[deps.MultivariateStats]]
deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"]
@@ -1369,7 +1383,7 @@ version = "0.2.5"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
-version = "0.3.23+4"
+version = "0.3.27+1"
[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
@@ -1494,9 +1508,13 @@ uuid = "30392449-352a-5448-841d-b1acce4e97dc"
version = "0.43.4+0"
[[deps.Pkg]]
-deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
+deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
-version = "1.10.0"
+version = "1.11.0"
+weakdeps = ["REPL"]
+
+ [deps.Pkg.extensions]
+ REPLExt = "REPL"
[[deps.PlotThemes]]
deps = ["PlotUtils", "Statistics"]
@@ -1505,10 +1523,10 @@ uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a"
version = "3.2.0"
[[deps.PlotUtils]]
-deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"]
-git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5"
+deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"]
+git-tree-sha1 = "650a022b2ce86c7dcfbdecf00f78afeeb20e5655"
uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043"
-version = "1.4.1"
+version = "1.4.2"
[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"]
@@ -1567,6 +1585,7 @@ version = "2.4.0"
[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
+version = "1.11.0"
[[deps.ProgressLogging]]
deps = ["Logging", "SHA", "UUIDs"]
@@ -1622,12 +1641,14 @@ version = "2.11.1"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
[[deps.REPL]]
-deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
+deps = ["InteractiveUtils", "Markdown", "Sockets", "StyledStrings", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
+version = "1.11.0"
[[deps.Random]]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+version = "1.11.0"
[[deps.RealDot]]
deps = ["LinearAlgebra"]
@@ -1711,6 +1732,7 @@ version = "1.4.5"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+version = "1.11.0"
[[deps.Setfield]]
deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"]
@@ -1742,6 +1764,7 @@ version = "0.9.4"
[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
+version = "1.11.0"
[[deps.SortingAlgorithms]]
deps = ["DataStructures"]
@@ -1752,7 +1775,7 @@ version = "1.2.1"
[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
-version = "1.10.0"
+version = "1.11.0"
[[deps.SparseInverseSubset]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
@@ -1790,9 +1813,9 @@ version = "0.1.1"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
-git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50"
+git-tree-sha1 = "777657803913ffc7e8cc20f0fd04b634f871af8f"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
-version = "1.9.7"
+version = "1.9.8"
weakdeps = ["ChainRulesCore", "Statistics"]
[deps.StaticArrays.extensions]
@@ -1831,9 +1854,14 @@ uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9"
version = "3.4.0"
[[deps.Statistics]]
-deps = ["LinearAlgebra", "SparseArrays"]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0"
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
-version = "1.10.0"
+version = "1.11.1"
+weakdeps = ["SparseArrays"]
+
+ [deps.Statistics.extensions]
+ SparseArraysExt = ["SparseArrays"]
[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
@@ -1901,6 +1929,10 @@ git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8"
uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
version = "1.11.0"
+[[deps.StyledStrings]]
+uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
+version = "1.11.0"
+
[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
@@ -1908,7 +1940,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
-version = "7.2.1+1"
+version = "7.7.0+0"
[[deps.TOML]]
deps = ["Dates"]
@@ -1952,6 +1984,7 @@ version = "0.1.1"
[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+version = "1.11.0"
[[deps.TranscodingStreams]]
git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742"
@@ -2011,6 +2044,7 @@ version = "1.5.1"
[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
+version = "1.11.0"
[[deps.UnPack]]
git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b"
@@ -2019,6 +2053,7 @@ version = "1.0.2"
[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
+version = "1.11.0"
[[deps.UnicodeFun]]
deps = ["REPL"]
@@ -2381,7 +2416,7 @@ version = "1.1.6+0"
[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
-version = "1.52.0+1"
+version = "1.59.0+0"
[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index c9e5b689..f4b6e5fc 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -1,19 +1,20 @@
using Random: Random
import Random.seed!
using MLJBase: MLJBase, categorical
+using MLJ: MLJ
using Flux
using StableRNGs
import LaplaceRedux: LaplaceClassifier, LaplaceRegressor
-cv = CV(; nfolds=3)
+cv = MLJBase.CV(; nfolds=3)
@testset "Regression" begin
flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1))
model = LaplaceRegressor(; model=flux_model, epochs=50)
- X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
+ X, y = MLJ.make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
#train, test = partition(eachindex(y), 0.7); # 70:30 split
- mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
+ mach = MLJ.machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
MLJBase.fit!(mach; verbosity=1)
#Xnew, ynew = make_regression(3, 4; rng=123)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
@@ -27,7 +28,7 @@ cv = CV(; nfolds=3)
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
- evaluate!(mach; resampling=cv, measure=log_loss, verbosity=0)
+ MLJ.evaluate!(mach; resampling=cv, measure=MLJ.log_loss, verbosity=0)
# Define a different model
flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 1))
@@ -45,8 +46,8 @@ end
model = LaplaceClassifier(; model=flux_model, epochs=50)
- X, y = @load_iris
- mach = machine(model, X, y)
+ X, y = MLJ.@load_iris
+ mach = MLJ.machine(model, X, y)
MLJBase.fit!(mach; verbosity=1)
Xnew = (
sepal_length=[6.4, 7.2, 7.4],
@@ -55,8 +56,8 @@ end
petal_width=[2.1, 1.6, 1.9],
)
yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
- predict_mode(mach, Xnew) # point predictions
- pdf.(yhat, "virginica") # probabilities for the "verginica" class
+ MLJBase.predict_mode(mach, Xnew) # point predictions
+ MLJBase.pdf.(yhat, "virginica") # probabilities for the "verginica" class
MLJBase.fitted_params(mach) # fitted params
MLJBase.training_losses(mach) #training loss history
model.epochs = 100 #changing number of epochs
@@ -65,7 +66,7 @@ end
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
- evaluate!(mach; resampling=cv, measure=brier_loss, verbosity=0)
+ MLJ.evaluate!(mach; resampling=cv, measure=MLJ.brier_loss, verbosity=0)
# Define a different model
flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 3))
diff --git a/test/runtests.jl b/test/runtests.jl
index ed16d37b..1184e74a 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -35,7 +35,7 @@ using Test
include("krondecomposed.jl")
end
- @testset "ML" begin
+ @testset "MLJ" begin
include("direct_mlj_interface.jl")
end
end
From 02abec26671c0b2f8911e03e06460291301332ce Mon Sep 17 00:00:00 2001
From: pat-alt
Date: Wed, 23 Oct 2024 10:03:57 +0200
Subject: [PATCH 53/60] removing reference to deep_propertier
---
docs/Manifest.toml | 414 ++++++++++++++++++++++++++++-----------------
docs/Project.toml | 1 +
src/direct_mlj.jl | 4 +-
3 files changed, 266 insertions(+), 153 deletions(-)
diff --git a/docs/Manifest.toml b/docs/Manifest.toml
index a97bfd27..d01b6356 100644
--- a/docs/Manifest.toml
+++ b/docs/Manifest.toml
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised
-julia_version = "1.10.5"
+julia_version = "1.11.1"
manifest_format = "2.0"
-project_hash = "616a9e89f5c520a58672ad91b5525001e0dadab3"
+project_hash = "aafefca0e42df191a3fa625af8a8b686a9db4944"
[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
@@ -26,31 +26,35 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.4.5"
[[deps.Accessors]]
-deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"]
-git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a"
+deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"]
+git-tree-sha1 = "b392ede862e506d451fc1616e79aa6f4c673dab8"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
-version = "0.1.37"
+version = "0.1.38"
[deps.Accessors.extensions]
AccessorsAxisKeysExt = "AxisKeys"
+ AccessorsDatesExt = "Dates"
AccessorsIntervalSetsExt = "IntervalSets"
AccessorsStaticArraysExt = "StaticArrays"
AccessorsStructArraysExt = "StructArrays"
+ AccessorsTestExt = "Test"
AccessorsUnitfulExt = "Unitful"
[deps.Accessors.weakdeps]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
+ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
+ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
-git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099"
+git-tree-sha1 = "d80af0733c99ea80575f612813fa6aa71022d33a"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
-version = "4.0.4"
+version = "4.1.0"
weakdeps = ["StaticArrays"]
[deps.Adapt.extensions]
@@ -69,7 +73,7 @@ version = "2.3.0"
[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
-version = "1.1.1"
+version = "1.1.2"
[[deps.Arpack]]
deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"]
@@ -115,6 +119,7 @@ version = "7.16.0"
[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+version = "1.11.0"
[[deps.Atomix]]
deps = ["UnsafeAtomics"]
@@ -157,6 +162,7 @@ version = "0.4.3"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
+version = "1.11.0"
[[deps.Baselet]]
git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
@@ -170,9 +176,9 @@ version = "0.1.9"
[[deps.Bzip2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd"
+git-tree-sha1 = "8873e196c2eb87962a2048b3b8e08946535864a1"
uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0"
-version = "1.0.8+1"
+version = "1.0.8+2"
[[deps.CEnum]]
git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc"
@@ -181,15 +187,15 @@ version = "0.5.0"
[[deps.CSV]]
deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
-git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab"
+git-tree-sha1 = "deddd8725e5e1cc49ee205a1964256043720a6c3"
uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
-version = "0.10.14"
+version = "0.10.15"
[[deps.CUDA]]
-deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"]
-git-tree-sha1 = "fdd9dfb67dfefd548f51000cc400bb51003de247"
+deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics", "demumble_jll"]
+git-tree-sha1 = "e0725a467822697171af4dae15cec10b4fc19053"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
-version = "5.4.3"
+version = "5.5.2"
[deps.CUDA.extensions]
ChainRulesCoreExt = "ChainRulesCore"
@@ -203,9 +209,9 @@ version = "5.4.3"
[[deps.CUDA_Driver_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "325058b426c2b421e3d2df3d5fa646d72d2e3e7e"
+git-tree-sha1 = "ccd1e54610c222fadfd4737dac66bff786f63656"
uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
-version = "0.9.2+0"
+version = "0.10.3+0"
[[deps.CUDA_Runtime_Discovery]]
deps = ["Libdl"]
@@ -215,21 +221,21 @@ version = "0.3.5"
[[deps.CUDA_Runtime_jll]]
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "afea94249b821dc754a8ca6695d3daed851e1f5a"
+git-tree-sha1 = "e43727b237b2879a34391eeb81887699a26f8f2f"
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
-version = "0.14.1+0"
+version = "0.15.3+0"
[[deps.CUDNN_jll]]
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4"
+git-tree-sha1 = "9851af16a2f357a793daa0f13634c82bc7e40419"
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
-version = "9.0.0+1"
+version = "9.4.0+0"
[[deps.Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd"
+git-tree-sha1 = "009060c9a6168704143100f36ab08f06c2af4642"
uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a"
-version = "1.18.0+2"
+version = "1.18.2+1"
[[deps.CategoricalArrays]]
deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"]
@@ -263,15 +269,15 @@ version = "0.1.15"
[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"]
-git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03"
+git-tree-sha1 = "be227d253d132a6d57f9ccf5f67c0fb6488afd87"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
-version = "1.69.0"
+version = "1.71.0"
[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
-git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f"
+git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
-version = "1.24.0"
+version = "1.25.0"
weakdeps = ["SparseArrays"]
[deps.ChainRulesCore.extensions]
@@ -285,9 +291,9 @@ version = "0.7.6"
[[deps.ColorSchemes]]
deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"]
-git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0"
+git-tree-sha1 = "13951eb68769ad1cd460cdb2e64e5e95f1bf123d"
uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
-version = "3.26.0"
+version = "3.27.0"
[[deps.ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
@@ -411,10 +417,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.16.0"
[[deps.DataFrames]]
-deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
-git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8"
+deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
+git-tree-sha1 = "fb61b4812c49343d7ef0b533ba982c46021938a6"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
-version = "1.6.1"
+version = "1.7.0"
[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
@@ -430,6 +436,7 @@ version = "1.0.0"
[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
+version = "1.11.0"
[[deps.Dbus_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"]
@@ -486,12 +493,13 @@ weakdeps = ["ChainRulesCore", "SparseArrays"]
[[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
+version = "1.11.0"
[[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
-git-tree-sha1 = "e6c693a0e4394f8fda0e51a5bdf5aef26f8235e9"
+git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
-version = "0.25.111"
+version = "0.25.112"
weakdeps = ["ChainRulesCore", "DensityInterface", "Test"]
[deps.Distributions.extensions]
@@ -499,6 +507,12 @@ weakdeps = ["ChainRulesCore", "DensityInterface", "Test"]
DistributionsDensityInterfaceExt = "DensityInterface"
DistributionsTestExt = "Test"
+[[deps.DocInventories]]
+deps = ["CodecZlib", "Downloads", "TOML"]
+git-tree-sha1 = "e97cfa8680a39396924dcdca4b7ff1014ed5c499"
+uuid = "43dc2714-ed3b-44b5-b226-857eda1aa7de"
+version = "1.0.0"
+
[[deps.DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
@@ -507,9 +521,21 @@ version = "0.9.3"
[[deps.Documenter]]
deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"]
-git-tree-sha1 = "9d29b99b6b2b6bc2382a4c8dbec6eb694f389853"
+git-tree-sha1 = "5a1ee886566f2fa9318df1273d8b778b9d42712d"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
-version = "1.6.0"
+version = "1.7.0"
+
+[[deps.DocumenterInterLinks]]
+deps = ["CodecZlib", "DocInventories", "Documenter", "DocumenterInventoryWritingBackport", "Markdown", "MarkdownAST", "TOML"]
+git-tree-sha1 = "00dceb038f6cb24f4d8d6a9f2feb85bbe58305fd"
+uuid = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
+version = "1.0.0"
+
+[[deps.DocumenterInventoryWritingBackport]]
+deps = ["CodecZlib", "Documenter", "TOML"]
+git-tree-sha1 = "1b89024e375353961bb98b9818b44a4e38961cc4"
+uuid = "195adf08-069f-4855-af3e-8933a2cdae94"
+version = "0.1.0"
[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
@@ -541,9 +567,9 @@ version = "0.1.10"
[[deps.FFMPEG]]
deps = ["FFMPEG_jll"]
-git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8"
+git-tree-sha1 = "53ebe7511fa11d33bec688a9178fac4e49eeee00"
uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a"
-version = "0.4.1"
+version = "0.4.2"
[[deps.FFMPEG_jll]]
deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"]
@@ -565,9 +591,9 @@ version = "0.1.1"
[[deps.FileIO]]
deps = ["Pkg", "Requires", "UUIDs"]
-git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322"
+git-tree-sha1 = "62ca0547a14c57e98154423419d8a342dca75ca9"
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
-version = "1.16.3"
+version = "1.16.4"
[[deps.FilePathsBase]]
deps = ["Compat", "Dates"]
@@ -582,6 +608,7 @@ weakdeps = ["Mmap", "Test"]
[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
+version = "1.11.0"
[[deps.FillArrays]]
deps = ["LinearAlgebra"]
@@ -596,19 +623,21 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"]
FillArraysStatisticsExt = "Statistics"
[[deps.FiniteDiff]]
-deps = ["ArrayInterface", "LinearAlgebra", "Setfield", "SparseArrays"]
-git-tree-sha1 = "f9219347ebf700e77ca1d48ef84e4a82a6701882"
+deps = ["ArrayInterface", "LinearAlgebra", "Setfield"]
+git-tree-sha1 = "b10bdafd1647f57ace3885143936749d61638c3b"
uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
-version = "2.24.0"
+version = "2.26.0"
[deps.FiniteDiff.extensions]
FiniteDiffBandedMatricesExt = "BandedMatrices"
FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices"
+ FiniteDiffSparseArraysExt = "SparseArrays"
FiniteDiffStaticArraysExt = "StaticArrays"
[deps.FiniteDiff.weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
+ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[[deps.FixedPointNumbers]]
@@ -618,23 +647,27 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.5"
[[deps.Flux]]
-deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
-git-tree-sha1 = "fbf100b4bed74c9b6fac0ebd1031e04977d35b3b"
+deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
+git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.14.19"
+version = "0.14.22"
[deps.Flux.extensions]
FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxEnzymeExt = "Enzyme"
+ FluxMPIExt = "MPI"
+ FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
FluxMetalExt = "Metal"
[deps.Flux.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
+ NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[[deps.Fontconfig_jll]]
@@ -679,6 +712,7 @@ version = "0.4.12"
[[deps.Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
+version = "1.11.0"
[[deps.GLFW_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"]
@@ -699,22 +733,22 @@ uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
version = "0.1.6"
[[deps.GPUCompiler]]
-deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"]
-git-tree-sha1 = "ab29216184312f99ff957b32cd63c2fe9c928b91"
+deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "PrecompileTools", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"]
+git-tree-sha1 = "1d6f290a5eb1201cd63574fbc4440c788d5cb38f"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
-version = "0.26.7"
+version = "0.27.8"
[[deps.GR]]
deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"]
-git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5"
+git-tree-sha1 = "ee28ddcd5517d54e417182fec3886e7412d3926f"
uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
-version = "0.73.7"
+version = "0.73.8"
[[deps.GR_jll]]
deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d"
+git-tree-sha1 = "f31929b9e67066bee48eec8b03c0df47d31a74b3"
uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
-version = "0.73.7+0"
+version = "0.73.8+0"
[[deps.Gettext_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"]
@@ -730,15 +764,15 @@ version = "1.3.1"
[[deps.Git_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"]
-git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809"
+git-tree-sha1 = "ea372033d09e4552a04fd38361cd019f9003f4f4"
uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb"
-version = "2.44.0+2"
+version = "2.46.2+0"
[[deps.Glib_jll]]
deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"]
-git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba"
+git-tree-sha1 = "674ff0db93fffcd11a3573986e550d66cd4fd71f"
uuid = "7746bdde-850d-59dc-9ae8-88ece973131d"
-version = "2.80.2+0"
+version = "2.80.5+0"
[[deps.Graphite2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -814,11 +848,12 @@ version = "1.4.2"
[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
+version = "1.11.0"
[[deps.InverseFunctions]]
-git-tree-sha1 = "2787db24f4e03daf859c6509ff87764e4182f7d1"
+git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
-version = "0.1.16"
+version = "0.1.17"
weakdeps = ["Dates", "Test"]
[deps.InverseFunctions.extensions]
@@ -848,9 +883,9 @@ version = "1.0.0"
[[deps.JLD2]]
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
-git-tree-sha1 = "a0746c21bdc986d0dc293efa6b1faee112c37c28"
+git-tree-sha1 = "b464b9b461ee989b435a689a4f7d870b68d467ed"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
-version = "0.4.53"
+version = "0.5.6"
[[deps.JLFzf]]
deps = ["Pipe", "REPL", "Random", "fzf_jll"]
@@ -860,9 +895,9 @@ version = "0.1.8"
[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
-git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40"
+git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
-version = "1.6.0"
+version = "1.6.1"
[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
@@ -872,9 +907,9 @@ version = "0.21.4"
[[deps.JpegTurbo_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637"
+git-tree-sha1 = "25ee0be4d43d0269027024d75a24c24d6c6e590c"
uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8"
-version = "3.0.3+0"
+version = "3.0.4+0"
[[deps.JuliaNVTXCallbacks_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -890,9 +925,9 @@ version = "0.2.4"
[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
-git-tree-sha1 = "cb1cff88ef2f3a157cbad75bbe6b229e1975e498"
+git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
-version = "0.9.25"
+version = "0.9.28"
[deps.KernelAbstractions.extensions]
EnzymeExt = "EnzymeCore"
@@ -911,16 +946,16 @@ uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d"
version = "3.100.2+0"
[[deps.LERC_jll]]
-deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434"
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "36bdbc52f13a7d1dcb0f3cd694e01677a515655b"
uuid = "88015f11-f218-50d7-93a8-a6af411a945d"
-version = "3.0.0+1"
+version = "4.0.0+0"
[[deps.LLVM]]
-deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"]
-git-tree-sha1 = "2470e69781ddd70b8878491233cd09bc1bd7fc96"
+deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Unicode"]
+git-tree-sha1 = "d422dfd9707bec6617335dc2ea3c5172a87d5908"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
-version = "8.1.0"
+version = "9.1.3"
weakdeps = ["BFloat16s"]
[deps.LLVM.extensions]
@@ -928,9 +963,9 @@ weakdeps = ["BFloat16s"]
[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "597d1c758c9ae5d985ba4202386a607c675ee700"
+git-tree-sha1 = "05a8bd5a42309a9ec82f700876903abce1017dd3"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
-version = "0.0.31+0"
+version = "0.0.34+0"
[[deps.LLVMLoopInfo]]
git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea"
@@ -939,24 +974,24 @@ version = "1.0.0"
[[deps.LLVMOpenMP_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "e16271d212accd09d52ee0ae98956b8a05c4b626"
+git-tree-sha1 = "78211fb6cbc872f77cad3fc0b6cf647d923f4929"
uuid = "1d63c593-3942-5779-bab2-d838dc0a180e"
-version = "17.0.6+0"
+version = "18.1.7+0"
[[deps.LZO_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d"
+git-tree-sha1 = "854a9c268c43b77b0a27f22d7fab8d33cdb3a731"
uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac"
-version = "2.10.2+0"
+version = "2.10.2+1"
[[deps.LaTeXStrings]]
-git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
+git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c"
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
-version = "1.3.1"
+version = "1.4.0"
[[deps.LaplaceRedux]]
-deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
-path = "C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl"
+deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
+path = ".."
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
version = "1.1.1"
@@ -984,6 +1019,7 @@ version = "1.2.2"
[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
+version = "1.11.0"
[[deps.LearnAPI]]
deps = ["InteractiveUtils", "Statistics"]
@@ -999,16 +1035,17 @@ version = "0.6.4"
[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
-version = "8.4.0+0"
+version = "8.6.0+0"
[[deps.LibGit2]]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
+version = "1.11.0"
[[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
-version = "1.6.4+0"
+version = "1.7.2+0"
[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
@@ -1017,6 +1054,7 @@ version = "1.11.0+1"
[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
+version = "1.11.0"
[[deps.Libffi_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1056,9 +1094,9 @@ version = "2.40.1+0"
[[deps.Libtiff_jll]]
deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"]
-git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a"
+git-tree-sha1 = "b404131d06f7886402758c9ce2214b636eb4d54a"
uuid = "89763e89-9b03-5906-acba-b20f662cd828"
-version = "4.5.1+1"
+version = "4.7.0+0"
[[deps.Libuuid_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
@@ -1075,6 +1113,7 @@ version = "7.3.0"
[[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+version = "1.11.0"
[[deps.LinearMaps]]
deps = ["LinearAlgebra"]
@@ -1106,6 +1145,7 @@ version = "0.3.28"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
+version = "1.11.0"
[[deps.LoggingExtras]]
deps = ["Dates", "Logging"]
@@ -1113,6 +1153,46 @@ git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075"
uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
version = "1.0.3"
+[[deps.MLDataDevices]]
+deps = ["Adapt", "Compat", "Functors", "LinearAlgebra", "Preferences", "Random"]
+git-tree-sha1 = "3207c2e66164e6366440ad3f0243a8d67abb4a47"
+uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
+version = "1.4.1"
+
+ [deps.MLDataDevices.extensions]
+ MLDataDevicesAMDGPUExt = "AMDGPU"
+ MLDataDevicesCUDAExt = "CUDA"
+ MLDataDevicesChainRulesCoreExt = "ChainRulesCore"
+ MLDataDevicesFillArraysExt = "FillArrays"
+ MLDataDevicesGPUArraysExt = "GPUArrays"
+ MLDataDevicesMLUtilsExt = "MLUtils"
+ MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
+ MLDataDevicesReactantExt = "Reactant"
+ MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
+ MLDataDevicesReverseDiffExt = "ReverseDiff"
+ MLDataDevicesSparseArraysExt = "SparseArrays"
+ MLDataDevicesTrackerExt = "Tracker"
+ MLDataDevicesZygoteExt = "Zygote"
+ MLDataDevicescuDNNExt = ["CUDA", "cuDNN"]
+ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
+
+ [deps.MLDataDevices.weakdeps]
+ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
+ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
+ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
+ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
+ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
+ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
+ Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
+ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
+ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
+ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
+ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
+
[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9"
@@ -1175,6 +1255,7 @@ version = "0.5.13"
[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
+version = "1.11.0"
[[deps.MarkdownAST]]
deps = ["AbstractTrees", "Markdown"]
@@ -1191,7 +1272,7 @@ version = "1.1.9"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
-version = "2.28.2+1"
+version = "2.28.6+0"
[[deps.Measures]]
git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102"
@@ -1200,9 +1281,9 @@ version = "0.3.2"
[[deps.Metalhead]]
deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"]
-git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152"
+git-tree-sha1 = "aef476e4958303f5ea9e1deb81a1ba2f510d4e11"
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
-version = "0.9.3"
+version = "0.9.4"
weakdeps = ["CUDA"]
[deps.Metalhead.extensions]
@@ -1222,6 +1303,7 @@ version = "1.2.0"
[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
+version = "1.11.0"
[[deps.Mocking]]
deps = ["Compat", "ExprTools"]
@@ -1231,7 +1313,7 @@ version = "0.8.1"
[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
-version = "2023.1.10"
+version = "2023.12.12"
[[deps.MultivariateStats]]
deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"]
@@ -1246,10 +1328,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c"
version = "7.8.3"
[[deps.NNlib]]
-deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
-git-tree-sha1 = "ae52c156a63bb647f80c26319b104e99e5977e51"
+deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"]
+git-tree-sha1 = "da09a1e112fd75f9af2a5229323f01b56ec96a4c"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.9.22"
+version = "0.9.24"
[deps.NNlib.extensions]
NNlibAMDGPUExt = "AMDGPU"
@@ -1257,12 +1339,14 @@ version = "0.9.22"
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"
+ NNlibForwardDiffExt = "ForwardDiff"
[deps.NNlib.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
+ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[[deps.NVTX]]
@@ -1302,9 +1386,9 @@ version = "0.2.3"
[[deps.NearestNeighbors]]
deps = ["Distances", "StaticArrays"]
-git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0"
+git-tree-sha1 = "3cebfc94a0754cc329ebc3bab1e6c89621e791ad"
uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
-version = "0.4.18"
+version = "0.4.20"
[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
@@ -1325,7 +1409,7 @@ version = "0.2.5"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
-version = "0.3.23+4"
+version = "0.3.27+1"
[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
@@ -1340,9 +1424,9 @@ version = "1.4.3"
[[deps.OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5"
+git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10"
uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
-version = "3.0.14+0"
+version = "3.0.15+1"
[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
@@ -1432,9 +1516,13 @@ uuid = "30392449-352a-5448-841d-b1acce4e97dc"
version = "0.43.4+0"
[[deps.Pkg]]
-deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
+deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
-version = "1.10.0"
+version = "1.11.0"
+weakdeps = ["REPL"]
+
+ [deps.Pkg.extensions]
+ REPLExt = "REPL"
[[deps.PlotThemes]]
deps = ["PlotUtils", "Statistics"]
@@ -1443,10 +1531,10 @@ uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a"
version = "3.2.0"
[[deps.PlotUtils]]
-deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"]
-git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5"
+deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"]
+git-tree-sha1 = "650a022b2ce86c7dcfbdecf00f78afeeb20e5655"
uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043"
-version = "1.4.1"
+version = "1.4.2"
[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"]
@@ -1499,13 +1587,14 @@ version = "0.2.0"
[[deps.PrettyTables]]
deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"]
-git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7"
+git-tree-sha1 = "1101cd475833706e4d0e7b122218257178f48f34"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
-version = "2.3.2"
+version = "2.4.0"
[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
+version = "1.11.0"
[[deps.ProgressLogging]]
deps = ["Logging", "SHA", "UUIDs"]
@@ -1550,9 +1639,9 @@ version = "6.7.1+1"
[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
-git-tree-sha1 = "1d587203cf851a51bf1ea31ad7ff89eff8d625ea"
+git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
-version = "2.11.0"
+version = "2.11.1"
[deps.QuadGK.extensions]
QuadGKEnzymeExt = "Enzyme"
@@ -1573,12 +1662,14 @@ uuid = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
version = "0.7.7"
[[deps.REPL]]
-deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
+deps = ["InteractiveUtils", "Markdown", "Sockets", "StyledStrings", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
+version = "1.11.0"
[[deps.Random]]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+version = "1.11.0"
[[deps.Random123]]
deps = ["Random", "RandomNumbers"]
@@ -1647,15 +1738,15 @@ version = "1.3.0"
[[deps.Rmath]]
deps = ["Random", "Rmath_jll"]
-git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b"
+git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
-version = "0.7.1"
+version = "0.8.0"
[[deps.Rmath_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "e60724fd3beea548353984dc61c943ecddb0e29a"
+git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
-version = "0.4.3+0"
+version = "0.5.1+0"
[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
@@ -1692,6 +1783,7 @@ version = "1.4.5"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+version = "1.11.0"
[[deps.Setfield]]
deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"]
@@ -1711,9 +1803,9 @@ uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f"
version = "1.0.3"
[[deps.SimpleBufferStream]]
-git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1"
+git-tree-sha1 = "f305871d2f381d21527c770d4788c06c097c9bc1"
uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7"
-version = "1.1.0"
+version = "1.2.0"
[[deps.SimpleTraits]]
deps = ["InteractiveUtils", "MacroTools"]
@@ -1723,6 +1815,7 @@ version = "0.9.4"
[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
+version = "1.11.0"
[[deps.SortingAlgorithms]]
deps = ["DataStructures"]
@@ -1733,7 +1826,7 @@ version = "1.2.1"
[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
-version = "1.10.0"
+version = "1.11.0"
[[deps.SparseInverseSubset]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
@@ -1765,9 +1858,9 @@ version = "1.0.2"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
-git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50"
+git-tree-sha1 = "777657803913ffc7e8cc20f0fd04b634f871af8f"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
-version = "1.9.7"
+version = "1.9.8"
weakdeps = ["ChainRulesCore", "Statistics"]
[deps.StaticArrays.extensions]
@@ -1792,9 +1885,14 @@ uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9"
version = "3.4.0"
[[deps.Statistics]]
-deps = ["LinearAlgebra", "SparseArrays"]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0"
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
-version = "1.10.0"
+version = "1.11.1"
+weakdeps = ["SparseArrays"]
+
+ [deps.Statistics.extensions]
+ SparseArraysExt = ["SparseArrays"]
[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
@@ -1810,9 +1908,9 @@ version = "0.34.3"
[[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
-git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a"
+git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
-version = "1.3.1"
+version = "1.3.2"
weakdeps = ["ChainRulesCore", "InverseFunctions"]
[deps.StatsFuns.extensions]
@@ -1821,9 +1919,9 @@ weakdeps = ["ChainRulesCore", "InverseFunctions"]
[[deps.StringManipulation]]
deps = ["PrecompileTools"]
-git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5"
+git-tree-sha1 = "a6b1675a536c5ad1a60e5a5153e1fee12eb146e3"
uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e"
-version = "0.3.4"
+version = "0.4.0"
[[deps.StructArrays]]
deps = ["ConstructionBase", "DataAPI", "Tables"]
@@ -1838,6 +1936,10 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"]
StructArraysSparseArraysExt = "SparseArrays"
StructArraysStaticArraysExt = "StaticArrays"
+[[deps.StyledStrings]]
+uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
+version = "1.11.0"
+
[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
@@ -1845,7 +1947,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
-version = "7.2.1+1"
+version = "7.7.0+0"
[[deps.TOML]]
deps = ["Dates"]
@@ -1854,9 +1956,9 @@ version = "1.0.3"
[[deps.TZJData]]
deps = ["Artifacts"]
-git-tree-sha1 = "1607ad46cf8d642aa779a1d45af1c8620dbf6915"
+git-tree-sha1 = "36b40607bf2bf856828690e097e1c799623b0602"
uuid = "dc5dba14-91b3-4cab-a142-028a31da12f7"
-version = "1.2.0+2024a"
+version = "1.3.0+2024b"
[[deps.TableTraits]]
deps = ["IteratorInterfaceExtensions"]
@@ -1871,16 +1973,15 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.12.0"
[[deps.TaijaBase]]
-deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"]
-git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0"
+git-tree-sha1 = "4076f60078b12095ca71a2c26e2e4515e3a6a5e5"
uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6"
-version = "1.2.2"
+version = "1.2.3"
[[deps.TaijaPlotting]]
-deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots", "Trapz"]
-git-tree-sha1 = "2fc71041e1c215cf6ef3dc2d3b8419499c4b40ff"
+deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "OneHotArrays", "Plots", "RecipesBase", "Trapz"]
+git-tree-sha1 = "01c76535e8b87b05d0fbc275d44714c78e49fe6e"
uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240"
-version = "1.2.0"
+version = "1.3.0"
[[deps.Tar]]
deps = ["ArgTools", "SHA"]
@@ -1896,6 +1997,7 @@ version = "0.1.1"
[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+version = "1.11.0"
[[deps.ThreadsX]]
deps = ["Accessors", "ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "SplittablesBase", "Transducers"]
@@ -1905,9 +2007,9 @@ version = "0.1.12"
[[deps.TimeZones]]
deps = ["Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"]
-git-tree-sha1 = "b92aebdd3555f3a7e3267cf17702033c2814ef48"
+git-tree-sha1 = "8323074bc977aa85cf5ad71099a83ac75b0ac107"
uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53"
-version = "1.18.0"
+version = "1.18.1"
weakdeps = ["RecipesBase"]
[deps.TimeZones.extensions]
@@ -1915,22 +2017,23 @@ weakdeps = ["RecipesBase"]
[[deps.TimerOutputs]]
deps = ["ExprTools", "Printf"]
-git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531"
+git-tree-sha1 = "3a6f063d690135f5c1ba351412c82bae4d1402bf"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
-version = "0.5.24"
+version = "0.5.25"
[[deps.TranscodingStreams]]
-git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2"
+git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
-version = "0.11.2"
+version = "0.11.3"
[[deps.Transducers]]
-deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"]
-git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23"
+deps = ["Accessors", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"]
+git-tree-sha1 = "7deeab4ff96b85c5f72c824cae53a1398da3d1cb"
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
-version = "0.4.82"
+version = "0.4.84"
[deps.Transducers.extensions]
+ TransducersAdaptExt = "Adapt"
TransducersBlockArraysExt = "BlockArrays"
TransducersDataFramesExt = "DataFrames"
TransducersLazyArraysExt = "LazyArrays"
@@ -1938,6 +2041,7 @@ version = "0.4.82"
TransducersReferenceablesExt = "Referenceables"
[deps.Transducers.weakdeps]
+ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
@@ -1975,6 +2079,7 @@ version = "1.5.1"
[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
+version = "1.11.0"
[[deps.UnPack]]
git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b"
@@ -1983,6 +2088,7 @@ version = "1.0.2"
[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
+version = "1.11.0"
[[deps.UnicodeFun]]
deps = ["REPL"]
@@ -2221,15 +2327,15 @@ version = "1.2.13+1"
[[deps.Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b"
+git-tree-sha1 = "555d1076590a6cc2fdee2ef1469451f872d8b41b"
uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
-version = "1.5.6+0"
+version = "1.5.6+1"
[[deps.Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
-git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54"
+git-tree-sha1 = "f816633be6dc5c0ed9ffedda157ecfda0b3b6a69"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.70"
+version = "0.6.72"
[deps.Zygote.extensions]
ZygoteColorsExt = "Colors"
@@ -2249,9 +2355,15 @@ version = "0.2.5"
[[deps.cuDNN]]
deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"]
-git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e"
+git-tree-sha1 = "4b3ac62501ca73263eaa0d034c772f13c647fba6"
uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
-version = "1.3.2"
+version = "1.4.0"
+
+[[deps.demumble_jll]]
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "6498e3581023f8e530f34760d18f75a69e3a4ea8"
+uuid = "1e29f10c-031c-5a83-9565-69cddfc27673"
+version = "1.3.0+0"
[[deps.eudev_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"]
@@ -2314,9 +2426,9 @@ version = "1.18.0+0"
[[deps.libpng_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"]
-git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4"
+git-tree-sha1 = "b70c870239dc3d7bc094eb2d6be9b73d27bef280"
uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f"
-version = "1.6.43+1"
+version = "1.6.44+0"
[[deps.libvorbis_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"]
@@ -2333,7 +2445,7 @@ version = "1.1.6+0"
[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
-version = "1.52.0+1"
+version = "1.59.0+0"
[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
diff --git a/docs/Project.toml b/docs/Project.toml
index c8d6234a..e8f6d849 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -3,6 +3,7 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
+DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 42ec720a..c45d996b 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -312,7 +312,7 @@ The meaining of "equal" depends on the type of the property value:
- values that are not of `MLJType` are "equal" if they are `==`.
In the special case of a "deep" property, "equal" has a different
-meaning; see [`MMI.StatTraits.deep_properties`](@ref)) for details.
+meaning; see [`MLJBase.deep_properties`](@ref)) for details.
If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`.
@@ -327,7 +327,7 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy
if !_isdefined(m1, name)
!_isdefined(m2, name) || return false
elseif _isdefined(m2, name)
- if name in MMI.StatTraits.deep_properties(LaplaceRegressor)
+ if name in MLJBase.deep_properties(LaplaceRegressor)
_equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) ||
return false
else
From 374aca5bf79b3c5bfedcd86d657a4b1c1fef287b Mon Sep 17 00:00:00 2001
From: pat-alt
Date: Wed, 23 Oct 2024 10:16:16 +0200
Subject: [PATCH 54/60] hadn't saved file
---
src/direct_mlj.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index c45d996b..aa644bd9 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -312,7 +312,7 @@ The meaining of "equal" depends on the type of the property value:
- values that are not of `MLJType` are "equal" if they are `==`.
In the special case of a "deep" property, "equal" has a different
-meaning; see [`MLJBase.deep_properties`](@ref)) for details.
+meaning; see `MLJBase.deep_properties` for details.
If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`.
From 59917f86d3c42a8f3c939ffa6b21b6336ecd6ec2 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 30 Oct 2024 00:46:40 +0100
Subject: [PATCH 55/60] added default mlp
---
src/baselaplace/predicting.jl | 2 +-
src/direct_mlj.jl | 70 +++++++++++++++++++++++++++++++++--
test/direct_mlj_interface.jl | 40 ++++++++++++++------
test/runtests.jl | 4 +-
4 files changed, 99 insertions(+), 17 deletions(-)
diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl
index d26c9a07..feb457b3 100644
--- a/src/baselaplace/predicting.jl
+++ b/src/baselaplace/predicting.jl
@@ -93,7 +93,7 @@ Computes the Bayesian predictivie distribution from a neural network with a Lapl
- `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`.
- `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model.
- `return_distr::Bool=false`: if `false` (default), the function outputs either the direct output of the chain or pseudo-probabilities (if `predict_proba=true`).
- if `true` predict return a probability distribution.
+ if `true` predict returns a probability distribution.
# Returns
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index aa644bd9..82c6d81b 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -65,6 +65,63 @@ MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,)
MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:, I], y[2]))
MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],)
+
+
+"""
+ function features_shape(model::LaplaceRegression, X, y)
+
+Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.
+
+# Arguments
+- `model::LaplaceModels`: The Laplace model to fit.
+- `X`: The input data for training.
+- `y`: The target labels for training one-hot encoded.
+
+# Returns
+- (input size, output size)
+"""
+function features_shape(model::LaplaceModels, X, y)
+ #X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
+ n_input = size(X, 1)
+ dims = size(y)
+ if length(dims) == 1
+ n_output = 1
+ else
+ n_output = dims[1]
+ end
+ return (n_input, n_output)
+end
+
+
+"""
+ default_build( seed::Int, shape)
+
+Builds a default MLP Flux model compatible with the dimensions of the dataset, with reproducible initial weights.
+
+# Arguments
+- `seed::Int`: The seed for random number generation.
+- `shape`: a tuple containing the dimensions of the input layer and the output layer.
+
+# Returns
+- The constructed Flux model, which consist in a simple MLP with 2 hidden layers with 20 neurons each and an input and output layers compatible with the dataset.
+"""
+function default_build(seed::Int, shape)
+ Random.seed!(seed)
+ (n_input, n_output) = shape
+
+ chain = Chain(
+ Dense(n_input, 20, relu),
+ Dense(20, 20, relu),
+ #Dense(20, 20, relu),
+ Dense(20, n_output)
+ )
+
+ return chain
+end
+
+
+
+
@doc """
MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)
@@ -84,6 +141,13 @@ Fit a Laplace model using the provided features and target values.
function MMI.fit(m::LaplaceModels, verbosity, X, y)
y, decode = y
+ if (m.model === nothing)
+ shape = features_shape(m, X, y)
+
+ m.model = default_build(11, shape)
+
+ end
+
# Make a copy of the model because Flux does not allow to mutate hyperparameters
copied_model = deepcopy(m.model)
@@ -141,7 +205,7 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y)
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader)
- optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
+ optimize_prior!(la; verbosity= verbosity, n_steps=m.fit_prior_nsteps)
fitresult = (la, decode)
report = (loss_history=loss_history,)
@@ -229,7 +293,7 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader)
- optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
+ optimize_prior!(la; verbosity = verbosity, n_steps=m.fit_prior_nsteps)
fitresult = (la, decode)
report = (loss_history=old_loss_history,)
@@ -276,7 +340,7 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
# fit the Laplace model:
LaplaceRedux.fit!(la, data_loader)
- optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)
+ optimize_prior!(la; verbosity = verbosity, n_steps=m.fit_prior_nsteps)
fitresult = (la, decode)
report = (loss_history=old_loss_history,)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index f4b6e5fc..e2ec4022 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -9,24 +9,24 @@ import LaplaceRedux: LaplaceClassifier, LaplaceRegressor
cv = MLJBase.CV(; nfolds=3)
@testset "Regression" begin
+ @info " testing interface for LaplaceRegressor"
flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1))
model = LaplaceRegressor(; model=flux_model, epochs=50)
X, y = MLJ.make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
#train, test = partition(eachindex(y), 0.7); # 70:30 split
- mach = MLJ.machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
- MLJBase.fit!(mach; verbosity=1)
- #Xnew, ynew = make_regression(3, 4; rng=123)
+ mach = MLJ.machine(model, X, y)
+ MLJBase.fit!(mach; verbosity=0)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
MLJBase.predict_mode(mach, X) # point predictions
MLJBase.fitted_params(mach) #fitted params function
MLJBase.training_losses(mach) #training loss history
model.epochs = 100 #changing number of epochs
- MLJBase.fit!(mach) #testing update function
+ MLJBase.fit!(mach; verbosity=0) #testing update function
model.epochs = 50 #changing number of epochs to a lower number
- MLJBase.fit!(mach) #testing update function
+ MLJBase.fit!(mach; verbosity=0) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
- MLJBase.fit!(mach) #testing update function (the laplace part)
+ MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
MLJ.evaluate!(mach; resampling=cv, measure=MLJ.log_loss, verbosity=0)
@@ -34,13 +34,29 @@ cv = MLJBase.CV(; nfolds=3)
flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 1))
# test update! fallback to fit!
model.model = flux_model_two
- MLJBase.fit!(mach)
+ MLJBase.fit!(mach; verbosity=0)
model_two = LaplaceRegressor(; model=flux_model_two, epochs=100)
@test !MLJBase.is_same_except(model, model_two)
+
+
+ #testing default mlp builder
+ model = LaplaceRegressor(; model=nothing, epochs=50)
+ mach = MLJ.machine(model, X, y)
+ MLJBase.fit!(mach; verbosity=0)
+ yhat = MLJBase.predict(mach, X) # probabilistic predictions
+ MLJBase.predict_mode(mach, X) # point predictions
+ MLJBase.fitted_params(mach) #fitted params function
+ MLJBase.training_losses(mach) #training loss history
+ model.epochs = 100 #changing number of epochs
+ MLJBase.fit!(mach; verbosity=0) #testing update function
+
+
+
end
@testset "Classification" begin
+ @info " testing interface for LaplaceClassifier"
# Define the model
flux_model = Chain(Dense(4, 10, relu), Dense(10, 3))
@@ -48,7 +64,7 @@ end
X, y = MLJ.@load_iris
mach = MLJ.machine(model, X, y)
- MLJBase.fit!(mach; verbosity=1)
+ MLJBase.fit!(mach; verbosity=0)
Xnew = (
sepal_length=[6.4, 7.2, 7.4],
sepal_width=[2.8, 3.0, 2.8],
@@ -61,11 +77,11 @@ end
MLJBase.fitted_params(mach) # fitted params
MLJBase.training_losses(mach) #training loss history
model.epochs = 100 #changing number of epochs
- MLJBase.fit!(mach) #testing update function
+ MLJBase.fit!(mach; verbosity=0) #testing update function
model.epochs = 50 #changing number of epochs to a lower number
- MLJBase.fit!(mach) #testing update function
+ MLJBase.fit!(mach; verbosity=0) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
- MLJBase.fit!(mach) #testing update function (the laplace part)
+ MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part)
MLJ.evaluate!(mach; resampling=cv, measure=MLJ.brier_loss, verbosity=0)
# Define a different model
@@ -73,5 +89,5 @@ end
model.model = flux_model_two
- MLJBase.fit!(mach)
+ MLJBase.fit!(mach; verbosity=0)
end
diff --git a/test/runtests.jl b/test/runtests.jl
index 92d97033..b29d8e14 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -34,5 +34,7 @@ using Test
@testset "KronDecomposed" begin
include("krondecomposed.jl")
end
-
+ @testset "Interface" begin
+ include("direct_mlj_interface.jl")
+ end
end
From 9da5d7daaff901f63dcc9d4be4108eb68d96d170 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 30 Oct 2024 01:09:07 +0100
Subject: [PATCH 56/60] reducing number of epochs and trying to extende patch
coverage
---
src/direct_mlj.jl | 6 +++---
test/direct_mlj_interface.jl | 25 ++++++++++++++++---------
2 files changed, 19 insertions(+), 12 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 82c6d81b..f2d4a204 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -68,7 +68,7 @@ MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],)
"""
- function features_shape(model::LaplaceRegression, X, y)
+ function dataset_shape(model::LaplaceRegression, X, y)
Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.
@@ -80,7 +80,7 @@ Compute the the number of features of the X input dataset and the number of var
# Returns
- (input size, output size)
"""
-function features_shape(model::LaplaceModels, X, y)
+function dataset_shape(model::LaplaceModels, X, y)
#X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
n_input = size(X, 1)
dims = size(y)
@@ -142,7 +142,7 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y)
y, decode = y
if (m.model === nothing)
- shape = features_shape(m, X, y)
+ shape = dataset_shape(m, X, y)
m.model = default_build(11, shape)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index e2ec4022..88e26e37 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -11,8 +11,9 @@ cv = MLJBase.CV(; nfolds=3)
@testset "Regression" begin
@info " testing interface for LaplaceRegressor"
flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1))
- model = LaplaceRegressor(; model=flux_model, epochs=50)
+ model = LaplaceRegressor(; model=flux_model, epochs=20)
+ #testing more complex dataset
X, y = MLJ.make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
#train, test = partition(eachindex(y), 0.7); # 70:30 split
mach = MLJ.machine(model, X, y)
@@ -21,9 +22,9 @@ cv = MLJBase.CV(; nfolds=3)
MLJBase.predict_mode(mach, X) # point predictions
MLJBase.fitted_params(mach) #fitted params function
MLJBase.training_losses(mach) #training loss history
- model.epochs = 100 #changing number of epochs
+ model.epochs = 40 #changing number of epochs
MLJBase.fit!(mach; verbosity=0) #testing update function
- model.epochs = 50 #changing number of epochs to a lower number
+ model.epochs = 30 #changing number of epochs to a lower number
MLJBase.fit!(mach; verbosity=0) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part)
@@ -41,15 +42,21 @@ cv = MLJBase.CV(; nfolds=3)
#testing default mlp builder
- model = LaplaceRegressor(; model=nothing, epochs=50)
+ model = LaplaceRegressor(; model=nothing, epochs=20)
mach = MLJ.machine(model, X, y)
- MLJBase.fit!(mach; verbosity=0)
+ MLJBase.fit!(mach; verbosity=1)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
MLJBase.predict_mode(mach, X) # point predictions
MLJBase.fitted_params(mach) #fitted params function
MLJBase.training_losses(mach) #training loss history
model.epochs = 100 #changing number of epochs
- MLJBase.fit!(mach; verbosity=0) #testing update function
+ MLJBase.fit!(mach; verbosity=1) #testing update function
+
+ #testing dataset_shape for one dimensional function
+ X, y = MLJ.make_regression(100, 1; noise=0.5, sparse=0.2, outliers=0.1)
+ model = LaplaceRegressor(; model=nothing, epochs=20)
+ mach = MLJ.machine(model, X, y)
+ MLJBase.fit!(mach; verbosity=0)
@@ -60,7 +67,7 @@ end
# Define the model
flux_model = Chain(Dense(4, 10, relu), Dense(10, 3))
- model = LaplaceClassifier(; model=flux_model, epochs=50)
+ model = LaplaceClassifier(; model=flux_model, epochs=20)
X, y = MLJ.@load_iris
mach = MLJ.machine(model, X, y)
@@ -76,9 +83,9 @@ end
MLJBase.pdf.(yhat, "virginica") # probabilities for the "verginica" class
MLJBase.fitted_params(mach) # fitted params
MLJBase.training_losses(mach) #training loss history
- model.epochs = 100 #changing number of epochs
+ model.epochs = 40 #changing number of epochs
MLJBase.fit!(mach; verbosity=0) #testing update function
- model.epochs = 50 #changing number of epochs to a lower number
+ model.epochs = 30 #changing number of epochs to a lower number
MLJBase.fit!(mach; verbosity=0) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part)
From 503763dd3a0c0c91d195d09f31d625914826e01c Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 30 Oct 2024 17:33:15 +0100
Subject: [PATCH 57/60] removed the else because it seems to have no role.
---
src/direct_mlj.jl | 4 ++--
test/direct_mlj_interface.jl | 2 --
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index f2d4a204..3e87a520 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -346,8 +346,8 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
report = (loss_history=old_loss_history,)
cache = (deepcopy(m), old_state_tree, old_loss_history)
- else
- fitresult, cache, report = MLJBase.fit(m, verbosity, X, y)
+ #else
+ #fitresult, cache, report = MLJBase.fit(m, verbosity, X, y)
end
return fitresult, cache, report
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 88e26e37..5e658c58 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -37,8 +37,6 @@ cv = MLJBase.CV(; nfolds=3)
model.model = flux_model_two
MLJBase.fit!(mach; verbosity=0)
- model_two = LaplaceRegressor(; model=flux_model_two, epochs=100)
- @test !MLJBase.is_same_except(model, model_two)
#testing default mlp builder
From e78e9b89a24b22fd3838023d05d52d4f6db5125f Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 30 Oct 2024 17:34:12 +0100
Subject: [PATCH 58/60] ops forgot to remove the comment
---
src/direct_mlj.jl | 2 --
1 file changed, 2 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 3e87a520..a1628cad 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -346,8 +346,6 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
report = (loss_history=old_loss_history,)
cache = (deepcopy(m), old_state_tree, old_loss_history)
- #else
- #fitresult, cache, report = MLJBase.fit(m, verbosity, X, y)
end
return fitresult, cache, report
From 6a5f26f9478cc6700f1c214520bed3c3a2ba3c30 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 30 Oct 2024 18:06:09 +0100
Subject: [PATCH 59/60] various change in the documentation
---
src/direct_mlj.jl | 9 +++------
test/direct_mlj_interface.jl | 2 ++
2 files changed, 5 insertions(+), 6 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index a1628cad..2a8187a5 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -1,4 +1,3 @@
-#module MLJLaplaceRedux
using Optimisers: Optimisers
using Flux
using Random
@@ -81,7 +80,6 @@ Compute the the number of features of the X input dataset and the number of var
- (input size, output size)
"""
function dataset_shape(model::LaplaceModels, X, y)
- #X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
n_input = size(X, 1)
dims = size(y)
if length(dims) == 1
@@ -142,6 +140,7 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y)
y, decode = y
if (m.model === nothing)
+ @warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 3 hidden layers with 20 neurons each and input and output layers compatible with the dataset."
shape = dataset_shape(m, X, y)
m.model = default_build(11, shape)
@@ -606,7 +605,7 @@ Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters (format: name-type-default value-restrictions)
-- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
+- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
@@ -744,8 +743,7 @@ Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters (format: name-type-default value-restrictions)
-- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
-
+- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
- `optimiser = Adam()` a Flux optimiser
@@ -846,4 +844,3 @@ See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl
"""
LaplaceRegressor
-#end # module
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
index 5e658c58..8f2b0438 100644
--- a/test/direct_mlj_interface.jl
+++ b/test/direct_mlj_interface.jl
@@ -36,6 +36,8 @@ cv = MLJBase.CV(; nfolds=3)
# test update! fallback to fit!
model.model = flux_model_two
MLJBase.fit!(mach; verbosity=0)
+ model_two = LaplaceRegressor(; model=flux_model_two, epochs=100)
+ @test !MLJBase.is_same_except(model, model_two)
From 12f2584462f5dbad7090941f3d489071fcb258a8 Mon Sep 17 00:00:00 2001
From: "pasquale c." <343guiltyspark@outlook.it>
Date: Wed, 30 Oct 2024 18:12:54 +0100
Subject: [PATCH 60/60] ufffffffffffffffffffff
---
src/direct_mlj.jl | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
index 2a8187a5..61aab7dd 100644
--- a/src/direct_mlj.jl
+++ b/src/direct_mlj.jl
@@ -110,7 +110,7 @@ function default_build(seed::Int, shape)
chain = Chain(
Dense(n_input, 20, relu),
Dense(20, 20, relu),
- #Dense(20, 20, relu),
+ Dense(20, 20, relu),
Dense(20, n_output)
)
@@ -140,7 +140,7 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y)
y, decode = y
if (m.model === nothing)
- @warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 3 hidden layers with 20 neurons each and input and output layers compatible with the dataset."
+ @warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each and input and output layers compatible with the dataset."
shape = dataset_shape(m, X, y)
m.model = default_build(11, shape)
@@ -605,7 +605,7 @@ Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters (format: name-type-default value-restrictions)
-- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each.
+- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
@@ -743,7 +743,7 @@ Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters (format: name-type-default value-restrictions)
-- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each.
+- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
- `optimiser = Adam()` a Flux optimiser