From 51ed52680144505936d0268d6e41a8fac9049721 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Thu, 12 Sep 2024 20:27:06 +0200 Subject: [PATCH 01/60] new first commit --- Project.toml | 1 + src/LaplaceRedux.jl | 3 + src/baselaplace/predicting.jl | 2 +- src/direct_mlj.jl | 142 ++++++++++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 src/direct_mlj.jl diff --git a/Project.toml b/Project.toml index 451430e3..2028e0ba 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Patrick Altmeyer"] version = "1.1.1" [deps] +CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index 9a36d18e..4d8df9ec 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -30,4 +30,7 @@ export empirical_frequency_binary_classification, sharpness_regression, extract_mean_and_variance, sigma_scaling, rescale_stddev + + + include("direct_mlj.jl") end diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 4a0a1d6c..71b74a4b 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -93,7 +93,7 @@ Computes the Bayesian predictivie distribution from a neural network with a Lapl - `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`. - `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model. - `return_distr::Bool=false`: if `false` (default), the function outputs either the direct output of the chain or pseudo-probabilities (if `predict_proba=true`). - if `true` predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks. + if `true` predict return a probability distribution. # Returns diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl new file mode 100644 index 00000000..ed046ff5 --- /dev/null +++ b/src/direct_mlj.jl @@ -0,0 +1,142 @@ +using Flux +using ProgressMeter: Progress, next!, BarGlyphs +using Random +using Tables +using LinearAlgebra +using LaplaceRedux +using ComputationalResources +using MLJBase: MLJBase +import MLJModelInterface as MMI +using Optimisers: Optimisers + +""" + MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic + +A mutable struct representing a Laplace regression model. +It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. +It has the following Hyperparameters: +- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. +- `subnetwork_indices`: the indices of the subnetworks. +- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. +- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. +- `σ`: the standard deviation of the prior distribution. +- `μ₀`: the mean of the prior distribution. +- `P₀`: the covariance matrix of the prior distribution. +- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. +- `fit_prior_nsteps`: the number of steps used to fit the priors. +""" +MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic + subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) + subnetwork_indices = nothing + hessian_structure::Union{HessianStructure,Symbol,String} = + :full::(_ in (:full, :diagonal)) + backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) + σ::Float64 = 1.0 + μ₀::Float64 = 0.0 + P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing + ret_distr::Bool = false::(_ in (true, false)) + fit_prior_nsteps::Int = 100::(_ > 0) +end + + +function MLJModelInterface.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing) + + X = MLJBase.matrix(X) + + + + + cache = nothing + return (fitresult, cache, report) +end + + + + + + + + + + + + + + + + + + + + + +""" + MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic + +A mutable struct representing a Laplace Classification model. +It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. + + +The model also has the following parameters: + +- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. +- `subnetwork_indices`: the indices of the subnetworks. +- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. +- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. +- `σ`: the standard deviation of the prior distribution. +- `μ₀`: the mean of the prior distribution. +- `P₀`: the covariance matrix of the prior distribution. +- `link_approx`: the link approximation to use, either `:probit` or `:plugin`. +- `predict_proba`: a boolean that select whether to predict probabilities or not. +- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. +- `fit_prior_nsteps`: the number of steps used to fit the priors. +""" +MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic + subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) + subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) + hessian_structure::Union{HessianStructure,Symbol,String} = + :full::(_ in (:full, :diagonal)) + backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) + σ::Float64 = 1.0 + μ₀::Float64 = 0.0 + P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing + link_approx::Symbol = :probit::(_ in (:probit, :plugin)) + predict_proba::Bool = true::(_ in (true, false)) + ret_distr::Bool = false::(_ in (true, false)) + fit_prior_nsteps::Int = 100::(_ > 0) +end + + + +function MLJModelInterface.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing) + + + + + + cache = nothing + return (fitresult, cache, report) +end + + + + +MLJBase.metadata_model( + LaplaceClassifier; + input_scitype=Union{ + AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types + MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types + }, + target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass + load_path="LaplaceRedux.LaplaceClassification", +) +# metadata for each model, +MLJBase.metadata_model( + LaplaceRegressor; + input_scitype=Union{ + AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types + MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types + }, + target_scitype=AbstractArray{MLJBase.Continuous}, + load_path="LaplaceRedux.MLJFlux.LaplaceRegression", +) \ No newline at end of file From e8e96d13f956b7a74a00bb7246cf4f9b82f9274c Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 13 Sep 2024 00:40:22 +0200 Subject: [PATCH 02/60] various stuff --- src/direct_mlj.jl | 99 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 73 insertions(+), 26 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index ed046ff5..b1d0f01e 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -1,13 +1,11 @@ using Flux -using ProgressMeter: Progress, next!, BarGlyphs using Random using Tables using LinearAlgebra using LaplaceRedux -using ComputationalResources -using MLJBase: MLJBase +using MLJBase import MLJModelInterface as MMI -using Optimisers: Optimisers +using Distributions: Normal """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic @@ -15,6 +13,7 @@ using Optimisers: Optimisers A mutable struct representing a Laplace regression model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. It has the following Hyperparameters: +- flux_model???? - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -26,6 +25,7 @@ It has the following Hyperparameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic + flux_model::Flux subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices = nothing hessian_structure::Union{HessianStructure,Symbol,String} = @@ -39,13 +39,30 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist end -function MLJModelInterface.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing) +function MMI.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing) + features = Tables.schema(X).names X = MLJBase.matrix(X) - - - + la = LaplaceRedux.Laplace( + m.flux_model; + likelihood=:regression, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=m.σ, + μ₀=m.μ₀, + P₀=m.P₀, + ) + + # fit the Laplace model: + LaplaceRedux.fit!(la, zip(X, y)) + optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) + + + fitresult=la + report = (status="success", message="Model fitted successfully") cache = nothing return (fitresult, cache, report) end @@ -53,18 +70,12 @@ end - - - - - - - - - - - - +function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) + Xnew = MLJBase.matrix(Xnew) |> permutedims + la = fitresult[1] + yhat = LaplaceRedux.predict(la, Xnew; ret_distr=model.ret_distr) + return [Normal(μᵢ, σ̂) for (μᵢ,σ) ∈ yhat] +end @@ -79,6 +90,7 @@ It uses Laplace approximation to estimate the posterior distribution of the weig The model also has the following parameters: +- `model`: Flux ??? - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -92,6 +104,8 @@ The model also has the following parameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic + + model::Flux subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) hessian_structure::Union{HessianStructure,Symbol,String} = @@ -108,14 +122,47 @@ end -function MLJModelInterface.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing) - - - +function MMI.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing) + features = Tables.schema(X).names + Xmatrix = MLJBase.matrix(X) + decode = y[1] + y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int + + la = LaplaceRedux.Laplace( + m.flux_model; + likelihood=:classification, + subset_of_weights=model.subset_of_weights, + subnetwork_indices=model.subnetwork_indices, + hessian_structure=model.hessian_structure, + backend=model.backend, + σ=m.σ, + μ₀=m.μ₀, + P₀=m.P₀, + ) + + # fit the Laplace model: + LaplaceRedux.fit!(la, zip(X, y_plain)) + optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) + + + fitresult=la + report = (status="success", message="Model fitted successfully") + cache = nothing + return (fitresult, decode), cache, report +end - cache = nothing - return (fitresult, cache, report) +function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) + la = fitresult + Xnew = MLJBase.matrix(Xnew) |> permutedims + predictions = LaplaceRedux.predict( + la, + Xnew; + link_approx=model.link_approx, + predict_proba=model.predict_proba, + ret_distr=model.ret_distr, + ) + return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in predictions] end From 255cf19dfb00255ef099317501a8d6bc86f769a6 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 13 Sep 2024 01:28:13 +0200 Subject: [PATCH 03/60] fixes --- src/direct_mlj.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index b1d0f01e..d341f4a1 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -25,7 +25,7 @@ It has the following Hyperparameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic - flux_model::Flux + flux_model::Flux.Chain= nothing subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices = nothing hessian_structure::Union{HessianStructure,Symbol,String} = @@ -105,7 +105,7 @@ The model also has the following parameters: """ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic - model::Flux + #model::Flux subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) hessian_structure::Union{HessianStructure,Symbol,String} = @@ -172,18 +172,18 @@ MLJBase.metadata_model( LaplaceClassifier; input_scitype=Union{ AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types - MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types + MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types }, target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass - load_path="LaplaceRedux.LaplaceClassification", + load_path="LaplaceRedux.LaplaceClassifier", ) # metadata for each model, MLJBase.metadata_model( LaplaceRegressor; input_scitype=Union{ AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types - MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types + MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types }, target_scitype=AbstractArray{MLJBase.Continuous}, - load_path="LaplaceRedux.MLJFlux.LaplaceRegression", + load_path="LaplaceRedux.LaplaceRegressor", ) \ No newline at end of file From 35ac2d8c25ebca59518754889b3c2a22f7796e29 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 13 Sep 2024 06:08:06 +0200 Subject: [PATCH 04/60] changes --- src/direct_mlj.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index d341f4a1..cb4c3c17 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -25,7 +25,7 @@ It has the following Hyperparameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic - flux_model::Flux.Chain= nothing + flux_model::Flux.Chain = nothing subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices = nothing hessian_structure::Union{HessianStructure,Symbol,String} = @@ -39,7 +39,7 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist end -function MMI.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing) +function MMI.fit(m::LaplaceRegressor, verbosity, X, y) features = Tables.schema(X).names X = MLJBase.matrix(X) @@ -47,15 +47,17 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing) la = LaplaceRedux.Laplace( m.flux_model; likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, σ=m.σ, μ₀=m.μ₀, P₀=m.P₀, ) + println(la) + # fit the Laplace model: LaplaceRedux.fit!(la, zip(X, y)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) @@ -122,7 +124,7 @@ end -function MMI.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing) +function MMI.fit(m::LaplaceClassifier, verbosity, X, y) features = Tables.schema(X).names Xmatrix = MLJBase.matrix(X) decode = y[1] From 44469a20bdc62e489c5c2173b3a7498f1dbf263d Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 18 Sep 2024 08:13:25 +0200 Subject: [PATCH 05/60] there is still a problem with the classifier --- src/direct_mlj.jl | 49 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index cb4c3c17..b8da8f70 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -13,7 +13,8 @@ using Distributions: Normal A mutable struct representing a Laplace regression model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. It has the following Hyperparameters: -- flux_model???? +- `flux_model`: A flux model provided by the user and compatible with the dataset. +- `epochs`: The number of training epochs. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -25,7 +26,9 @@ It has the following Hyperparameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic + flux_model::Flux.Chain = nothing + epochs::Integer = 1000::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices = nothing hessian_structure::Union{HessianStructure,Symbol,String} = @@ -42,7 +45,19 @@ end function MMI.fit(m::LaplaceRegressor, verbosity, X, y) features = Tables.schema(X).names - X = MLJBase.matrix(X) + X = MLJBase.matrix(X) |> permutedims + y = reshape(y, 1,:) + data_loader = Flux.DataLoader((X,y), batchsize=10) + opt_state = Flux.setup(Adam(), m.flux_model) + loss(y_hat, y) = mean(Flux.Losses.mse.(y_hat, y)) + + for epoch in 1:m.epochs + Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y + loss(model(X), y) + + end + end + la = LaplaceRedux.Laplace( m.flux_model; @@ -56,11 +71,10 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) P₀=m.P₀, ) - println(la) # fit the Laplace model: - LaplaceRedux.fit!(la, zip(X, y)) - optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) + LaplaceRedux.fit!(la, data_loader ) + optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps) fitresult=la @@ -92,7 +106,8 @@ It uses Laplace approximation to estimate the posterior distribution of the weig The model also has the following parameters: -- `model`: Flux ??? +- `flux_model`: A flux model provided by the user and compatible with the dataset. +- `epochs`: The number of training epochs. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -107,7 +122,8 @@ The model also has the following parameters: """ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic - #model::Flux + flux_model::Flux.Chain = nothing + epochs::Integer = 1000::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) hessian_structure::Union{HessianStructure,Symbol,String} = @@ -130,6 +146,17 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) decode = y[1] y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int + + + for epoch in 1:m.epochs + Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y + loss(model(X), y) + + end + end + + + la = LaplaceRedux.Laplace( m.flux_model; likelihood=:classification, @@ -142,15 +169,17 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) P₀=m.P₀, ) + + + # fit the Laplace model: LaplaceRedux.fit!(la, zip(X, y_plain)) optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - - fitresult=la + report = (status="success", message="Model fitted successfully") cache = nothing - return (fitresult, decode), cache, report + return (la, decode), cache, report end From 53152645ebec793d86d37fc1d4e818aaf75f45c4 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 18 Sep 2024 11:03:37 +0200 Subject: [PATCH 06/60] almost fixed --- src/direct_mlj.jl | 71 ++++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index b8da8f70..d30b6c71 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -15,6 +15,7 @@ It uses Laplace approximation to estimate the posterior distribution of the weig It has the following Hyperparameters: - `flux_model`: A flux model provided by the user and compatible with the dataset. - `epochs`: The number of training epochs. +- `batch_size`: The batch size. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -29,6 +30,7 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist flux_model::Flux.Chain = nothing epochs::Integer = 1000::(_ > 0) + batch_size::Integer= 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices = nothing hessian_structure::Union{HessianStructure,Symbol,String} = @@ -47,7 +49,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) X = MLJBase.matrix(X) |> permutedims y = reshape(y, 1,:) - data_loader = Flux.DataLoader((X,y), batchsize=10) + data_loader = Flux.DataLoader((X,y), batchsize=m.batch_size) opt_state = Flux.setup(Adam(), m.flux_model) loss(y_hat, y) = mean(Flux.Losses.mse.(y_hat, y)) @@ -88,9 +90,13 @@ end function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) Xnew = MLJBase.matrix(Xnew) |> permutedims - la = fitresult[1] - yhat = LaplaceRedux.predict(la, Xnew; ret_distr=model.ret_distr) - return [Normal(μᵢ, σ̂) for (μᵢ,σ) ∈ yhat] + la = fitresult + yhat = LaplaceRedux.predict(la, Xnew; ret_distr=m.ret_distr) + # Extract mean and variance matrices + means, variances = yhat + + # Create Normal distributions from the means and variances + return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)] end @@ -108,6 +114,7 @@ The model also has the following parameters: - `flux_model`: A flux model provided by the user and compatible with the dataset. - `epochs`: The number of training epochs. +- `batch_size`: The batch size. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. - `subnetwork_indices`: the indices of the subnetworks. - `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. @@ -124,33 +131,41 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis flux_model::Flux.Chain = nothing epochs::Integer = 1000::(_ > 0) + batch_size::Integer= 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) + subnetwork_indices = nothing hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal)) backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) σ::Float64 = 1.0 μ₀::Float64 = 0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - link_approx::Symbol = :probit::(_ in (:probit, :plugin)) - predict_proba::Bool = true::(_ in (true, false)) ret_distr::Bool = false::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) + link_approx::Symbol = :probit::(_ in (:probit, :plugin)) end - +#link_approx::Symbol = :probit::(_ in (:probit, :plugin)) function MMI.fit(m::LaplaceClassifier, verbosity, X, y) features = Tables.schema(X).names - Xmatrix = MLJBase.matrix(X) + X = MLJBase.matrix(X) |> permutedims + y = reshape(y, 1,:) + decode = y[1] y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int + loss(y_hat, y) = mean(Flux.Losses.crossentropy(y_hat, y)) + + data_loader = Flux.DataLoader((X,y_plain), batchsize=m.batch_size) + opt_state = Flux.setup(Adam(), m.flux_model) + + for epoch in 1:m.epochs - Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y - loss(model(X), y) + Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_plain + loss(model(X), y_plain) end end @@ -160,10 +175,10 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) la = LaplaceRedux.Laplace( m.flux_model; likelihood=:classification, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, σ=m.σ, μ₀=m.μ₀, P₀=m.P₀, @@ -173,27 +188,27 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) # fit the Laplace model: - LaplaceRedux.fit!(la, zip(X, y_plain)) - optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - + LaplaceRedux.fit!(la, data_loader ) + optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps) + + fitresult=la report = (status="success", message="Model fitted successfully") cache = nothing - return (la, decode), cache, report + return ((fitresult,decode), cache, report) end function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) - la = fitresult + #la = fitresult Xnew = MLJBase.matrix(Xnew) |> permutedims - predictions = LaplaceRedux.predict( - la, - Xnew; - link_approx=model.link_approx, - predict_proba=model.predict_proba, - ret_distr=model.ret_distr, - ) - return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in predictions] + #predictions = LaplaceRedux.predict( + #la, + #Xnew; + #link_approx=model.link_approx, + #predict_proba=model.predict_proba, + #ret_distr=model.ret_distr) + #return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in predictions] end From f6e7a00c4c16288f085657bd53caea79e70eb256 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 18 Sep 2024 11:20:36 +0200 Subject: [PATCH 07/60] works but i have to fix the hyperparameters --- src/direct_mlj.jl | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index d30b6c71..207a101c 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -156,7 +156,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int - loss(y_hat, y) = mean(Flux.Losses.crossentropy(y_hat, y)) + loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y) data_loader = Flux.DataLoader((X,y_plain), batchsize=m.batch_size) opt_state = Flux.setup(Adam(), m.flux_model) @@ -191,24 +191,21 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) LaplaceRedux.fit!(la, data_loader ) optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps) - - fitresult=la report = (status="success", message="Model fitted successfully") cache = nothing - return ((fitresult,decode), cache, report) + return ((la,decode), cache, report) end function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) - #la = fitresult + la = fitresult Xnew = MLJBase.matrix(Xnew) |> permutedims - #predictions = LaplaceRedux.predict( - #la, - #Xnew; - #link_approx=model.link_approx, - #predict_proba=model.predict_proba, - #ret_distr=model.ret_distr) - #return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in predictions] + predictions = LaplaceRedux.predict( + la, + Xnew; + link_approx=m.link_approx, + ret_distr=m.ret_distr) + return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction,augment=true) for prediction in predictions] end From d7c4f7b65cbf00404f43cf51ca851696f3643dda Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 18 Sep 2024 11:46:17 +0200 Subject: [PATCH 08/60] question on parameters.... --- src/direct_mlj.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 207a101c..46f826b5 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -29,6 +29,7 @@ It has the following Hyperparameters: MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic flux_model::Flux.Chain = nothing + flux_loss = Flux.Losses.mse epochs::Integer = 1000::(_ > 0) batch_size::Integer= 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) @@ -39,7 +40,7 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist σ::Float64 = 1.0 μ₀::Float64 = 0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - ret_distr::Bool = false::(_ in (true, false)) + #ret_distr::Bool = false::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) end @@ -51,11 +52,10 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) y = reshape(y, 1,:) data_loader = Flux.DataLoader((X,y), batchsize=m.batch_size) opt_state = Flux.setup(Adam(), m.flux_model) - loss(y_hat, y) = mean(Flux.Losses.mse.(y_hat, y)) for epoch in 1:m.epochs Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y - loss(model(X), y) + m.flux_loss(model(X), y) end end @@ -91,7 +91,7 @@ end function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) Xnew = MLJBase.matrix(Xnew) |> permutedims la = fitresult - yhat = LaplaceRedux.predict(la, Xnew; ret_distr=m.ret_distr) + yhat = LaplaceRedux.predict(la, Xnew; ret_distr= false) # Extract mean and variance matrices means, variances = yhat @@ -130,6 +130,7 @@ The model also has the following parameters: MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic flux_model::Flux.Chain = nothing + flux_loss = Flux.Losses.logitcrossentropy epochs::Integer = 1000::(_ > 0) batch_size::Integer= 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) @@ -140,12 +141,12 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis σ::Float64 = 1.0 μ₀::Float64 = 0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - ret_distr::Bool = false::(_ in (true, false)) + #ret_distr::Bool = false::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) link_approx::Symbol = :probit::(_ in (:probit, :plugin)) end -#link_approx::Symbol = :probit::(_ in (:probit, :plugin)) + function MMI.fit(m::LaplaceClassifier, verbosity, X, y) features = Tables.schema(X).names @@ -156,7 +157,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int - loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y) + #loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y) data_loader = Flux.DataLoader((X,y_plain), batchsize=m.batch_size) opt_state = Flux.setup(Adam(), m.flux_model) @@ -165,7 +166,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) for epoch in 1:m.epochs Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_plain - loss(model(X), y_plain) + m.flux_loss(model(X), y_plain) end end @@ -204,7 +205,7 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) la, Xnew; link_approx=m.link_approx, - ret_distr=m.ret_distr) + ret_distr=false) return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction,augment=true) for prediction in predictions] end From bbab4600302eb8c8b3836eab6b86cc0ab4ac74f7 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 18 Sep 2024 13:17:15 +0200 Subject: [PATCH 09/60] there is some problem with the one hot encoding --- src/direct_mlj.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 46f826b5..52e34a2d 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -46,7 +46,7 @@ end function MMI.fit(m::LaplaceRegressor, verbosity, X, y) - features = Tables.schema(X).names + #features = Tables.schema(X).names X = MLJBase.matrix(X) |> permutedims y = reshape(y, 1,:) @@ -149,24 +149,25 @@ end function MMI.fit(m::LaplaceClassifier, verbosity, X, y) - features = Tables.schema(X).names + #features = Tables.schema(X).names X = MLJBase.matrix(X) |> permutedims - y = reshape(y, 1,:) - - decode = y[1] + decode = MMI.decoder(y[1]) + #y = reshape(y, 1,:) y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int + y_onehot = Flux.onehotbatch(y_plain, unique(y_plain) ) + #loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y) - data_loader = Flux.DataLoader((X,y_plain), batchsize=m.batch_size) + data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size) opt_state = Flux.setup(Adam(), m.flux_model) for epoch in 1:m.epochs - Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_plain - m.flux_loss(model(X), y_plain) + Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot + m.flux_loss(model(X), y_onehot) end end @@ -206,7 +207,7 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) Xnew; link_approx=m.link_approx, ret_distr=false) - return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction,augment=true) for prediction in predictions] + return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction, pool=decode,augment=true) for prediction in predictions] end From 8af38aed4b5c1895fd2f08793b85f12c1130e2ae Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Thu, 19 Sep 2024 12:22:04 +0200 Subject: [PATCH 10/60] fixed error in univariatefinite --- src/direct_mlj.jl | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 52e34a2d..d2b42368 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -149,22 +149,13 @@ end function MMI.fit(m::LaplaceClassifier, verbosity, X, y) - #features = Tables.schema(X).names X = MLJBase.matrix(X) |> permutedims - decode = MMI.decoder(y[1]) - #y = reshape(y, 1,:) - y_plain = MLJBase.int(y) .- 1 # 0, 1 of type Int + decode = y[1] + y_plain = MLJBase.int(y) .- 1 y_onehot = Flux.onehotbatch(y_plain, unique(y_plain) ) - - - - #loss(y_hat, y) = Flux.Losses.logitcrossentropy(y_hat, y) - data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size) opt_state = Flux.setup(Adam(), m.flux_model) - - for epoch in 1:m.epochs Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot m.flux_loss(model(X), y_onehot) @@ -173,7 +164,6 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) end - la = LaplaceRedux.Laplace( m.flux_model; likelihood=:classification, @@ -186,9 +176,6 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) P₀=m.P₀, ) - - - # fit the Laplace model: LaplaceRedux.fit!(la, data_loader ) optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps) @@ -206,8 +193,12 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) la, Xnew; link_approx=m.link_approx, - ret_distr=false) - return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction, pool=decode,augment=true) for prediction in predictions] + ret_distr=false) |>permutedims + + + + + return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in eachrow(predictions)] end From fe19d4d405c631632677ecea1d1cc5ae88b70359 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Thu, 19 Sep 2024 12:26:25 +0200 Subject: [PATCH 11/60] performance improvement --- src/direct_mlj.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index d2b42368..20fa29a6 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -198,7 +198,7 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) - return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction) for prediction in eachrow(predictions)] + return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions) end From d809afbbfa65670d0f2e2ea2ffc1af796ee7ad66 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 21 Sep 2024 03:47:34 +0200 Subject: [PATCH 12/60] JuliaFormatter --- dev/issues/predict_slow.jl | 6 +-- docs/make.jl | 2 +- src/LaplaceRedux.jl | 6 +-- src/baselaplace/predicting.jl | 1 - src/calibration_functions.jl | 9 ++-- src/data/functions.jl | 4 +- src/direct_mlj.jl | 85 ++++++++++++----------------------- test/calibration.jl | 4 +- test/data.jl | 2 +- 9 files changed, 45 insertions(+), 74 deletions(-) diff --git a/dev/issues/predict_slow.jl b/dev/issues/predict_slow.jl index ae713cbd..95c7c1f8 100644 --- a/dev/issues/predict_slow.jl +++ b/dev/issues/predict_slow.jl @@ -4,16 +4,16 @@ using Optimisers X = MLJBase.table(rand(Float32, 100, 3)); y = coerce(rand("abc", 100), Multiclass); -model = LaplaceClassification(optimiser=Optimisers.Adam(0.1), epochs=100); +model = LaplaceClassification(; optimiser=Optimisers.Adam(0.1), epochs=100); fitresult, _, _ = MLJBase.fit(model, 2, X, y); la = fitresult[1]; Xmat = matrix(X) |> permutedims; # Single test sample: -Xtest = Xmat[:,1:10]; +Xtest = Xmat[:, 1:10]; Xtest_tab = MLJBase.table(Xtest'); MLJBase.predict(model, fitresult, Xtest_tab); # warm up LaplaceRedux.predict(la, Xmat); # warm up @time MLJBase.predict(model, fitresult, Xtest_tab); @time LaplaceRedux.predict(la, Xtest); -@time glm_predictive_distribution(la, Xtest); \ No newline at end of file +@time glm_predictive_distribution(la, Xtest); diff --git a/docs/make.jl b/docs/make.jl index 8aedd777..f24ea50b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,7 +24,7 @@ makedocs(; "Calibrated forecasts" => "tutorials/calibration.md", ], "Reference" => "reference.md", - "MLJ interface"=> "mlj_interface.md", + "MLJ interface" => "mlj_interface.md", "Additional Resources" => "resources/_resources.md", ], ) diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index 4d8df9ec..15ab3d4f 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -29,8 +29,8 @@ export empirical_frequency_binary_classification, empirical_frequency_regression, sharpness_regression, extract_mean_and_variance, - sigma_scaling, rescale_stddev + sigma_scaling, + rescale_stddev - - include("direct_mlj.jl") +include("direct_mlj.jl") end diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 71b74a4b..d26c9a07 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -144,7 +144,6 @@ function predict( else return fμ, pred_var end - end # Classification: diff --git a/src/calibration_functions.jl b/src/calibration_functions.jl index 1d313248..d6aa5062 100644 --- a/src/calibration_functions.jl +++ b/src/calibration_functions.jl @@ -78,7 +78,7 @@ Outputs: \ """ function sharpness_classification(y_binary, distributions::Vector{Bernoulli{Float64}}) mean_class_one = mean(mean.(distributions[findall(y_binary .== 1)])) - mean_class_zero = mean( 1 .- mean.(distributions[findall(y_binary .== 0)])) + mean_class_zero = mean(1 .- mean.(distributions[findall(y_binary .== 0)])) return mean_class_one, mean_class_zero end @@ -176,8 +176,9 @@ Inputs: \ Outputs: \ - `sigma`: the scalar that maximize the likelihood. """ -function sigma_scaling(distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat} - ) where T <: AbstractFloat +function sigma_scaling( + distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat} +) where {T<:AbstractFloat} means, variances = extract_mean_and_variance(distr) sigma = sqrt(1 / length(y_cal) * sum(norm.(y_cal .- means) ./ variances)) @@ -198,4 +199,4 @@ Outputs: \ function rescale_stddev(distr::Vector{Normal{T}}, s::T) where {T<:AbstractFloat} rescaled_distr = [Normal(mean(d), std(d) * s) for d in distr] return rescaled_distr -end \ No newline at end of file +end diff --git a/src/data/functions.jl b/src/data/functions.jl index 99e4440c..5ae83805 100644 --- a/src/data/functions.jl +++ b/src/data/functions.jl @@ -1,4 +1,4 @@ -import Random +using Random: Random """ toy_data_linear(N=100) @@ -42,7 +42,7 @@ toy_data_non_linear() ``` """ -function toy_data_non_linear( N=100; seed=nothing) +function toy_data_non_linear(N=100; seed=nothing) #set seed if available if seed !== nothing diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 20fa29a6..73758254 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -27,11 +27,10 @@ It has the following Hyperparameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic - flux_model::Flux.Chain = nothing flux_loss = Flux.Losses.mse epochs::Integer = 1000::(_ > 0) - batch_size::Integer= 32::(_ > 0) + batch_size::Integer = 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices = nothing hessian_structure::Union{HessianStructure,Symbol,String} = @@ -44,22 +43,19 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist fit_prior_nsteps::Int = 100::(_ > 0) end - function MMI.fit(m::LaplaceRegressor, verbosity, X, y) #features = Tables.schema(X).names X = MLJBase.matrix(X) |> permutedims - y = reshape(y, 1,:) - data_loader = Flux.DataLoader((X,y), batchsize=m.batch_size) + y = reshape(y, 1, :) + data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) opt_state = Flux.setup(Adam(), m.flux_model) - for epoch in 1:m.epochs - Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y + for epoch in 1:(m.epochs) + Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y m.flux_loss(model(X), y) - end - end - + end la = LaplaceRedux.Laplace( m.flux_model; @@ -73,25 +69,20 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) P₀=m.P₀, ) - # fit the Laplace model: - LaplaceRedux.fit!(la, data_loader ) - optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps) + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - - fitresult=la + fitresult = la report = (status="success", message="Model fitted successfully") - cache = nothing + cache = nothing return (fitresult, cache, report) end - - - function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) Xnew = MLJBase.matrix(Xnew) |> permutedims la = fitresult - yhat = LaplaceRedux.predict(la, Xnew; ret_distr= false) + yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) # Extract mean and variance matrices means, variances = yhat @@ -99,10 +90,6 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)] end - - - - """ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic @@ -128,11 +115,10 @@ The model also has the following parameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic - flux_model::Flux.Chain = nothing flux_loss = Flux.Losses.logitcrossentropy epochs::Integer = 1000::(_ > 0) - batch_size::Integer= 32::(_ > 0) + batch_size::Integer = 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) subnetwork_indices = nothing hessian_structure::Union{HessianStructure,Symbol,String} = @@ -146,23 +132,19 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis link_approx::Symbol = :probit::(_ in (:probit, :plugin)) end - - function MMI.fit(m::LaplaceClassifier, verbosity, X, y) X = MLJBase.matrix(X) |> permutedims decode = y[1] - y_plain = MLJBase.int(y) .- 1 - y_onehot = Flux.onehotbatch(y_plain, unique(y_plain) ) - data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size) + y_plain = MLJBase.int(y) .- 1 + y_onehot = Flux.onehotbatch(y_plain, unique(y_plain)) + data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) opt_state = Flux.setup(Adam(), m.flux_model) - for epoch in 1:m.epochs - Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot + for epoch in 1:(m.epochs) + Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y_onehot m.flux_loss(model(X), y_onehot) - end - end - + end la = LaplaceRedux.Laplace( m.flux_model; @@ -177,37 +159,28 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) ) # fit the Laplace model: - LaplaceRedux.fit!(la, data_loader ) - optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps) + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) report = (status="success", message="Model fitted successfully") - cache = nothing - return ((la,decode), cache, report) + cache = nothing + return ((la, decode), cache, report) end - function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) la = fitresult Xnew = MLJBase.matrix(Xnew) |> permutedims - predictions = LaplaceRedux.predict( - la, - Xnew; - link_approx=m.link_approx, - ret_distr=false) |>permutedims - - - + predictions = + LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> + permutedims return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions) end - - - MLJBase.metadata_model( LaplaceClassifier; input_scitype=Union{ - AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types + AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types }, target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass @@ -217,9 +190,9 @@ MLJBase.metadata_model( MLJBase.metadata_model( LaplaceRegressor; input_scitype=Union{ - AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}}, # matrix with mixed types + AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types }, - target_scitype=AbstractArray{MLJBase.Continuous}, + target_scitype=AbstractArray{MLJBase.Continuous}, load_path="LaplaceRedux.LaplaceRegressor", -) \ No newline at end of file +) diff --git a/test/calibration.jl b/test/calibration.jl index 847543ce..988ff1ef 100644 --- a/test/calibration.jl +++ b/test/calibration.jl @@ -3,8 +3,6 @@ using LaplaceRedux using Distributions using Trapz - - # Test for `sharpness_regression` function @testset "sharpness_regression distributions tests" begin @info " testing sharpness_regression with distributions" @@ -201,4 +199,4 @@ end distributions = [Normal(0, 1), Normal(2, 1), Normal(4, 1)] rescaled_distr = rescale_stddev(distributions, 2.0) @test rescaled_distr == [Normal(0, 2), Normal(2, 2), Normal(4, 2)] -end \ No newline at end of file +end diff --git a/test/data.jl b/test/data.jl index e6d61881..05a5b30d 100644 --- a/test/data.jl +++ b/test/data.jl @@ -15,7 +15,7 @@ for fun in fun_list # Generate data with the same seed Random.seed!(seed) xs1, ys1 = fun(N; seed=seed) - + Random.seed!(seed) xs2, ys2 = fun(N; seed=seed) From 33d84f570c1fcbbcd91c3ab3c7c0c80d10bdd8eb Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 21 Sep 2024 07:22:10 +0200 Subject: [PATCH 13/60] juliaformatter+docstrings --- src/direct_mlj.jl | 83 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 73758254..77409e16 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -43,6 +43,34 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist fit_prior_nsteps::Int = 100::(_ > 0) end +@doc """ + MMI.fit(m::LaplaceRegressor, verbosity, X, y) + +Fit a LaplaceRegressor model using the provided features and target values. + +# Arguments +- `m::LaplaceRegressor`: The LaplaceRegressor model to be fitted. +- `verbosity`: Verbosity level for logging. +- `X`: Input features, expected to be in a format compatible with MLJBase.matrix. +- `y`: Target values. + +# Returns +- `fitresult`: The fitted Laplace model. +- `cache`: Currently unused, returns `nothing`. +- `report`: A tuple containing the status and message of the fitting process. + +# Description +This function performs the following steps: +1. Converts the input features `X` to a matrix and transposes it. +2. Reshapes the target values `y` to shape (1,:). +3. Creates a data loader for batching the data. +4. Sets up the optimizer state using the Adam optimizer. +5. Trains the model for a specified number of epochs. +6. Initializes a Laplace model with the trained Flux model and specified parameters. +7. Fits the Laplace model using the data loader. +8. Optimizes the prior of the Laplace model. +9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success. +""" function MMI.fit(m::LaplaceRegressor, verbosity, X, y) #features = Tables.schema(X).names @@ -79,6 +107,22 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) return (fitresult, cache, report) end +@doc """ +function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) + + Predicts the response for new data using a fitted LaplaceRegressor model. + + # Arguments + - `m::LaplaceRegressor`: The LaplaceRegressor model. + - `fitresult`: The result of fitting the LaplaceRegressor model. + - `Xnew`: The new data for which predictions are to be made. + + # Returns + - An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`. + + The function first converts `Xnew` to a matrix and permutes its dimensions. It then uses the `LaplaceRedux.predict` function to obtain the predicted means and variances. +Finally, it creates Normal distributions from these means and variances and returns them as an array. +""" function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) Xnew = MLJBase.matrix(Xnew) |> permutedims la = fitresult @@ -132,6 +176,31 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis link_approx::Symbol = :probit::(_ in (:probit, :plugin)) end +@doc """ + + function MMI.fit(m::LaplaceClassifier, verbosity, X, y) + + Description: + This function fits a LaplaceClassifier model using the provided data. It first preprocesses the input data `X` and target labels `y`, + then trains a neural network model using the Flux library. After training, it fits a Laplace approximation to the trained model. + + Arguments: + - `m::LaplaceClassifier`: The LaplaceClassifier model to be fitted. + - `verbosity`: Verbosity level for logging. + - `X`: Input data features. + - `y`: Target labels. + + Returns: + - A tuple containing: + - `(la, decode)`: The fitted Laplace model and the decode function for the target labels. + - `cache`: A placeholder for any cached data (currently `nothing`). + - `report`: A report dictionary containing the status and message of the fitting process. + + Notes: + - The function uses the Flux library for neural network training and the LaplaceRedux library for fitting the Laplace approximation. + - The `optimize_prior!` function is called to optimize the prior parameters of the Laplace model. + +""" function MMI.fit(m::LaplaceClassifier, verbosity, X, y) X = MLJBase.matrix(X) |> permutedims decode = y[1] @@ -167,6 +236,20 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) return ((la, decode), cache, report) end +@doc """ +Predicts the class probabilities for new data using a Laplace classifier. + + # Arguments + - `m::LaplaceClassifier`: The Laplace classifier model. + - `(fitresult, decode)`: A tuple containing the fitted model result and the decode function. + - `Xnew`: The new data for which predictions are to be made. + + # Returns + - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data. + +The function transforms the new data `Xnew` into a matrix, applies the LaplaceRedux +prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object. +""" function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) la = fitresult Xnew = MLJBase.matrix(Xnew) |> permutedims From 9731297057c1324f4c6e9990d477435a0cb44fe3 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 21 Sep 2024 09:38:12 +0200 Subject: [PATCH 14/60] removed predict_proba and ret_Distr from the struct --- src/direct_mlj.jl | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 77409e16..3e166288 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -13,7 +13,9 @@ using Distributions: Normal A mutable struct representing a Laplace regression model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. It has the following Hyperparameters: -- `flux_model`: A flux model provided by the user and compatible with the dataset. +- `flux_model`: A Flux model provided by the user and compatible with the dataset. +- `flux_loss` : a Flux loss function +- `optimiser` = a Flux optimiser - `epochs`: The number of training epochs. - `batch_size`: The batch size. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. @@ -23,12 +25,12 @@ It has the following Hyperparameters: - `σ`: the standard deviation of the prior distribution. - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. -- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic flux_model::Flux.Chain = nothing flux_loss = Flux.Losses.mse + optimiser = Adam() epochs::Integer = 1000::(_ > 0) batch_size::Integer = 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) @@ -39,7 +41,6 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist σ::Float64 = 1.0 μ₀::Float64 = 0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - #ret_distr::Bool = false::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) end @@ -72,12 +73,11 @@ This function performs the following steps: 9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success. """ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) - #features = Tables.schema(X).names X = MLJBase.matrix(X) |> permutedims y = reshape(y, 1, :) data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) - opt_state = Flux.setup(Adam(), m.flux_model) + opt_state = Flux.setup(m.optimiser(), m.flux_model) for epoch in 1:(m.epochs) Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y @@ -139,11 +139,11 @@ end A mutable struct representing a Laplace Classification model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. - - The model also has the following parameters: -- `flux_model`: A flux model provided by the user and compatible with the dataset. +- `flux_model`: A Flux model provided by the user and compatible with the dataset. +- `flux_loss` : a Flux loss function +- `optimiser` = a Flux optimiser - `epochs`: The number of training epochs. - `batch_size`: The batch size. - `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. @@ -154,13 +154,12 @@ The model also has the following parameters: - `μ₀`: the mean of the prior distribution. - `P₀`: the covariance matrix of the prior distribution. - `link_approx`: the link approximation to use, either `:probit` or `:plugin`. -- `predict_proba`: a boolean that select whether to predict probabilities or not. -- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic flux_model::Flux.Chain = nothing flux_loss = Flux.Losses.logitcrossentropy + optimiser = Adam() epochs::Integer = 1000::(_ > 0) batch_size::Integer = 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) @@ -171,7 +170,6 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis σ::Float64 = 1.0 μ₀::Float64 = 0.0 P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - #ret_distr::Bool = false::(_ in (true, false)) fit_prior_nsteps::Int = 100::(_ > 0) link_approx::Symbol = :probit::(_ in (:probit, :plugin)) end @@ -207,7 +205,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) y_plain = MLJBase.int(y) .- 1 y_onehot = Flux.onehotbatch(y_plain, unique(y_plain)) data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) - opt_state = Flux.setup(Adam(), m.flux_model) + opt_state = Flux.setup(m.optimiser, m.flux_model) for epoch in 1:(m.epochs) Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y_onehot From f70d239adb8307492f40a6a172f02868c7de8ac8 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 21 Sep 2024 11:47:51 +0200 Subject: [PATCH 15/60] mlj docstring in progress --- src/direct_mlj.jl | 139 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 3e166288..37e0b4b0 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -277,3 +277,142 @@ MLJBase.metadata_model( target_scitype=AbstractArray{MLJBase.Continuous}, load_path="LaplaceRedux.LaplaceRegressor", ) + +const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]" + "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in " + "Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", *NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103"; + + """ +$(MMI.doc_header(LaplaceClassifier)) + +`LaplaceClassifier` implements the $DOC_CART. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with + + mach = machine(model, X, y) + +where + +- `X`: any table of input features (eg, a `DataFrame`) whose columns + each have one of the following element scitypes: `Continuous`, + `Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)` + +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype + with `scitype(y)` + +Train the machine using `fit!(mach, rows=...)`. + + +# Hyperparameters + +- `max_depth=-1`: max depth of the decision tree (-1=any) + +- `min_samples_leaf=1`: max number of samples each leaf needs to have + +- `min_samples_split=2`: min number of samples needed for a split + +- `min_purity_increase=0`: min purity needed for a split + +- `n_subfeatures=0`: number of features to select at random (0 for all) + +- `post_prune=false`: set to `true` for post-fit pruning + +- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having + combined purity `>= merge_purity_threshold` + +- `display_depth=5`: max depth to show when displaying the tree + +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, + :split)` + +- `rng=Random.GLOBAL_RNG`: random number generator or seed + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given + features `Xnew` having the same scitype as `X` above. Predictions + are probabilistic, but uncalibrated. + +- `predict_mode(mach, Xnew)`: instead return the mode of each + prediction above. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + +- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl + algorithm + +- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl + interface; see "Examples" below + +- `encoding`: dictionary of target classes keyed on integers used + internally by DecisionTree.jl + +- `features`: the names of the features encountered in training, in an + order consistent with the output of `print_tree` (see below) + + +# Report + +The fields of `report(mach)` are: + +- `classes_seen`: list of target classes actually observed in training + +- `print_tree`: alternative method to print the fitted + tree, with single argument the tree depth; interpretation requires + internal integer-class encoding (see "Fitted parameters" above). + +- `features`: the names of the features encountered in training, in an + order consistent with the output of `print_tree` (see below) + +# Accessor functions + +- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs; + the type of importance is determined by the hyperparameter `feature_importance` (see + above) + +# Examples + +``` +using MLJ +DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree +model = DecisionTreeClassifier(max_depth=3, min_samples_split=3) + +X, y = @load_iris +mach = machine(model, X, y) |> fit! + +Xnew = (sepal_length = [6.4, 7.2, 7.4], + sepal_width = [2.8, 3.0, 2.8], + petal_length = [5.6, 5.8, 6.1], + petal_width = [2.1, 1.6, 1.9],) +yhat = predict(mach, Xnew) # probabilistic predictions +predict_mode(mach, Xnew) # point predictions +pdf.(yhat, "virginica") # probabilities for the "verginica" class + +julia> tree = fitted_params(mach).tree +petal_length < 2.45 +├─ setosa (50/50) +└─ petal_width < 1.75 + ├─ petal_length < 4.95 + │ ├─ versicolor (47/48) + │ └─ virginica (4/6) + └─ petal_length < 4.85 + ├─ virginica (2/3) + └─ virginica (43/43) + +using Plots, TreeRecipe +plot(tree) # for a graphical representation of the tree + +feature_importances(mach) +``` + +See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl). + +""" +LaplaceClassifier \ No newline at end of file From 80c65539e09487d9f826b3fc8824a53cfea1a16c Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 21 Sep 2024 11:53:29 +0200 Subject: [PATCH 16/60] ah fixed constant , added prototype for regression --- src/direct_mlj.jl | 145 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 140 insertions(+), 5 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 37e0b4b0..4ecc1702 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -278,14 +278,14 @@ MLJBase.metadata_model( load_path="LaplaceRedux.LaplaceRegressor", ) -const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]" - "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in " - "Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", *NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103"; +const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]"* + "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in "* + "Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103" """ $(MMI.doc_header(LaplaceClassifier)) -`LaplaceClassifier` implements the $DOC_CART. +`LaplaceClassifier` implements the $DOC_LAPLACE_REDUX for classification models. # Training data @@ -415,4 +415,139 @@ feature_importances(mach) See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl). """ -LaplaceClassifier \ No newline at end of file +LaplaceClassifier + +""" +$(MMI.doc_header(LaplaceRegressor)) + +`LaplaceRegressor` implements the $DOC_LAPLACE_REDUX for regression models. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with + + mach = machine(model, X, y) + +where + +- `X`: any table of input features (eg, a `DataFrame`) whose columns + each have one of the following element scitypes: `Continuous`, + `Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)` + +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype + with `scitype(y)` + +Train the machine using `fit!(mach, rows=...)`. + + +# Hyperparameters + +- `max_depth=-1`: max depth of the decision tree (-1=any) + +- `min_samples_leaf=1`: max number of samples each leaf needs to have + +- `min_samples_split=2`: min number of samples needed for a split + +- `min_purity_increase=0`: min purity needed for a split + +- `n_subfeatures=0`: number of features to select at random (0 for all) + +- `post_prune=false`: set to `true` for post-fit pruning + +- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having + combined purity `>= merge_purity_threshold` + +- `display_depth=5`: max depth to show when displaying the tree + +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, + :split)` + +- `rng=Random.GLOBAL_RNG`: random number generator or seed + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given + features `Xnew` having the same scitype as `X` above. Predictions + are probabilistic, but uncalibrated. + +- `predict_mode(mach, Xnew)`: instead return the mode of each + prediction above. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + +- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl + algorithm + +- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl + interface; see "Examples" below + +- `encoding`: dictionary of target classes keyed on integers used + internally by DecisionTree.jl + +- `features`: the names of the features encountered in training, in an + order consistent with the output of `print_tree` (see below) + + +# Report + +The fields of `report(mach)` are: + +- `classes_seen`: list of target classes actually observed in training + +- `print_tree`: alternative method to print the fitted + tree, with single argument the tree depth; interpretation requires + internal integer-class encoding (see "Fitted parameters" above). + +- `features`: the names of the features encountered in training, in an + order consistent with the output of `print_tree` (see below) + +# Accessor functions + +- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs; + the type of importance is determined by the hyperparameter `feature_importance` (see + above) + +# Examples + +``` +using MLJ +DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree +model = DecisionTreeClassifier(max_depth=3, min_samples_split=3) + +X, y = @load_iris +mach = machine(model, X, y) |> fit! + +Xnew = (sepal_length = [6.4, 7.2, 7.4], + sepal_width = [2.8, 3.0, 2.8], + petal_length = [5.6, 5.8, 6.1], + petal_width = [2.1, 1.6, 1.9],) +yhat = predict(mach, Xnew) # probabilistic predictions +predict_mode(mach, Xnew) # point predictions +pdf.(yhat, "virginica") # probabilities for the "verginica" class + +julia> tree = fitted_params(mach).tree +petal_length < 2.45 +├─ setosa (50/50) +└─ petal_width < 1.75 + ├─ petal_length < 4.95 + │ ├─ versicolor (47/48) + │ └─ virginica (4/6) + └─ petal_length < 4.85 + ├─ virginica (2/3) + └─ virginica (43/43) + +using Plots, TreeRecipe +plot(tree) # for a graphical representation of the tree + +feature_importances(mach) +``` + +See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl). + +""" +LaplaceRegressor \ No newline at end of file From d1c895c128cba4088bd9a4af4da1da938c9b3f6c Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 21 Sep 2024 16:13:44 +0200 Subject: [PATCH 17/60] small stuff here and there in the docstring plus fixed a small mistake in the optimiser --- src/direct_mlj.jl | 58 ++++++++++++++++++----------------------------- 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 4ecc1702..bc1826ab 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -77,7 +77,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) X = MLJBase.matrix(X) |> permutedims y = reshape(y, 1, :) data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) - opt_state = Flux.setup(m.optimiser(), m.flux_model) + opt_state = Flux.setup(m.optimiser, m.flux_model) for epoch in 1:(m.epochs) Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y @@ -381,10 +381,21 @@ The fields of `report(mach)` are: ``` using MLJ -DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree -model = DecisionTreeClassifier(max_depth=3, min_samples_split=3) +LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux X, y = @load_iris + +# Define the Flux Chain model +using Flux +flux_model = Chain( + Dense(4, 10, relu), + Dense(10, 10, relu), + Dense(10, 3) +) + +#Define the LaplaceClassifier +model = LaplaceClassifier(flux_model=flux_model) + mach = machine(model, X, y) |> fit! Xnew = (sepal_length = [6.4, 7.2, 7.4], @@ -395,19 +406,8 @@ yhat = predict(mach, Xnew) # probabilistic predictions predict_mode(mach, Xnew) # point predictions pdf.(yhat, "virginica") # probabilities for the "verginica" class -julia> tree = fitted_params(mach).tree -petal_length < 2.45 -├─ setosa (50/50) -└─ petal_width < 1.75 - ├─ petal_length < 4.95 - │ ├─ versicolor (47/48) - │ └─ virginica (4/6) - └─ petal_length < 4.85 - ├─ virginica (2/3) - └─ virginica (43/43) +julia> tree = fitted_params(mach) -using Plots, TreeRecipe -plot(tree) # for a graphical representation of the tree feature_importances(mach) ``` @@ -516,33 +516,19 @@ The fields of `report(mach)` are: ``` using MLJ -DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree -model = DecisionTreeClassifier(max_depth=3, min_samples_split=3) +LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux +model = LaplaceRegressor(flux_model=flux_model) -X, y = @load_iris +X, y = make_regression(100, 2; noise=0.5, sparse=0.2, outliers=0.1) mach = machine(model, X, y) |> fit! -Xnew = (sepal_length = [6.4, 7.2, 7.4], - sepal_width = [2.8, 3.0, 2.8], - petal_length = [5.6, 5.8, 6.1], - petal_width = [2.1, 1.6, 1.9],) +Xnew, _ = make_regression(3, 2; rng=123) yhat = predict(mach, Xnew) # probabilistic predictions predict_mode(mach, Xnew) # point predictions -pdf.(yhat, "virginica") # probabilities for the "verginica" class -julia> tree = fitted_params(mach).tree -petal_length < 2.45 -├─ setosa (50/50) -└─ petal_width < 1.75 - ├─ petal_length < 4.95 - │ ├─ versicolor (47/48) - │ └─ virginica (4/6) - └─ petal_length < 4.85 - ├─ virginica (2/3) - └─ virginica (43/43) - -using Plots, TreeRecipe -plot(tree) # for a graphical representation of the tree +julia> tree = fitted_params(mach) + + feature_importances(mach) ``` From 19ffa1667d38a83f5ddeac23ec3a2396f297f599 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 21 Sep 2024 17:48:08 +0200 Subject: [PATCH 18/60] still writing this long ass docstring --- src/direct_mlj.jl | 127 ++++++++++++++++------------------------------ 1 file changed, 45 insertions(+), 82 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index bc1826ab..5e72712c 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -306,30 +306,35 @@ where Train the machine using `fit!(mach, rows=...)`. -# Hyperparameters +# Hyperparameters (format: name-type-default value-restrictions) -- `max_depth=-1`: max depth of the decision tree (-1=any) +- `flux_model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. -- `min_samples_leaf=1`: max number of samples each leaf needs to have +- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function -- `min_samples_split=2`: min number of samples needed for a split +- `optimiser = Adam()` a Flux optimiser -- `min_purity_increase=0`: min purity needed for a split +- `epochs::Integer = 1000::(_ > 0)`: the number of training epochs. -- `n_subfeatures=0`: number of features to select at random (0 for all) +- `batch_size::Integer = 32::(_ > 0)`: the batch size. -- `post_prune=false`: set to `true` for post-fit pruning +- `subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. -- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having - combined purity `>= merge_purity_threshold` +- `subnetwork_indices = nothing`: the indices of the subnetworks. -- `display_depth=5`: max depth to show when displaying the tree +- `hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal))`: the structure of the Hessian matrix, either `:full` or `:diagonal`. -- `feature_importance`: method to use for computing feature importances. One of `(:impurity, - :split)` +- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `rng=Random.GLOBAL_RNG`: random number generator or seed +- `σ::Float64 = 1.0`: the standard deviation of the prior distribution. +- `μ₀::Float64 = 0.0`: the mean of the prior distribution. + +- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution. + +- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors. + +- `link_approx::Symbol = :probit::(_ in (:probit, :plugin))`: the approximation to adopt to compute the probabilities. # Operations @@ -348,34 +353,8 @@ The fields of `fitted_params(mach)` are: - `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl algorithm -- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl - interface; see "Examples" below - -- `encoding`: dictionary of target classes keyed on integers used - internally by DecisionTree.jl - -- `features`: the names of the features encountered in training, in an - order consistent with the output of `print_tree` (see below) - - -# Report - -The fields of `report(mach)` are: - -- `classes_seen`: list of target classes actually observed in training - -- `print_tree`: alternative method to print the fitted - tree, with single argument the tree depth; interpretation requires - internal integer-class encoding (see "Fitted parameters" above). - -- `features`: the names of the features encountered in training, in an - order consistent with the output of `print_tree` (see below) - # Accessor functions -- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs; - the type of importance is determined by the hyperparameter `feature_importance` (see - above) # Examples @@ -408,8 +387,6 @@ pdf.(yhat, "virginica") # probabilities for the "verginica" class julia> tree = fitted_params(mach) - -feature_importances(mach) ``` See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl). @@ -435,35 +412,39 @@ where `Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)` - `y`: is the target, which can be any `AbstractVector` whose element - scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype + scitype is `<:Continuous`; check the scitype with `scitype(y)` Train the machine using `fit!(mach, rows=...)`. -# Hyperparameters +# Hyperparameters (format: name-type-default value-restrictions) + +- `flux_model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. -- `max_depth=-1`: max depth of the decision tree (-1=any) +- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function -- `min_samples_leaf=1`: max number of samples each leaf needs to have +- `optimiser = Adam()` a Flux optimiser -- `min_samples_split=2`: min number of samples needed for a split +- `epochs::Integer = 1000::(_ > 0)`: the number of training epochs. -- `min_purity_increase=0`: min purity needed for a split +- `batch_size::Integer = 32::(_ > 0)`: the batch size. -- `n_subfeatures=0`: number of features to select at random (0 for all) +- `subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. -- `post_prune=false`: set to `true` for post-fit pruning +- `subnetwork_indices = nothing`: the indices of the subnetworks. -- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having - combined purity `>= merge_purity_threshold` +- `hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal))`: the structure of the Hessian matrix, either `:full` or `:diagonal`. -- `display_depth=5`: max depth to show when displaying the tree +- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `feature_importance`: method to use for computing feature importances. One of `(:impurity, - :split)` +- `σ::Float64 = 1.0`: the standard deviation of the prior distribution. -- `rng=Random.GLOBAL_RNG`: random number generator or seed +- `μ₀::Float64 = 0.0`: the mean of the prior distribution. + +- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution. + +- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors. # Operations @@ -483,46 +464,28 @@ The fields of `fitted_params(mach)` are: - `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl algorithm -- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl - interface; see "Examples" below - -- `encoding`: dictionary of target classes keyed on integers used - internally by DecisionTree.jl - -- `features`: the names of the features encountered in training, in an - order consistent with the output of `print_tree` (see below) - - -# Report - -The fields of `report(mach)` are: - -- `classes_seen`: list of target classes actually observed in training - -- `print_tree`: alternative method to print the fitted - tree, with single argument the tree depth; interpretation requires - internal integer-class encoding (see "Fitted parameters" above). - -- `features`: the names of the features encountered in training, in an - order consistent with the output of `print_tree` (see below) # Accessor functions -- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs; - the type of importance is determined by the hyperparameter `feature_importance` (see - above) + # Examples ``` using MLJ +using Flux LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux +flux_model = Chain( + Dense(4, 10, relu), + Dense(10, 10, relu), + Dense(10, 1) +) model = LaplaceRegressor(flux_model=flux_model) -X, y = make_regression(100, 2; noise=0.5, sparse=0.2, outliers=0.1) +X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) mach = machine(model, X, y) |> fit! -Xnew, _ = make_regression(3, 2; rng=123) +Xnew, _ = make_regression(3, 4; rng=123) yhat = predict(mach, Xnew) # probabilistic predictions predict_mode(mach, Xnew) # point predictions From de0bd91ac850adcf6595fe8a8952c31705dce0c6 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 22 Sep 2024 06:32:37 +0200 Subject: [PATCH 19/60] added fit_params functions --- src/direct_mlj.jl | 123 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 119 insertions(+), 4 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 5e72712c..4bd86eb0 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -107,6 +107,48 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) return (fitresult, cache, report) end + +@doc """ + + function MMI.fitted_params(model::LaplaceRegressor, fitresult) + + + This function extracts the fitted parameters from a `LaplaceRegressor` model. + + # Arguments + - `model::LaplaceRegressor`: The Laplace regression model. + - `fitresult`: the Laplace approximation (`la`). + + # Returns + A named tuple containing: + - `μ`: The mean of the posterior distribution. + - `H`: The Hessian of the posterior distribution. + - `P`: The precision matrix of the posterior distribution. + - `Σ`: The covariance matrix of the posterior distribution. + - `n_data`: The number of data points. + - `n_params`: The number of parameters. + - `n_out`: The number of outputs. + - `loss`: The loss value of the posterior distribution. + +""" +function MMI.fitted_params(model::LaplaceRegressor, fitresult) + la = fitresult + posterior = la.posterior + return ( + μ = posterior.μ, + H = posterior.H, + P = posterior.P, + Σ = posterior.Σ, + n_data = posterior.n_data, + n_params = posterior.n_params, + n_out = posterior.n_out, + loss = posterior.loss + ) +end + + + + @doc """ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) @@ -234,6 +276,53 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) return ((la, decode), cache, report) end + + + + + +@doc """ + + function MMI.fitted_params(model::LaplaceClassifier, fitresult) + + + This function extracts the fitted parameters from a `LaplaceClassifier` model. + + # Arguments + - `model::LaplaceClassifier`: The Laplace classifier model. + - `fitresult`: A tuple containing the Laplace approximation (`la`) and a decode function. + + # Returns + A named tuple containing: + - `μ`: The mean of the posterior distribution. + - `H`: The Hessian of the posterior distribution. + - `P`: The precision matrix of the posterior distribution. + - `Σ`: The covariance matrix of the posterior distribution. + - `n_data`: The number of data points. + - `n_params`: The number of parameters. + - `n_out`: The number of outputs. + - `loss`: The loss value of the posterior distribution. + +""" +function MMI.fitted_params(model::LaplaceClassifier, fitresult) + la, decode = fitresult + posterior = la.posterior + return ( + μ = posterior.μ, + H = posterior.H, + P = posterior.P, + Σ = posterior.Σ, + n_data = posterior.n_data, + n_params = posterior.n_params, + n_out = posterior.n_out, + loss = posterior.loss + ) +end + + + + + @doc """ Predicts the class probabilities for new data using a Laplace classifier. @@ -350,8 +439,21 @@ Train the machine using `fit!(mach, rows=...)`. The fields of `fitted_params(mach)` are: -- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl - algorithm + - `μ`: The mean of the posterior distribution. + + - `H`: The Hessian of the posterior distribution. + + - `P`: The precision matrix of the posterior distribution. + + - `Σ`: The covariance matrix of the posterior distribution. + + - `n_data`: The number of data points. + + - `n_params`: The number of parameters. + + - `n_out`: The number of outputs. + + - `loss`: The loss value of the posterior distribution. # Accessor functions @@ -461,8 +563,21 @@ Train the machine using `fit!(mach, rows=...)`. The fields of `fitted_params(mach)` are: -- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl - algorithm + - `μ`: The mean of the posterior distribution. + + - `H`: The Hessian of the posterior distribution. + + - `P`: The precision matrix of the posterior distribution. + + - `Σ`: The covariance matrix of the posterior distribution. + + - `n_data`: The number of data points. + + - `n_params`: The number of parameters. + + - `n_out`: The number of outputs. + + - `loss`: The loss value of the posterior distribution. # Accessor functions From 87df85f7c7f931f25710c9ec7fcb345380808620 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 22 Sep 2024 14:52:07 +0200 Subject: [PATCH 20/60] switched to customized loop --- src/direct_mlj.jl | 83 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 13 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 4bd86eb0..9e6e5921 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -78,10 +78,40 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) y = reshape(y, 1, :) data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) opt_state = Flux.setup(m.optimiser, m.flux_model) + loss_history=[] for epoch in 1:(m.epochs) - Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y - m.flux_loss(model(X), y) + + loss_per_epoch= 0.0 + + + for (X_batch, y_batch) in data_loader + # Forward pass: compute predictions + y_pred = m.flux_model(X_batch) + + # Compute loss + loss = m.flux_loss(y_pred, y_batch) + + # Compute gradients explicitly + grads = gradient(m.flux_model) do model + # Recompute predictions inside gradient context + y_pred = model(X_batch) + m.flux_loss(y_pred, y_batch) + end + + # Update parameters using the optimizer and computed gradients + Flux.Optimise.update!(opt_state ,m.flux_model , grads[1]) + + # Accumulate the loss for this batch + loss_per_epoch += sum(loss) # Summing the batch loss + + end + + push!(loss_history,loss_per_epoch ) + + # Print loss every 100 epochs if verbosity is 1 or more + if verbosity >= 1 && epoch % 100 == 0 + println("Epoch $epoch: Loss: $loss_per_epoch ") end end @@ -102,7 +132,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) fitresult = la - report = (status="success", message="Model fitted successfully") + report = (status="success", loss_history = loss_history) cache = nothing return (fitresult, cache, report) end @@ -243,17 +273,44 @@ end """ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) X = MLJBase.matrix(X) |> permutedims - decode = y[1] - y_plain = MLJBase.int(y) .- 1 - y_onehot = Flux.onehotbatch(y_plain, unique(y_plain)) - data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) - opt_state = Flux.setup(m.optimiser, m.flux_model) +# Store the first label as decode function +decode = y[1] + +# Convert labels to integer format starting from 0 for one-hot encoding +y_plain = MLJBase.int(y) .- 1 + +# One-hot encoding of labels +unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding +y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding + +# Create a data loader for batching the data +data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) + +# Set up the optimizer for the model +opt_state = Flux.setup(m.optimiser, m.flux_model) +loss_history = [] + +# Training loop for the specified number of epochs +for epoch in 1:m.epochs + loss_per_epoch = 0.0 # Initialize loss for the current epoch + + # Training function for each batch in the data loader + Flux.train!(m.flux_model, data_loader, opt_state) do model, X_batch, y_batch + # Forward pass and compute the loss + loss = m.flux_loss(model(X_batch), y_batch) + + # Accumulate the loss for the current epoch + loss_per_epoch += sum(loss) + end + + # Record loss history for analysis + push!(loss_history, loss_per_epoch) - for epoch in 1:(m.epochs) - Flux.train!(m.flux_model, data_loader, opt_state) do model, X, y_onehot - m.flux_loss(model(X), y_onehot) - end + # Verbosity: print loss every 100 epochs + if verbosity >= 1 && epoch % 100 == 0 + println("Epoch $epoch: Loss: $loss_per_epoch") end +end la = LaplaceRedux.Laplace( m.flux_model; @@ -271,7 +328,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) LaplaceRedux.fit!(la, data_loader) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - report = (status="success", message="Model fitted successfully") + report = (status="success", loss_history = loss_history) cache = nothing return ((la, decode), cache, report) end From 24459a1b807230e529c93fe80dc14868975f4e45 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 22 Sep 2024 15:41:59 +0200 Subject: [PATCH 21/60] fixed error in custom loop --- src/direct_mlj.jl | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 9e6e5921..28407f4b 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -294,14 +294,24 @@ loss_history = [] for epoch in 1:m.epochs loss_per_epoch = 0.0 # Initialize loss for the current epoch - # Training function for each batch in the data loader - Flux.train!(m.flux_model, data_loader, opt_state) do model, X_batch, y_batch - # Forward pass and compute the loss - loss = m.flux_loss(model(X_batch), y_batch) + for (X_batch, y_batch) in data_loader + # Compute gradients explicitly + grads = gradient(m.flux_model) do model + # Recompute predictions inside gradient context + y_pred = model(X_batch) + m.flux_loss(y_pred, y_batch) + end + + # Update the model parameters using the computed gradients + Flux.Optimise.update!(opt_state, Flux.params(m.flux_model), grads) + + # Compute the loss for this batch + loss = m.flux_loss(m.flux_model(X_batch), y_batch) # Accumulate the loss for the current epoch loss_per_epoch += sum(loss) end + # Record loss history for analysis push!(loss_history, loss_per_epoch) From 0e2ca0330f35378e09a7696b980ccac34beb6945 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 22 Sep 2024 16:16:46 +0200 Subject: [PATCH 22/60] various fixes --- src/direct_mlj.jl | 100 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 80 insertions(+), 20 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 28407f4b..87171553 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -13,7 +13,7 @@ using Distributions: Normal A mutable struct representing a Laplace regression model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. It has the following Hyperparameters: -- `flux_model`: A Flux model provided by the user and compatible with the dataset. +- `model`: A Flux model provided by the user and compatible with the dataset. - `flux_loss` : a Flux loss function - `optimiser` = a Flux optimiser - `epochs`: The number of training epochs. @@ -28,7 +28,7 @@ It has the following Hyperparameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic - flux_model::Flux.Chain = nothing + model::Flux.Chain = nothing flux_loss = Flux.Losses.mse optimiser = Adam() epochs::Integer = 1000::(_ > 0) @@ -77,7 +77,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) X = MLJBase.matrix(X) |> permutedims y = reshape(y, 1, :) data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) - opt_state = Flux.setup(m.optimiser, m.flux_model) + opt_state = Flux.setup(m.optimiser, m.model) loss_history=[] for epoch in 1:(m.epochs) @@ -87,20 +87,20 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) for (X_batch, y_batch) in data_loader # Forward pass: compute predictions - y_pred = m.flux_model(X_batch) + y_pred = m.model(X_batch) # Compute loss loss = m.flux_loss(y_pred, y_batch) # Compute gradients explicitly - grads = gradient(m.flux_model) do model + grads = gradient(m.model) do model # Recompute predictions inside gradient context y_pred = model(X_batch) m.flux_loss(y_pred, y_batch) end # Update parameters using the optimizer and computed gradients - Flux.Optimise.update!(opt_state ,m.flux_model , grads[1]) + Flux.Optimise.update!(opt_state ,m.model , grads[1]) # Accumulate the loss for this batch loss_per_epoch += sum(loss) # Summing the batch loss @@ -116,7 +116,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) end la = LaplaceRedux.Laplace( - m.flux_model; + m.model; likelihood=:regression, subset_of_weights=m.subset_of_weights, subnetwork_indices=m.subnetwork_indices, @@ -179,6 +179,26 @@ end + +@doc """ + MMI.training_losses(model::LaplaceRegressor, report) + +Retrieve the training loss history from the given `report`. + +# Arguments +- `model::LaplaceRegressor`: The model for which the training losses are being retrieved. +- `report`: An object containing the training report, which includes the loss history. + +# Returns +- A collection representing the loss history from the training report. +""" +function MMI.training_losses(model::LaplaceRegressor, report) + return report.loss_history +end + + + + @doc """ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) @@ -213,7 +233,7 @@ A mutable struct representing a Laplace Classification model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. The model also has the following parameters: -- `flux_model`: A Flux model provided by the user and compatible with the dataset. +- `model`: A Flux model provided by the user and compatible with the dataset. - `flux_loss` : a Flux loss function - `optimiser` = a Flux optimiser - `epochs`: The number of training epochs. @@ -229,7 +249,7 @@ The model also has the following parameters: - `fit_prior_nsteps`: the number of steps used to fit the priors. """ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic - flux_model::Flux.Chain = nothing + model::Flux.Chain = nothing flux_loss = Flux.Losses.logitcrossentropy optimiser = Adam() epochs::Integer = 1000::(_ > 0) @@ -287,7 +307,7 @@ y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) # Set up the optimizer for the model -opt_state = Flux.setup(m.optimiser, m.flux_model) +opt_state = Flux.setup(m.optimiser, m.model) loss_history = [] # Training loop for the specified number of epochs @@ -296,17 +316,17 @@ for epoch in 1:m.epochs for (X_batch, y_batch) in data_loader # Compute gradients explicitly - grads = gradient(m.flux_model) do model + grads = gradient(m.model) do model # Recompute predictions inside gradient context y_pred = model(X_batch) m.flux_loss(y_pred, y_batch) end # Update the model parameters using the computed gradients - Flux.Optimise.update!(opt_state, Flux.params(m.flux_model), grads) + Flux.Optimise.update!(opt_state, Flux.params(m.model), grads) # Compute the loss for this batch - loss = m.flux_loss(m.flux_model(X_batch), y_batch) + loss = m.flux_loss(m.model(X_batch), y_batch) # Accumulate the loss for the current epoch loss_per_epoch += sum(loss) @@ -323,7 +343,7 @@ for epoch in 1:m.epochs end la = LaplaceRedux.Laplace( - m.flux_model; + m.model; likelihood=:classification, subset_of_weights=m.subset_of_weights, subnetwork_indices=m.subnetwork_indices, @@ -387,7 +407,21 @@ function MMI.fitted_params(model::LaplaceClassifier, fitresult) end +@doc """ + MMI.training_losses(model::LaplaceClassifier, report) + +Retrieve the training loss history from the given `report`. +# Arguments +- `model::LaplaceClassifier`: The model for which the training losses are being retrieved. +- `report`: An object containing the training report, which includes the loss history. + +# Returns +- A collection representing the loss history from the training report. +""" +function MMI.training_losses(model::LaplaceClassifier, report) + return report.loss_history +end @doc """ @@ -414,6 +448,29 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions) end + +MMI.metadata_pkg( + LaplaceRegressor, + name="LaplaceRedux", + package_uuid="??????", + package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", + is_pure_julia=true, + is_wrapper=true, + package_license = "https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE", +) + +MMI.metadata_pkg( + LaplaceClassifier, + name="LaplaceRedux", + package_uuid="dontknow", + package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", + is_pure_julia=true, + is_wrapper=true, + package_license = "https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE", +) + + + MLJBase.metadata_model( LaplaceClassifier; input_scitype=Union{ @@ -421,6 +478,7 @@ MLJBase.metadata_model( MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types }, target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass + supports_training_losses = true, load_path="LaplaceRedux.LaplaceClassifier", ) # metadata for each model, @@ -431,7 +489,9 @@ MLJBase.metadata_model( MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types }, target_scitype=AbstractArray{MLJBase.Continuous}, + supports_training_losses = true, load_path="LaplaceRedux.LaplaceRegressor", + ) const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]"* @@ -464,7 +524,7 @@ Train the machine using `fit!(mach, rows=...)`. # Hyperparameters (format: name-type-default value-restrictions) -- `flux_model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. +- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. - `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function @@ -535,14 +595,14 @@ X, y = @load_iris # Define the Flux Chain model using Flux -flux_model = Chain( +model = Chain( Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 3) ) #Define the LaplaceClassifier -model = LaplaceClassifier(flux_model=flux_model) +model = LaplaceClassifier(model=model) mach = machine(model, X, y) |> fit! @@ -589,7 +649,7 @@ Train the machine using `fit!(mach, rows=...)`. # Hyperparameters (format: name-type-default value-restrictions) -- `flux_model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. +- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. - `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function @@ -657,12 +717,12 @@ The fields of `fitted_params(mach)` are: using MLJ using Flux LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux -flux_model = Chain( +model = Chain( Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1) ) -model = LaplaceRegressor(flux_model=flux_model) +model = LaplaceRegressor(model=model) X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) mach = machine(model, X, y) |> fit! From 841d5ebf5898a56700534ef304d0aeee4550c36c Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 22 Sep 2024 18:08:02 +0200 Subject: [PATCH 23/60] added reformat. must updated again the doc string.... --- src/direct_mlj.jl | 168 ++++++++++++++++++++++++++++------------------ 1 file changed, 104 insertions(+), 64 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 87171553..690c248c 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -44,6 +44,14 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist fit_prior_nsteps::Int = 100::(_ > 0) end + +# for fit: +MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims,reshape(y, 1, :)) + +MMI.reformat(::LaplaceRegressor, X) = (MLJBase.matrix(X) |> permutedims,) + + + @doc """ MMI.fit(m::LaplaceRegressor, verbosity, X, y) @@ -74,8 +82,16 @@ This function performs the following steps: """ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) - X = MLJBase.matrix(X) |> permutedims + #X = MLJBase.matrix(X) |> permutedims + #y = reshape(y, 1, :) + + if Tables.istable(X) + X = Tables.matrix(X)|>permutedims + end + + # Reshape y if necessary y = reshape(y, 1, :) + data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) opt_state = Flux.setup(m.optimiser, m.model) loss_history=[] @@ -216,7 +232,10 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) Finally, it creates Normal distributions from these means and variances and returns them as an array. """ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) - Xnew = MLJBase.matrix(Xnew) |> permutedims + #Xnew = MLJBase.matrix(Xnew) |> permutedims + if Tables.istable(Xnew) + Xnew = Tables.matrix(Xnew)|>permutedims + end la = fitresult yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) # Extract mean and variance matrices @@ -266,6 +285,7 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis link_approx::Symbol = :probit::(_ in (:probit, :plugin)) end + @doc """ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) @@ -292,75 +312,85 @@ end """ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) - X = MLJBase.matrix(X) |> permutedims -# Store the first label as decode function -decode = y[1] - -# Convert labels to integer format starting from 0 for one-hot encoding -y_plain = MLJBase.int(y) .- 1 - -# One-hot encoding of labels -unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding -y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding - -# Create a data loader for batching the data -data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) - -# Set up the optimizer for the model -opt_state = Flux.setup(m.optimiser, m.model) -loss_history = [] - -# Training loop for the specified number of epochs -for epoch in 1:m.epochs - loss_per_epoch = 0.0 # Initialize loss for the current epoch - - for (X_batch, y_batch) in data_loader - # Compute gradients explicitly - grads = gradient(m.model) do model - # Recompute predictions inside gradient context - y_pred = model(X_batch) - m.flux_loss(y_pred, y_batch) - end + #X = MLJBase.matrix(X) |> permutedims + + if Tables.istable(X) + X = Tables.matrix(X)|>permutedims + end - # Update the model parameters using the computed gradients - Flux.Optimise.update!(opt_state, Flux.params(m.model), grads) - # Compute the loss for this batch - loss = m.flux_loss(m.model(X_batch), y_batch) + # Store the first label as decode function + decode = y[1] - # Accumulate the loss for the current epoch - loss_per_epoch += sum(loss) - end + # Convert labels to integer format starting from 0 for one-hot encoding + y_plain = MLJBase.int(y) .- 1 - - # Record loss history for analysis - push!(loss_history, loss_per_epoch) + # One-hot encoding of labels + unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding + y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding - # Verbosity: print loss every 100 epochs - if verbosity >= 1 && epoch % 100 == 0 - println("Epoch $epoch: Loss: $loss_per_epoch") - end -end + # Create a data loader for batching the data + data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) - la = LaplaceRedux.Laplace( - m.model; - likelihood=:classification, - subset_of_weights=m.subset_of_weights, - subnetwork_indices=m.subnetwork_indices, - hessian_structure=m.hessian_structure, - backend=m.backend, - σ=m.σ, - μ₀=m.μ₀, - P₀=m.P₀, - ) + # Set up the optimizer for the model + opt_state = Flux.setup(m.optimiser, m.model) + loss_history = [] - # fit the Laplace model: - LaplaceRedux.fit!(la, data_loader) - optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) + # Training loop for the specified number of epochs + for epoch in 1:(m.epochs) - report = (status="success", loss_history = loss_history) - cache = nothing - return ((la, decode), cache, report) + loss_per_epoch= 0.0 + + + for (X_batch, y_batch) in data_loader + # Forward pass: compute predictions + y_pred = m.model(X_batch) + + # Compute loss + loss = m.flux_loss(y_pred, y_batch) + + # Compute gradients explicitly + grads = gradient(m.model) do model + # Recompute predictions inside gradient context + y_pred = model(X_batch) + m.flux_loss(y_pred, y_batch) + end + + # Update parameters using the optimizer and computed gradients + Flux.Optimise.update!(opt_state ,m.model , grads[1]) + + # Accumulate the loss for this batch + loss_per_epoch += sum(loss) # Summing the batch loss + + end + + push!(loss_history,loss_per_epoch ) + + # Print loss every 100 epochs if verbosity is 1 or more + if verbosity >= 1 && epoch % 100 == 0 + println("Epoch $epoch: Loss: $loss_per_epoch ") + end + end + + la = LaplaceRedux.Laplace( + m.model; + likelihood=:classification, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, + σ=m.σ, + μ₀=m.μ₀, + P₀=m.P₀, + ) + + # fit the Laplace model: + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) + + report = (status="success", loss_history = loss_history) + cache = nothing + return ((la, decode), cache, report) end @@ -439,8 +469,11 @@ The function transforms the new data `Xnew` into a matrix, applies the LaplaceRe prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object. """ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) + if Tables.istable(Xnew) + Xnew = Tables.matrix(Xnew)|>permutedims + end la = fitresult - Xnew = MLJBase.matrix(Xnew) |> permutedims + #Xnew = MLJBase.matrix(Xnew) |> permutedims predictions = LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> permutedims @@ -449,6 +482,13 @@ function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) end + + +# for fit: +MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y) + +MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,) + MMI.metadata_pkg( LaplaceRegressor, name="LaplaceRedux", From de784f1c09fcd7dda9665fa3abde43aa26cb79f0 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sun, 22 Sep 2024 19:14:12 +0200 Subject: [PATCH 24/60] work on the docstring and then made it in a module --- src/direct_mlj.jl | 71 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 690c248c..dd33e3c0 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -1,3 +1,4 @@ +module MLJLaplaceRedux using Flux using Random using Tables @@ -66,11 +67,11 @@ Fit a LaplaceRegressor model using the provided features and target values. # Returns - `fitresult`: The fitted Laplace model. - `cache`: Currently unused, returns `nothing`. -- `report`: A tuple containing the status and message of the fitting process. +- `report`: A tuple containing the loss history of the fitting process. # Description This function performs the following steps: -1. Converts the input features `X` to a matrix and transposes it. +1. If X is a Table: converts the input features `X` to a matrix and transposes it. 2. Reshapes the target values `y` to shape (1,:). 3. Creates a data loader for batching the data. 4. Sets up the optimizer state using the Adam optimizer. @@ -78,7 +79,7 @@ This function performs the following steps: 6. Initializes a Laplace model with the trained Flux model and specified parameters. 7. Fits the Laplace model using the data loader. 8. Optimizes the prior of the Laplace model. -9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success. +9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report listing training related statistics. """ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) @@ -95,6 +96,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) opt_state = Flux.setup(m.optimiser, m.model) loss_history=[] + push!(loss_history, m.flux_loss(m.model(X), y )) for epoch in 1:(m.epochs) @@ -148,7 +150,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) fitresult = la - report = (status="success", loss_history = loss_history) + report = (loss_history = loss_history,) cache = nothing return (fitresult, cache, report) end @@ -304,11 +306,7 @@ end - A tuple containing: - `(la, decode)`: The fitted Laplace model and the decode function for the target labels. - `cache`: A placeholder for any cached data (currently `nothing`). - - `report`: A report dictionary containing the status and message of the fitting process. - - Notes: - - The function uses the Flux library for neural network training and the LaplaceRedux library for fitting the Laplace approximation. - - The `optimize_prior!` function is called to optimize the prior parameters of the Laplace model. + - `report`: A NamedTuple containing statistics related to the fitting process. """ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) @@ -388,7 +386,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) LaplaceRedux.fit!(la, data_loader) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - report = (status="success", loss_history = loss_history) + report = (loss_history = loss_history,) cache = nothing return ((la, decode), cache, report) end @@ -545,9 +543,13 @@ $(MMI.doc_header(LaplaceClassifier)) # Training data -In MLJ or MLJBase, bind an instance `model` to data with +In MLJ or MLJBase, given a dataset X,y and a Flux Chain adapt to the dataset, pass the chain to the model + +laplace_model = LaplaceClassifier(model = Flux_Chain,kwargs...) + +then bind an instance `laplace_model` to data with - mach = machine(model, X, y) + mach = machine(laplace_model, X, y) where @@ -601,6 +603,8 @@ Train the machine using `fit!(mach, rows=...)`. - `predict_mode(mach, Xnew)`: instead return the mode of each prediction above. +- `training_losses(mach)`: return the loss history from report + # Fitted parameters @@ -622,6 +626,14 @@ The fields of `fitted_params(mach)` are: - `loss`: The loss value of the posterior distribution. + + + # Report + +The fields of `report(mach)` are: + +- `loss_history`: an array containing the total loss per epoch. + # Accessor functions @@ -652,9 +664,9 @@ Xnew = (sepal_length = [6.4, 7.2, 7.4], petal_width = [2.1, 1.6, 1.9],) yhat = predict(mach, Xnew) # probabilistic predictions predict_mode(mach, Xnew) # point predictions +training_losses(mach) # loss history per epoch pdf.(yhat, "virginica") # probabilities for the "verginica" class - -julia> tree = fitted_params(mach) +fitted_params(mach) # NamedTuple with the fitted params of Laplace ``` @@ -670,10 +682,13 @@ $(MMI.doc_header(LaplaceRegressor)) # Training data -In MLJ or MLJBase, bind an instance `model` to data with +In MLJ or MLJBase, given a dataset X,y and a Flux Chain adapt to the dataset, pass the chain to the model + +laplace_model = LaplaceRegressor(model = Flux_Chain,kwargs...) - mach = machine(model, X, y) +then bind an instance `laplace_model` to data with + mach = machine(laplace_model, X, y) where - `X`: any table of input features (eg, a `DataFrame`) whose columns @@ -689,7 +704,7 @@ Train the machine using `fit!(mach, rows=...)`. # Hyperparameters (format: name-type-default value-restrictions) -- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. +- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. - `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function @@ -725,6 +740,8 @@ Train the machine using `fit!(mach, rows=...)`. - `predict_mode(mach, Xnew)`: instead return the mode of each prediction above. +- `training_losses(mach)`: return the loss history from report + # Fitted parameters @@ -747,6 +764,15 @@ The fields of `fitted_params(mach)` are: - `loss`: The loss value of the posterior distribution. +# Report + +The fields of `report(mach)` are: + +- `loss_history`: an array containing the total loss per epoch. + + + + # Accessor functions @@ -770,15 +796,14 @@ mach = machine(model, X, y) |> fit! Xnew, _ = make_regression(3, 4; rng=123) yhat = predict(mach, Xnew) # probabilistic predictions predict_mode(mach, Xnew) # point predictions +training_losses(mach) # loss history per epoch +fitted_params(mach) # NamedTuple with the fitted params of Laplace -julia> tree = fitted_params(mach) - - - -feature_importances(mach) ``` See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl). """ -LaplaceRegressor \ No newline at end of file +LaplaceRegressor + +end # module \ No newline at end of file From b7a99f6e40eb216dec08c34edd19a614e7d265de Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Mon, 23 Sep 2024 17:42:32 +0200 Subject: [PATCH 25/60] fixed uuid, made test file.for direct_mlj. shut down the tests for mljflux. --- src/direct_mlj.jl | 14 +++--- test/direct_mlj_interface.jl | 98 ++++++++++++++++++++++++++++++++++++ test/runtests.jl | 8 ++- 3 files changed, 111 insertions(+), 9 deletions(-) create mode 100644 test/direct_mlj_interface.jl diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index dd33e3c0..30c57045 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -1,4 +1,4 @@ -module MLJLaplaceRedux +#module MLJLaplaceRedux using Flux using Random using Tables @@ -304,7 +304,7 @@ end Returns: - A tuple containing: - - `(la, decode)`: The fitted Laplace model and the decode function for the target labels. + - `(la, y[1])`: The fitted Laplace model and the decode function for the target labels. - `cache`: A placeholder for any cached data (currently `nothing`). - `report`: A NamedTuple containing statistics related to the fitting process. @@ -318,7 +318,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) # Store the first label as decode function - decode = y[1] + #decode = y[1] # Convert labels to integer format starting from 0 for one-hot encoding y_plain = MLJBase.int(y) .- 1 @@ -388,7 +388,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) report = (loss_history = loss_history,) cache = nothing - return ((la, decode), cache, report) + return ((la, y[1]), cache, report) end @@ -490,7 +490,7 @@ MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,) MMI.metadata_pkg( LaplaceRegressor, name="LaplaceRedux", - package_uuid="??????", + package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478", package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", is_pure_julia=true, is_wrapper=true, @@ -500,7 +500,7 @@ MMI.metadata_pkg( MMI.metadata_pkg( LaplaceClassifier, name="LaplaceRedux", - package_uuid="dontknow", + package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478", package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", is_pure_julia=true, is_wrapper=true, @@ -806,4 +806,4 @@ See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl """ LaplaceRegressor -end # module \ No newline at end of file +#end # module \ No newline at end of file diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl new file mode 100644 index 00000000..13c91a21 --- /dev/null +++ b/test/direct_mlj_interface.jl @@ -0,0 +1,98 @@ +using Random: Random +import Random.seed! +using MLJBase: MLJBase, categorical +using Flux +using StableRNGs + + +@testset "Regression" begin + function basictest_regression(X, y, builder, optimiser, threshold) + optimiser = deepcopy(optimiser) + + stable_rng = StableRNGs.StableRNG(123) + + model = LaplaceRegression(; + builder=builder, + optimiser=optimiser, + acceleration=MLJBase.CPUThreads(), + loss=Flux.Losses.mse, + rng=stable_rng, + lambda=-1.0, + alpha=-1.0, + epochs=-1, + batch_size=-1, + subset_of_weights=:incorrect, + hessian_structure=:incorrect, + backend=:incorrect, + ret_distr=true, + ) + + fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + + history = _report.training_losses + @test length(history) == model.epochs + 1 + + # increase iterations and check update is incremental: + model.epochs = model.epochs + 3 + + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + @test :chain in keys(MLJBase.fitted_params(model, fitresult)) + + history = _report.training_losses + @test length(history) == model.epochs + 1 + + yhat = MLJBase.predict(model, fitresult, X) + + # start fresh with small epochs: + model = LaplaceRegression(; + builder=builder, + optimiser=optimiser, + epochs=2, + acceleration=CPU1(), + rng=stable_rng, + ) + + fitresult, cache, _report = MLJBase.fit(model, 0, X, y) + + # change batch_size and check it performs cold restart: + model.batch_size = 2 + fitresult, cache, _report = @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + # change learning rate and check it does *not* restart: + model.optimiser.eta /= 2 + fitresult, cache, _report = @test_logs( + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + # set `optimiser_changes_trigger_retraining = true` and change + # learning rate and check it does restart: + model.optimiser_changes_trigger_retraining = true + model.optimiser.eta /= 2 + @test_logs( + (:info, r""), # one line of :info per extra epoch + (:info, r""), + MLJBase.update(model, 2, fitresult, cache, X, y) + ) + + return true + end + + seed!(1234) + N = 300 + X = MLJBase.table(rand(Float32, N, 4)) + ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) + builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) + optimiser = Flux.Optimise.Adam(0.03) + + @test basictest_regression(X, ycont, builder, optimiser, 0.9) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0459cf5a..4995c6f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,7 +35,11 @@ using Test include("krondecomposed.jl") end - @testset "MLJFlux" begin - include("mlj_flux_interfacing.jl") + #@testset "MLJFlux" begin + #include("mlj_flux_interfacing.jl") + #end + @testset "ML" begin + include("direct_mlj_interface.jl") end + end From c44b8d83795690e32786a13c805c139728acb190 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Mon, 23 Sep 2024 18:54:51 +0200 Subject: [PATCH 26/60] added tests. should be good.... --- test/Project.toml | 1 + test/direct_mlj_interface.jl | 133 +++++++++++------------------------ 2 files changed, 44 insertions(+), 90 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 750ea47e..2d4f4647 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 13c91a21..66fe020d 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -3,96 +3,49 @@ import Random.seed! using MLJBase: MLJBase, categorical using Flux using StableRNGs +using MLJ +using LaplaceRedux @testset "Regression" begin - function basictest_regression(X, y, builder, optimiser, threshold) - optimiser = deepcopy(optimiser) - - stable_rng = StableRNGs.StableRNG(123) - - model = LaplaceRegression(; - builder=builder, - optimiser=optimiser, - acceleration=MLJBase.CPUThreads(), - loss=Flux.Losses.mse, - rng=stable_rng, - lambda=-1.0, - alpha=-1.0, - epochs=-1, - batch_size=-1, - subset_of_weights=:incorrect, - hessian_structure=:incorrect, - backend=:incorrect, - ret_distr=true, - ) - - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - # increase iterations and check update is incremental: - model.epochs = model.epochs + 3 - - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - yhat = MLJBase.predict(model, fitresult, X) - - # start fresh with small epochs: - model = LaplaceRegression(; - builder=builder, - optimiser=optimiser, - epochs=2, - acceleration=CPU1(), - rng=stable_rng, - ) - - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - - # change batch_size and check it performs cold restart: - model.batch_size = 2 - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - # change learning rate and check it does *not* restart: - model.optimiser.eta /= 2 - fitresult, cache, _report = @test_logs( - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - # set `optimiser_changes_trigger_retraining = true` and change - # learning rate and check it does restart: - model.optimiser_changes_trigger_retraining = true - model.optimiser.eta /= 2 - @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - return true - end - - seed!(1234) - N = 300 - X = MLJBase.table(rand(Float32, N, 4)) - ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) - builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) - optimiser = Flux.Optimise.Adam(0.03) - - @test basictest_regression(X, ycont, builder, optimiser, 0.9) -end \ No newline at end of file + flux_model = Chain( + Dense(4, 10, relu), + Dense(10, 10, relu), + Dense(10, 1) + ) + model = LaplaceRegressor(model=flux_model,epochs=50) + + X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) + mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report) + MLJBase.fit!(mach) + Xnew, _ = make_regression(3, 4; rng=123) + yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions + MLJBase.predict_mode(mach, Xnew) # point predictions + MLJBase.fitted_params(mach) #fitted params function +end + + + +@testset "Classification" begin +# Define the model +flux_model = Chain( + Dense(4, 10, relu), + Dense(10, 10, relu), + Dense(10, 3) +) + +model = LaplaceClassifier(model=flux_model,epochs=50) + +X, y = @load_iris +mach = machine(model, X, y) +MLJBase.fit!(mach) +Xnew = (sepal_length = [6.4, 7.2, 7.4], + sepal_width = [2.8, 3.0, 2.8], + petal_length = [5.6, 5.8, 6.1], + petal_width = [2.1, 1.6, 1.9],) +yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions +predict_mode(mach, Xnew) # point predictions +pdf.(yhat, "virginica") # probabilities for the "verginica" class +MLJBase.fitted_params(mach) # fitted params + +end From b762185fb93cf81fd220718b435d60eeb817bc8f Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Mon, 23 Sep 2024 20:07:01 +0200 Subject: [PATCH 27/60] added mlj to the dependency in test --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 2028e0ba..f9cb74ea 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.8" +CategoricalDistributions = "0.1.15" ChainRulesCore = "1.23.0" Compat = "4.7.0" ComputationalResources = "0.3.2" From ced3da056e0c46b5eaf910b3dc705659a2f2f026 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Thu, 26 Sep 2024 00:45:43 +0200 Subject: [PATCH 28/60] prep for update + added mljmodelinterface to doc env --- docs/Project.toml | 1 + src/direct_mlj.jl | 27 ++++++++++++++++++--------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 7118d5fe..624220c4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 30c57045..102196a2 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -9,7 +9,7 @@ import MLJModelInterface as MMI using Distributions: Normal """ - MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic + MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic A mutable struct representing a Laplace regression model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. @@ -28,8 +28,12 @@ It has the following Hyperparameters: - `P₀`: the covariance matrix of the prior distribution. - `fit_prior_nsteps`: the number of steps used to fit the priors. """ -MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic - model::Flux.Chain = nothing +MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic + model::Flux.Chain = Chain( + Dense(4, 10, relu), + Dense(10, 10, relu), + Dense(10, 1) + ) flux_loss = Flux.Losses.mse optimiser = Adam() epochs::Integer = 1000::(_ > 0) @@ -110,7 +114,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) # Compute loss loss = m.flux_loss(y_pred, y_batch) - # Compute gradients explicitly + # Compute gradients grads = gradient(m.model) do model # Recompute predictions inside gradient context y_pred = model(X_batch) @@ -151,11 +155,16 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) fitresult = la report = (loss_history = loss_history,) - cache = nothing - return (fitresult, cache, report) + cache = (deepcopy(m),opt_state, loss_history) + return fitresult, cache, report end + + + + + @doc """ function MMI.fitted_params(model::LaplaceRegressor, fitresult) @@ -248,7 +257,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) end """ - MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic + MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic A mutable struct representing a Laplace Classification model. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. @@ -269,7 +278,7 @@ The model also has the following parameters: - `link_approx`: the link approximation to use, either `:probit` or `:plugin`. - `fit_prior_nsteps`: the number of steps used to fit the priors. """ -MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic +MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic model::Flux.Chain = nothing flux_loss = Flux.Losses.logitcrossentropy optimiser = Adam() @@ -387,7 +396,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) report = (loss_history = loss_history,) - cache = nothing + cache = (deepcopy(m),opt_state,loss_history) return ((la, y[1]), cache, report) end From b700f851570b56e54b95389e4e43af8f5eeccb3a Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 2 Oct 2024 01:42:17 +0200 Subject: [PATCH 29/60] changed the loop so that it nows uses optimisers from optimisers.jl --- docs/Project.toml | 1 + src/direct_mlj.jl | 168 ++++++++++++++++++++++++++++------------------ 2 files changed, 102 insertions(+), 67 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 624220c4..c8d6234a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,6 +7,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 102196a2..1379aee7 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -1,4 +1,5 @@ #module MLJLaplaceRedux +using Optimisers: Optimisers using Flux using Random using Tables @@ -8,34 +9,11 @@ using MLJBase import MLJModelInterface as MMI using Distributions: Normal -""" - MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic - -A mutable struct representing a Laplace regression model. -It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. -It has the following Hyperparameters: -- `model`: A Flux model provided by the user and compatible with the dataset. -- `flux_loss` : a Flux loss function -- `optimiser` = a Flux optimiser -- `epochs`: The number of training epochs. -- `batch_size`: The batch size. -- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. -- `subnetwork_indices`: the indices of the subnetworks. -- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. -- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `σ`: the standard deviation of the prior distribution. -- `μ₀`: the mean of the prior distribution. -- `P₀`: the covariance matrix of the prior distribution. -- `fit_prior_nsteps`: the number of steps used to fit the priors. -""" + MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic - model::Flux.Chain = Chain( - Dense(4, 10, relu), - Dense(10, 10, relu), - Dense(10, 1) - ) + model::Union{Flux.Chain,Nothing} = nothing flux_loss = Flux.Losses.mse - optimiser = Adam() + optimiser = Optimisers.Adam() epochs::Integer = 1000::(_ > 0) batch_size::Integer = 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) @@ -52,7 +30,7 @@ end # for fit: MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims,reshape(y, 1, :)) - +#for predict: MMI.reformat(::LaplaceRegressor, X) = (MLJBase.matrix(X) |> permutedims,) @@ -96,11 +74,13 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) # Reshape y if necessary y = reshape(y, 1, :) + # Make a copy of the model because Flux does not allow to mutate hyperparameters + copied_model = deepcopy(m.model) data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) - opt_state = Flux.setup(m.optimiser, m.model) + state_tree = Optimisers.setup(m.optimiser, copied_model) loss_history=[] - push!(loss_history, m.flux_loss(m.model(X), y )) + push!(loss_history, m.flux_loss(copied_model(X), y )) for epoch in 1:(m.epochs) @@ -109,20 +89,20 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) for (X_batch, y_batch) in data_loader # Forward pass: compute predictions - y_pred = m.model(X_batch) + y_pred = copied_model(X_batch) # Compute loss loss = m.flux_loss(y_pred, y_batch) # Compute gradients - grads = gradient(m.model) do model + grads,_ = gradient(copied_model,X_batch) do model, X # Recompute predictions inside gradient context - y_pred = model(X_batch) + y_pred = model(X) m.flux_loss(y_pred, y_batch) end # Update parameters using the optimizer and computed gradients - Flux.Optimise.update!(opt_state ,m.model , grads[1]) + state_tree, model = Optimisers.update!(state_tree ,copied_model, grads) # Accumulate the loss for this batch loss_per_epoch += sum(loss) # Summing the batch loss @@ -138,7 +118,7 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) end la = LaplaceRedux.Laplace( - m.model; + copied_model; likelihood=:regression, subset_of_weights=m.subset_of_weights, subnetwork_indices=m.subnetwork_indices, @@ -155,12 +135,85 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) fitresult = la report = (loss_history = loss_history,) - cache = (deepcopy(m),opt_state, loss_history) + cache = (deepcopy(m),state_tree, loss_history) return fitresult, cache, report end +# Define the function is_same_except +function MMI.is_same_except(m1::LaplaceRegressor, m2::LaplaceRegressor, exceptions::Symbol...) + typeof(m1) === typeof(m2) || return false + names = propertynames(m1) + propertynames(m2) === names || return false + + for name in names + if !(name in exceptions) + if !_isdefined(m1, name) + !_isdefined(m2, name) || return false + elseif _isdefined(m2, name) + if name in deep_properties(LaplaceRegressor) + _equal_to_depth_one( + getproperty(m1,name), + getproperty(m2, name) + ) || return false + else + ( + is_same_except( + getproperty(m1, name), + getproperty(m2, name) + ) || + getproperty(m1, name) isa AbstractRNG || + getproperty(m2, name) isa AbstractRNG || + (getproperty(m1, name) isa Flux.Chain && getproperty(m2, name) isa Flux.Chain && _equal_flux_chain(getproperty(m1, name), getproperty(m2, name))) + ) || return false + end + else + return false + end + end + end + return true +end + +# Define helper functions used in is_same_except +function _isdefined(obj, name) + return hasproperty(obj, name) +end + +function deep_properties(::Type) + return Set{Symbol}() +end + +function _equal_to_depth_one(a, b) + return a == b +end + +function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain) + if length(chain1.layers) != length(chain2.layers) + println("no length chain") + return false + end + params1 = Flux.params(chain1) + params2 = Flux.params(chain2) + if length(params1) != length(params2) + println("no length params") + return false + end + for (p1, p2) in zip(params1, params2) + if !isequal(p1, p2) + println(" params differs") + return false + end + end + for (layer1, layer2) in zip(chain1.layers, chain2.layers) + if typeof(layer1) != typeof(layer2) + println("layer differ") + return false + end + end + return true +end @@ -256,32 +309,11 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)] end -""" - MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic - -A mutable struct representing a Laplace Classification model. -It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. -The model also has the following parameters: - -- `model`: A Flux model provided by the user and compatible with the dataset. -- `flux_loss` : a Flux loss function -- `optimiser` = a Flux optimiser -- `epochs`: The number of training epochs. -- `batch_size`: The batch size. -- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. -- `subnetwork_indices`: the indices of the subnetworks. -- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. -- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `σ`: the standard deviation of the prior distribution. -- `μ₀`: the mean of the prior distribution. -- `P₀`: the covariance matrix of the prior distribution. -- `link_approx`: the link approximation to use, either `:probit` or `:plugin`. -- `fit_prior_nsteps`: the number of steps used to fit the priors. -""" + MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic - model::Flux.Chain = nothing + model::Union{Flux.Chain,Nothing} = nothing flux_loss = Flux.Losses.logitcrossentropy - optimiser = Adam() + optimiser = Optimisers.Adam() epochs::Integer = 1000::(_ > 0) batch_size::Integer = 32::(_ > 0) subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) @@ -335,12 +367,14 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) # One-hot encoding of labels unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding + #copy model + copied_model = deepcopy(m.model) # Create a data loader for batching the data data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) # Set up the optimizer for the model - opt_state = Flux.setup(m.optimiser, m.model) + state_tree = Optimisers.setup(m.optimiser, copied_model) loss_history = [] # Training loop for the specified number of epochs @@ -351,20 +385,20 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) for (X_batch, y_batch) in data_loader # Forward pass: compute predictions - y_pred = m.model(X_batch) + y_pred = copied_model(X_batch) # Compute loss loss = m.flux_loss(y_pred, y_batch) - # Compute gradients explicitly - grads = gradient(m.model) do model + # Compute gradients + grads,_ = gradient(copied_model,X_batch) do model, X # Recompute predictions inside gradient context - y_pred = model(X_batch) + y_pred = model(X) m.flux_loss(y_pred, y_batch) end # Update parameters using the optimizer and computed gradients - Flux.Optimise.update!(opt_state ,m.model , grads[1]) + state_tree, model = Optimisers.update!(state_tree ,copied_model, grads) # Accumulate the loss for this batch loss_per_epoch += sum(loss) # Summing the batch loss @@ -396,7 +430,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) report = (loss_history = loss_history,) - cache = (deepcopy(m),opt_state,loss_history) + cache = (deepcopy(m),state_tree,loss_history) return ((la, y[1]), cache, report) end @@ -493,7 +527,7 @@ end # for fit: MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y) - +# for predict: MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,) MMI.metadata_pkg( From da6fc76b345f66987d7b700241299b13ee50f4ed Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Thu, 3 Oct 2024 22:40:37 +0200 Subject: [PATCH 30/60] started joining the functions in a single common function for both models --- src/direct_mlj.jl | 257 ++++++++++++++++++++++++++++------------------ 1 file changed, 157 insertions(+), 100 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 1379aee7..3b5b33af 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -9,6 +9,23 @@ using MLJBase import MLJModelInterface as MMI using Distributions: Normal +MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic + model::Union{Flux.Chain,Nothing} = nothing + flux_loss = Flux.Losses.logitcrossentropy + optimiser = Optimisers.Adam() + epochs::Integer = 1000::(_ > 0) + batch_size::Integer = 32::(_ > 0) + subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) + subnetwork_indices = nothing + hessian_structure::Union{HessianStructure,Symbol,String} = + :full::(_ in (:full, :diagonal)) + backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) + σ::Float64 = 1.0 + μ₀::Float64 = 0.0 + P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing + fit_prior_nsteps::Int = 100::(_ > 0) + link_approx::Symbol = :probit::(_ in (:probit, :plugin)) +end MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic model::Union{Flux.Chain,Nothing} = nothing @@ -27,11 +44,17 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic fit_prior_nsteps::Int = 100::(_ > 0) end +Const_Models = Union{LaplaceRegressor,LaplaceClassifier} # for fit: -MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims,reshape(y, 1, :)) +MMI.reformat(::Const_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :)) #for predict: -MMI.reformat(::LaplaceRegressor, X) = (MLJBase.matrix(X) |> permutedims,) +MMI.reformat(::Const_Models, X) = (MLJBase.matrix(X) |> permutedims,) + +# for fit: +MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y) +# for predict: +#MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,) @@ -68,19 +91,18 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) #X = MLJBase.matrix(X) |> permutedims #y = reshape(y, 1, :) - if Tables.istable(X) - X = Tables.matrix(X)|>permutedims - end + #if Tables.istable(X) + #X = Tables.matrix(X)|>permutedims + #end # Reshape y if necessary - y = reshape(y, 1, :) + #y = reshape(y, 1, :) # Make a copy of the model because Flux does not allow to mutate hyperparameters copied_model = deepcopy(m.model) data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) state_tree = Optimisers.setup(m.optimiser, copied_model) loss_history=[] - push!(loss_history, m.flux_loss(copied_model(X), y )) for epoch in 1:(m.epochs) @@ -95,14 +117,14 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) loss = m.flux_loss(y_pred, y_batch) # Compute gradients - grads,_ = gradient(copied_model,X_batch) do model, X + grads,_ = gradient(copied_model,X_batch) do grad_model, X # Recompute predictions inside gradient context - y_pred = model(X) + y_pred = grad_model(X) m.flux_loss(y_pred, y_batch) end # Update parameters using the optimizer and computed gradients - state_tree, model = Optimisers.update!(state_tree ,copied_model, grads) + state_tree, copied_model = Optimisers.update!(state_tree ,copied_model, grads) # Accumulate the loss for this batch loss_per_epoch += sum(loss) # Summing the batch loss @@ -133,15 +155,126 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) LaplaceRedux.fit!(la, data_loader) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - fitresult = la + fitresult = (la, y[1]) report = (loss_history = loss_history,) - cache = (deepcopy(m),state_tree, loss_history) + cache = (deepcopy(m),state_tree,loss_history) + return fitresult, cache, report +end + +function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, y) + println("we are in the update function") + + data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) + old_model = old_cache[1] + old_state_tree = old_cache[2] + old_loss_history = old_cache[3] + old_la = old_fitresult[1] + + epochs = m.epochs + + if MMI.is_same_except(m, old_model,:epochs) + + + if epochs > old_model.epochs + + + for epoch in (old_model.epochs+1):(epochs) + + loss_per_epoch= 0.0 + + + for (X_batch, y_batch) in data_loader + # Forward pass: compute predictions + y_pred = old_la.model(X_batch) + + # Compute loss + loss = m.flux_loss(y_pred, y_batch) + + # Compute gradients + grads,_ = gradient(old_la.model,X_batch) do grad_model, X + # Recompute predictions inside gradient context + y_pred = grad_model(X) + m.flux_loss(y_pred, y_batch) + end + + # Update parameters using the optimizer and computed gradients + old_state_tree,old_la.model = Optimisers.update!(old_state_tree,old_la.model, grads) + + # Accumulate the loss for this batch + loss_per_epoch += sum(loss) # Summing the batch loss + + end + + push!(old_loss_history,loss_per_epoch ) + + # Print loss every 100 epochs if verbosity is 1 or more + if verbosity >= 1 && epoch % 100 == 0 + println("Epoch $epoch: Loss: $loss_per_epoch ") + end + end + + la = LaplaceRedux.Laplace( + old_la.model; + likelihood=:regression, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, + σ=m.σ, + μ₀=m.μ₀, + P₀=m.P₀, + ) + + # fit the Laplace model: + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) + + fitresult = (la, y[1]) + report = (loss_history = old_loss_history,) + cache = (deepcopy(m),old_state_tree,old_loss_history) + + else + + nothing + + end + + end + + if MMI.is_same_except(m, old_model,:fit_prior_nsteps,:subset_of_weights,:subnetwork_indices,:hessian_structure,:backend,:σ,:μ₀,:P₀) + + println(" changing only the laplace optimization part") + + la = LaplaceRedux.Laplace( + old_la.model; + likelihood=:regression, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, + σ=m.σ, + μ₀=m.μ₀, + P₀=m.P₀, + ) + + # fit the Laplace model: + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) + + fitresult = la + report = (loss_history = old_loss_history,) + cache = (deepcopy(m),old_state_tree,old_loss_history) + + end + + return fitresult, cache, report end # Define the function is_same_except -function MMI.is_same_except(m1::LaplaceRegressor, m2::LaplaceRegressor, exceptions::Symbol...) +function MMI.is_same_except(m1::Const_Models, m2::Const_Models, exceptions::Symbol...) + println("overloaded") typeof(m1) === typeof(m2) || return false names = propertynames(m1) propertynames(m2) === names || return false @@ -241,8 +374,8 @@ end - `loss`: The loss value of the posterior distribution. """ -function MMI.fitted_params(model::LaplaceRegressor, fitresult) - la = fitresult +function MMI.fitted_params(model::Const_Models, fitresult) + la,decode = fitresult posterior = la.posterior return ( μ = posterior.μ, @@ -261,18 +394,18 @@ end @doc """ - MMI.training_losses(model::LaplaceRegressor, report) + MMI.training_losses(model::Union{LaplaceRegressor,LaplaceClassifier}, report) Retrieve the training loss history from the given `report`. # Arguments -- `model::LaplaceRegressor`: The model for which the training losses are being retrieved. +- `model`: The model for which the training losses are being retrieved. - `report`: An object containing the training report, which includes the loss history. # Returns - A collection representing the loss history from the training report. """ -function MMI.training_losses(model::LaplaceRegressor, report) +function MMI.training_losses(model::Const_Models, report) return report.loss_history end @@ -282,11 +415,11 @@ end @doc """ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) - Predicts the response for new data using a fitted LaplaceRegressor model. + Predicts the response for new data using a fitted Laplace model. # Arguments - - `m::LaplaceRegressor`: The LaplaceRegressor model. - - `fitresult`: The result of fitting the LaplaceRegressor model. + - `m::LaplaceRegressor`: The Laplace model. + - `fitresult`: The result of the fitting procedure. - `Xnew`: The new data for which predictions are to be made. # Returns @@ -300,7 +433,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) if Tables.istable(Xnew) Xnew = Tables.matrix(Xnew)|>permutedims end - la = fitresult + la, y = fitresult yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) # Extract mean and variance matrices means, variances = yhat @@ -310,23 +443,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) end -MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic - model::Union{Flux.Chain,Nothing} = nothing - flux_loss = Flux.Losses.logitcrossentropy - optimiser = Optimisers.Adam() - epochs::Integer = 1000::(_ > 0) - batch_size::Integer = 32::(_ > 0) - subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices = nothing - hessian_structure::Union{HessianStructure,Symbol,String} = - :full::(_ in (:full, :diagonal)) - backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64 = 1.0 - μ₀::Float64 = 0.0 - P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - fit_prior_nsteps::Int = 100::(_ > 0) - link_approx::Symbol = :probit::(_ in (:probit, :plugin)) -end + @doc """ @@ -435,64 +552,7 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) end - - - - -@doc """ - - function MMI.fitted_params(model::LaplaceClassifier, fitresult) - - This function extracts the fitted parameters from a `LaplaceClassifier` model. - - # Arguments - - `model::LaplaceClassifier`: The Laplace classifier model. - - `fitresult`: A tuple containing the Laplace approximation (`la`) and a decode function. - - # Returns - A named tuple containing: - - `μ`: The mean of the posterior distribution. - - `H`: The Hessian of the posterior distribution. - - `P`: The precision matrix of the posterior distribution. - - `Σ`: The covariance matrix of the posterior distribution. - - `n_data`: The number of data points. - - `n_params`: The number of parameters. - - `n_out`: The number of outputs. - - `loss`: The loss value of the posterior distribution. - -""" -function MMI.fitted_params(model::LaplaceClassifier, fitresult) - la, decode = fitresult - posterior = la.posterior - return ( - μ = posterior.μ, - H = posterior.H, - P = posterior.P, - Σ = posterior.Σ, - n_data = posterior.n_data, - n_params = posterior.n_params, - n_out = posterior.n_out, - loss = posterior.loss - ) -end - - -@doc """ - MMI.training_losses(model::LaplaceClassifier, report) - -Retrieve the training loss history from the given `report`. - -# Arguments -- `model::LaplaceClassifier`: The model for which the training losses are being retrieved. -- `report`: An object containing the training report, which includes the loss history. - -# Returns -- A collection representing the loss history from the training report. -""" -function MMI.training_losses(model::LaplaceClassifier, report) - return report.loss_history -end @doc """ @@ -525,10 +585,7 @@ end -# for fit: -MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y) -# for predict: -MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,) + MMI.metadata_pkg( LaplaceRegressor, From 70df56889d86fdfdb8979e6552b74fbb5d6966a8 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 4 Oct 2024 02:18:30 +0200 Subject: [PATCH 31/60] various fixes --- src/direct_mlj.jl | 53 ++++++++--------------------------------------- 1 file changed, 9 insertions(+), 44 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 3b5b33af..8ffd65b6 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -51,11 +51,6 @@ MMI.reformat(::Const_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape( #for predict: MMI.reformat(::Const_Models, X) = (MLJBase.matrix(X) |> permutedims,) -# for fit: -MMI.reformat(::LaplaceClassifier, X, y) = (MLJBase.matrix(X) |> permutedims,y) -# for predict: -#MMI.reformat(::LaplaceClassifier, X) = (MLJBase.matrix(X) |> permutedims,) - @doc """ @@ -87,16 +82,6 @@ This function performs the following steps: 9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report listing training related statistics. """ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) - - #X = MLJBase.matrix(X) |> permutedims - #y = reshape(y, 1, :) - - #if Tables.istable(X) - #X = Tables.matrix(X)|>permutedims - #end - - # Reshape y if necessary - #y = reshape(y, 1, :) # Make a copy of the model because Flux does not allow to mutate hyperparameters copied_model = deepcopy(m.model) @@ -243,7 +228,7 @@ function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, if MMI.is_same_except(m, old_model,:fit_prior_nsteps,:subset_of_weights,:subnetwork_indices,:hessian_structure,:backend,:σ,:μ₀,:P₀) - println(" changing only the laplace optimization part") + println(" updating only the laplace optimization part") la = LaplaceRedux.Laplace( old_la.model; @@ -274,7 +259,6 @@ end # Define the function is_same_except function MMI.is_same_except(m1::Const_Models, m2::Const_Models, exceptions::Symbol...) - println("overloaded") typeof(m1) === typeof(m2) || return false names = propertynames(m1) propertynames(m2) === names || return false @@ -323,24 +307,20 @@ end function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain) if length(chain1.layers) != length(chain2.layers) - println("no length chain") return false end params1 = Flux.params(chain1) params2 = Flux.params(chain2) if length(params1) != length(params2) - println("no length params") return false end for (p1, p2) in zip(params1, params2) if !isequal(p1, p2) - println(" params differs") return false end end for (layer1, layer2) in zip(chain1.layers, chain2.layers) if typeof(layer1) != typeof(layer2) - println("layer differ") return false end @@ -375,7 +355,7 @@ end """ function MMI.fitted_params(model::Const_Models, fitresult) - la,decode = fitresult + la, decode = fitresult posterior = la.posterior return ( μ = posterior.μ, @@ -429,10 +409,6 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) Finally, it creates Normal distributions from these means and variances and returns them as an array. """ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) - #Xnew = MLJBase.matrix(Xnew) |> permutedims - if Tables.istable(Xnew) - Xnew = Tables.matrix(Xnew)|>permutedims - end la, y = fitresult yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) # Extract mean and variance matrices @@ -468,18 +444,11 @@ end """ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) - #X = MLJBase.matrix(X) |> permutedims - - if Tables.istable(X) - X = Tables.matrix(X)|>permutedims - end - - - # Store the first label as decode function - #decode = y[1] + + # Convert labels to integer format starting from 0 for one-hot encoding - y_plain = MLJBase.int(y) .- 1 + y_plain = MLJBase.int(y[1,:]) .- 1 # One-hot encoding of labels unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding @@ -508,9 +477,9 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) loss = m.flux_loss(y_pred, y_batch) # Compute gradients - grads,_ = gradient(copied_model,X_batch) do model, X + grads,_ = gradient(copied_model,X_batch) do grad_model, X # Recompute predictions inside gradient context - y_pred = model(X) + y_pred = grad_model(X) m.flux_loss(y_pred, y_batch) end @@ -569,12 +538,8 @@ Predicts the class probabilities for new data using a Laplace classifier. The function transforms the new data `Xnew` into a matrix, applies the LaplaceRedux prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object. """ -function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) - if Tables.istable(Xnew) - Xnew = Tables.matrix(Xnew)|>permutedims - end - la = fitresult - #Xnew = MLJBase.matrix(Xnew) |> permutedims +function MMI.predict(m::LaplaceClassifier, fitresult, Xnew) + la,decode = fitresult predictions = LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> permutedims From 988987229838c30797c56f3b885ff1b8317c3a53 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 4 Oct 2024 14:57:06 +0200 Subject: [PATCH 32/60] merged functions for both cases --- src/direct_mlj.jl | 277 +++++++++++++++++++--------------------------- 1 file changed, 113 insertions(+), 164 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 8ffd65b6..3d592c9c 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -44,44 +44,48 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic fit_prior_nsteps::Int = 100::(_ > 0) end -Const_Models = Union{LaplaceRegressor,LaplaceClassifier} +Laplace_Models = Union{LaplaceRegressor,LaplaceClassifier} # for fit: -MMI.reformat(::Const_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :)) +MMI.reformat(::Laplace_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :)) #for predict: -MMI.reformat(::Const_Models, X) = (MLJBase.matrix(X) |> permutedims,) +MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,) @doc """ - MMI.fit(m::LaplaceRegressor, verbosity, X, y) + MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y) -Fit a LaplaceRegressor model using the provided features and target values. +Fit a Laplace model using the provided features and target values. # Arguments -- `m::LaplaceRegressor`: The LaplaceRegressor model to be fitted. +- `m::Laplace`: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted. - `verbosity`: Verbosity level for logging. - `X`: Input features, expected to be in a format compatible with MLJBase.matrix. - `y`: Target values. # Returns -- `fitresult`: The fitted Laplace model. -- `cache`: Currently unused, returns `nothing`. -- `report`: A tuple containing the loss history of the fitting process. - -# Description -This function performs the following steps: -1. If X is a Table: converts the input features `X` to a matrix and transposes it. -2. Reshapes the target values `y` to shape (1,:). -3. Creates a data loader for batching the data. -4. Sets up the optimizer state using the Adam optimizer. -5. Trains the model for a specified number of epochs. -6. Initializes a Laplace model with the trained Flux model and specified parameters. -7. Fits the Laplace model using the data loader. -8. Optimizes the prior of the Laplace model. -9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report listing training related statistics. +- `fitresult`: a tuple (la,decode) cointaing the fitted Laplace model and y[1],the first element of the categorical y vector. +- `cache`: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history. +- `report`: A Namedtuple containing the loss history of the fitting process. """ -function MMI.fit(m::LaplaceRegressor, verbosity, X, y) +function MMI.fit(m::Laplace_Models, verbosity, X, y) + + decode = y[1] + + + if typeof(m) == LaplaceRegressor + nothing + else + # Convert labels to integer format starting from 0 for one-hot encoding + y_plain = MLJBase.int(y[1,:]) .- 1 + + # One-hot encoding of labels + unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding + y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding + + end + # Make a copy of the model because Flux does not allow to mutate hyperparameters copied_model = deepcopy(m.model) @@ -136,18 +140,50 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) P₀=m.P₀, ) + if typeof(m) == LaplaceClassifier + la.likelihood = :classification + end + # fit the Laplace model: LaplaceRedux.fit!(la, data_loader) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - fitresult = (la, y[1]) + fitresult = (la, decode) report = (loss_history = loss_history,) cache = (deepcopy(m),state_tree,loss_history) return fitresult, cache, report end -function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, y) - println("we are in the update function") +@doc """ + MMI.update(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y) + +Update the Laplace model using the provided new data points. + +# Arguments +- `m`: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted. +- `verbosity`: Verbosity level for logging. +- `X`: New input features, expected to be in a format compatible with MLJBase.matrix. +- `y`: New target values. + +# Returns +- `fitresult`: a tuple (la,decode) cointaing the updated fitted Laplace model and y[1],the first element of the categorical y vector. +- `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history. +- `report`: A Namedtuple containing the complete loss history of the fitting process. +""" +function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y) + + + if typeof(m) == LaplaceRegressor + nothing + else + # Convert labels to integer format starting from 0 for one-hot encoding + y_plain = MLJBase.int(y[1,:]) .- 1 + + # One-hot encoding of labels + unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding + y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding + + end data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) old_model = old_cache[1] @@ -209,6 +245,9 @@ function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, μ₀=m.μ₀, P₀=m.P₀, ) + if typeof(m) == LaplaceClassifier + la.likelihood = :classification + end # fit the Laplace model: LaplaceRedux.fit!(la, data_loader) @@ -241,6 +280,9 @@ function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, μ₀=m.μ₀, P₀=m.P₀, ) + if typeof(m) == LaplaceClassifier + la.likelihood = :classification + end # fit the Laplace model: LaplaceRedux.fit!(la, data_loader) @@ -257,8 +299,35 @@ function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, end -# Define the function is_same_except -function MMI.is_same_except(m1::Const_Models, m2::Const_Models, exceptions::Symbol...) +@doc """ + function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...) + +If both `m1` and `m2` are of `MLJType`, return `true` if the +following conditions all hold, and `false` otherwise: + +- `typeof(m1) === typeof(m2)` + +- `propertynames(m1) === propertynames(m2)` + +- with the exception of properties listed as `exceptions` or bound to + an `AbstractRNG`, each pair of corresponding property values is + either "equal" or both undefined. (If a property appears as a + `propertyname` but not a `fieldname`, it is deemed as always defined.) + +The meaining of "equal" depends on the type of the property value: + +- values that are themselves of `MLJType` are "equal" if they are + equal in the sense of `is_same_except` with no exceptions. + +- values that are not of `MLJType` are "equal" if they are `==`. + +In the special case of a "deep" property, "equal" has a different +meaning; see [`deep_properties`](@ref)) for details. + +If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`. + +""" +function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...) typeof(m1) === typeof(m2) || return false names = propertynames(m1) propertynames(m2) === names || return false @@ -354,7 +423,7 @@ end - `loss`: The loss value of the posterior distribution. """ -function MMI.fitted_params(model::Const_Models, fitresult) +function MMI.fitted_params(model::Laplace_Models, fitresult) la, decode = fitresult posterior = la.posterior return ( @@ -385,7 +454,7 @@ Retrieve the training loss history from the given `report`. # Returns - A collection representing the loss history from the training report. """ -function MMI.training_losses(model::Const_Models, report) +function MMI.training_losses(model::Laplace_Models, report) return report.loss_history end @@ -403,149 +472,29 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) - `Xnew`: The new data for which predictions are to be made. # Returns - - An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`. - - The function first converts `Xnew` to a matrix and permutes its dimensions. It then uses the `LaplaceRedux.predict` function to obtain the predicted means and variances. -Finally, it creates Normal distributions from these means and variances and returns them as an array. -""" -function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) - la, y = fitresult - yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) - # Extract mean and variance matrices - means, variances = yhat - - # Create Normal distributions from the means and variances - return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)] -end - - - - - -@doc """ - - function MMI.fit(m::LaplaceClassifier, verbosity, X, y) - - Description: - This function fits a LaplaceClassifier model using the provided data. It first preprocesses the input data `X` and target labels `y`, - then trains a neural network model using the Flux library. After training, it fits a Laplace approximation to the trained model. - - Arguments: - - `m::LaplaceClassifier`: The LaplaceClassifier model to be fitted. - - `verbosity`: Verbosity level for logging. - - `X`: Input data features. - - `y`: Target labels. - - Returns: - - A tuple containing: - - `(la, y[1])`: The fitted Laplace model and the decode function for the target labels. - - `cache`: A placeholder for any cached data (currently `nothing`). - - `report`: A NamedTuple containing statistics related to the fitting process. - + for LaplaceRegressor: + - An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`. + for LaplaceClassifier: + - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data. """ -function MMI.fit(m::LaplaceClassifier, verbosity, X, y) - - - - # Convert labels to integer format starting from 0 for one-hot encoding - y_plain = MLJBase.int(y[1,:]) .- 1 - - # One-hot encoding of labels - unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding - y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding - #copy model - copied_model = deepcopy(m.model) - - # Create a data loader for batching the data - data_loader = Flux.DataLoader((X, y_onehot); batchsize=m.batch_size) - - # Set up the optimizer for the model - state_tree = Optimisers.setup(m.optimiser, copied_model) - loss_history = [] - - # Training loop for the specified number of epochs - for epoch in 1:(m.epochs) - - loss_per_epoch= 0.0 - - - for (X_batch, y_batch) in data_loader - # Forward pass: compute predictions - y_pred = copied_model(X_batch) - - # Compute loss - loss = m.flux_loss(y_pred, y_batch) - - # Compute gradients - grads,_ = gradient(copied_model,X_batch) do grad_model, X - # Recompute predictions inside gradient context - y_pred = grad_model(X) - m.flux_loss(y_pred, y_batch) - end - - # Update parameters using the optimizer and computed gradients - state_tree, model = Optimisers.update!(state_tree ,copied_model, grads) +function MMI.predict(m::Laplace_Models, fitresult, Xnew) + la, decode = fitresult + if typeof(m)== LaplaceRegressor + yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) + # Extract mean and variance matrices + means, variances = yhat - # Accumulate the loss for this batch - loss_per_epoch += sum(loss) # Summing the batch loss - - end + # Create Normal distributions from the means and variances + return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)] - push!(loss_history,loss_per_epoch ) + else + predictions = LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> permutedims - # Print loss every 100 epochs if verbosity is 1 or more - if verbosity >= 1 && epoch % 100 == 0 - println("Epoch $epoch: Loss: $loss_per_epoch ") - end + return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions) end - - la = LaplaceRedux.Laplace( - m.model; - likelihood=:classification, - subset_of_weights=m.subset_of_weights, - subnetwork_indices=m.subnetwork_indices, - hessian_structure=m.hessian_structure, - backend=m.backend, - σ=m.σ, - μ₀=m.μ₀, - P₀=m.P₀, - ) - - # fit the Laplace model: - LaplaceRedux.fit!(la, data_loader) - optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - - report = (loss_history = loss_history,) - cache = (deepcopy(m),state_tree,loss_history) - return ((la, y[1]), cache, report) end - - - -@doc """ -Predicts the class probabilities for new data using a Laplace classifier. - - # Arguments - - `m::LaplaceClassifier`: The Laplace classifier model. - - `(fitresult, decode)`: A tuple containing the fitted model result and the decode function. - - `Xnew`: The new data for which predictions are to be made. - - # Returns - - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data. - -The function transforms the new data `Xnew` into a matrix, applies the LaplaceRedux -prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object. -""" -function MMI.predict(m::LaplaceClassifier, fitresult, Xnew) - la,decode = fitresult - predictions = - LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> - permutedims - - return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions) -end From 0f46fd6452eac5813b3f24218a0c27e8eac56e88 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 4 Oct 2024 15:02:47 +0200 Subject: [PATCH 33/60] julia formatter --- src/direct_mlj.jl | 257 ++++++++++++++++++++-------------------------- 1 file changed, 113 insertions(+), 144 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 3d592c9c..15d3b39f 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -10,7 +10,7 @@ import MLJModelInterface as MMI using Distributions: Normal MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic - model::Union{Flux.Chain,Nothing} = nothing + model::Union{Flux.Chain,Nothing} = nothing flux_loss = Flux.Losses.logitcrossentropy optimiser = Optimisers.Adam() epochs::Integer = 1000::(_ > 0) @@ -28,7 +28,7 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic end MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic - model::Union{Flux.Chain,Nothing} = nothing + model::Union{Flux.Chain,Nothing} = nothing flux_loss = Flux.Losses.mse optimiser = Optimisers.Adam() epochs::Integer = 1000::(_ > 0) @@ -51,8 +51,6 @@ MMI.reformat(::Laplace_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshap #for predict: MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,) - - @doc """ MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y) @@ -70,20 +68,17 @@ Fit a Laplace model using the provided features and target values. - `report`: A Namedtuple containing the loss history of the fitting process. """ function MMI.fit(m::Laplace_Models, verbosity, X, y) - decode = y[1] - if typeof(m) == LaplaceRegressor nothing else # Convert labels to integer format starting from 0 for one-hot encoding - y_plain = MLJBase.int(y[1,:]) .- 1 + y_plain = MLJBase.int(y[1, :]) .- 1 # One-hot encoding of labels unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding - end # Make a copy of the model because Flux does not allow to mutate hyperparameters @@ -91,12 +86,10 @@ function MMI.fit(m::Laplace_Models, verbosity, X, y) data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) state_tree = Optimisers.setup(m.optimiser, copied_model) - loss_history=[] + loss_history = [] for epoch in 1:(m.epochs) - - loss_per_epoch= 0.0 - + loss_per_epoch = 0.0 for (X_batch, y_batch) in data_loader # Forward pass: compute predictions @@ -106,21 +99,20 @@ function MMI.fit(m::Laplace_Models, verbosity, X, y) loss = m.flux_loss(y_pred, y_batch) # Compute gradients - grads,_ = gradient(copied_model,X_batch) do grad_model, X + grads, _ = gradient(copied_model, X_batch) do grad_model, X # Recompute predictions inside gradient context y_pred = grad_model(X) m.flux_loss(y_pred, y_batch) end - + # Update parameters using the optimizer and computed gradients - state_tree, copied_model = Optimisers.update!(state_tree ,copied_model, grads) + state_tree, copied_model = Optimisers.update!(state_tree, copied_model, grads) # Accumulate the loss for this batch loss_per_epoch += sum(loss) # Summing the batch loss - end - push!(loss_history,loss_per_epoch ) + push!(loss_history, loss_per_epoch) # Print loss every 100 epochs if verbosity is 1 or more if verbosity >= 1 && epoch % 100 == 0 @@ -149,8 +141,8 @@ function MMI.fit(m::Laplace_Models, verbosity, X, y) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) fitresult = (la, decode) - report = (loss_history = loss_history,) - cache = (deepcopy(m),state_tree,loss_history) + report = (loss_history=loss_history,) + cache = (deepcopy(m), state_tree, loss_history) return fitresult, cache, report end @@ -170,19 +162,16 @@ Update the Laplace model using the provided new data points. - `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history. - `report`: A Namedtuple containing the complete loss history of the fitting process. """ -function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y) - - +function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y) if typeof(m) == LaplaceRegressor nothing else # Convert labels to integer format starting from 0 for one-hot encoding - y_plain = MLJBase.int(y[1,:]) .- 1 + y_plain = MLJBase.int(y[1, :]) .- 1 # One-hot encoding of labels unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding - end data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) @@ -193,80 +182,82 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y epochs = m.epochs - if MMI.is_same_except(m, old_model,:epochs) - - + if MMI.is_same_except(m, old_model, :epochs) if epochs > old_model.epochs + for epoch in (old_model.epochs + 1):(epochs) + loss_per_epoch = 0.0 - - for epoch in (old_model.epochs+1):(epochs) - - loss_per_epoch= 0.0 - - for (X_batch, y_batch) in data_loader # Forward pass: compute predictions y_pred = old_la.model(X_batch) - + # Compute loss loss = m.flux_loss(y_pred, y_batch) - + # Compute gradients - grads,_ = gradient(old_la.model,X_batch) do grad_model, X + grads, _ = gradient(old_la.model, X_batch) do grad_model, X # Recompute predictions inside gradient context y_pred = grad_model(X) m.flux_loss(y_pred, y_batch) end - + # Update parameters using the optimizer and computed gradients - old_state_tree,old_la.model = Optimisers.update!(old_state_tree,old_la.model, grads) - + old_state_tree, old_la.model = Optimisers.update!( + old_state_tree, old_la.model, grads + ) + # Accumulate the loss for this batch loss_per_epoch += sum(loss) # Summing the batch loss - end - - push!(old_loss_history,loss_per_epoch ) - + + push!(old_loss_history, loss_per_epoch) + # Print loss every 100 epochs if verbosity is 1 or more if verbosity >= 1 && epoch % 100 == 0 println("Epoch $epoch: Loss: $loss_per_epoch ") end end - la = LaplaceRedux.Laplace( - old_la.model; - likelihood=:regression, - subset_of_weights=m.subset_of_weights, - subnetwork_indices=m.subnetwork_indices, - hessian_structure=m.hessian_structure, - backend=m.backend, - σ=m.σ, - μ₀=m.μ₀, - P₀=m.P₀, - ) - if typeof(m) == LaplaceClassifier - la.likelihood = :classification - end + la = LaplaceRedux.Laplace( + old_la.model; + likelihood=:regression, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, + σ=m.σ, + μ₀=m.μ₀, + P₀=m.P₀, + ) + if typeof(m) == LaplaceClassifier + la.likelihood = :classification + end - # fit the Laplace model: - LaplaceRedux.fit!(la, data_loader) - optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) + # fit the Laplace model: + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - fitresult = (la, y[1]) - report = (loss_history = old_loss_history,) - cache = (deepcopy(m),old_state_tree,old_loss_history) + fitresult = (la, y[1]) + report = (loss_history=old_loss_history,) + cache = (deepcopy(m), old_state_tree, old_loss_history) else - nothing - end - end - if MMI.is_same_except(m, old_model,:fit_prior_nsteps,:subset_of_weights,:subnetwork_indices,:hessian_structure,:backend,:σ,:μ₀,:P₀) - + if MMI.is_same_except( + m, + old_model, + :fit_prior_nsteps, + :subset_of_weights, + :subnetwork_indices, + :hessian_structure, + :backend, + :σ, + :μ₀, + :P₀, + ) println(" updating only the laplace optimization part") la = LaplaceRedux.Laplace( @@ -279,26 +270,23 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y σ=m.σ, μ₀=m.μ₀, P₀=m.P₀, - ) - if typeof(m) == LaplaceClassifier - la.likelihood = :classification - end - - # fit the Laplace model: - LaplaceRedux.fit!(la, data_loader) - optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - - fitresult = la - report = (loss_history = old_loss_history,) - cache = (deepcopy(m),old_state_tree,old_loss_history) + ) + if typeof(m) == LaplaceClassifier + la.likelihood = :classification + end - end + # fit the Laplace model: + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) + fitresult = la + report = (loss_history=old_loss_history,) + cache = (deepcopy(m), old_state_tree, old_loss_history) + end return fitresult, cache, report end - @doc """ function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...) @@ -327,7 +315,7 @@ meaning; see [`deep_properties`](@ref)) for details. If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`. """ -function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...) +function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...) typeof(m1) === typeof(m2) || return false names = propertynames(m1) propertynames(m2) === names || return false @@ -335,22 +323,21 @@ function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions:: for name in names if !(name in exceptions) if !_isdefined(m1, name) - !_isdefined(m2, name) || return false + !_isdefined(m2, name) || return false elseif _isdefined(m2, name) if name in deep_properties(LaplaceRegressor) - _equal_to_depth_one( - getproperty(m1,name), - getproperty(m2, name) - ) || return false + _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) || + return false else ( - is_same_except( - getproperty(m1, name), - getproperty(m2, name) - ) || + is_same_except(getproperty(m1, name), getproperty(m2, name)) || getproperty(m1, name) isa AbstractRNG || getproperty(m2, name) isa AbstractRNG || - (getproperty(m1, name) isa Flux.Chain && getproperty(m2, name) isa Flux.Chain && _equal_flux_chain(getproperty(m1, name), getproperty(m2, name))) + ( + getproperty(m1, name) isa Flux.Chain && + getproperty(m2, name) isa Flux.Chain && + _equal_flux_chain(getproperty(m1, name), getproperty(m2, name)) + ) ) || return false end else @@ -392,14 +379,10 @@ function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain) if typeof(layer1) != typeof(layer2) return false end - end return true end - - - @doc """ function MMI.fitted_params(model::LaplaceRegressor, fitresult) @@ -427,21 +410,17 @@ function MMI.fitted_params(model::Laplace_Models, fitresult) la, decode = fitresult posterior = la.posterior return ( - μ = posterior.μ, - H = posterior.H, - P = posterior.P, - Σ = posterior.Σ, - n_data = posterior.n_data, - n_params = posterior.n_params, - n_out = posterior.n_out, - loss = posterior.loss + μ=posterior.μ, + H=posterior.H, + P=posterior.P, + Σ=posterior.Σ, + n_data=posterior.n_data, + n_params=posterior.n_params, + n_out=posterior.n_out, + loss=posterior.loss, ) end - - - - @doc """ MMI.training_losses(model::Union{LaplaceRegressor,LaplaceClassifier}, report) @@ -458,9 +437,6 @@ function MMI.training_losses(model::Laplace_Models, report) return report.loss_history end - - - @doc """ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) @@ -479,7 +455,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) """ function MMI.predict(m::Laplace_Models, fitresult, Xnew) la, decode = fitresult - if typeof(m)== LaplaceRegressor + if typeof(m) == LaplaceRegressor yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) # Extract mean and variance matrices means, variances = yhat @@ -488,41 +464,34 @@ function MMI.predict(m::Laplace_Models, fitresult, Xnew) return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)] else - predictions = LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> permutedims + predictions = + LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> + permutedims return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions) end end - - - - - - - MMI.metadata_pkg( - LaplaceRegressor, - name="LaplaceRedux", - package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478", - package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", - is_pure_julia=true, - is_wrapper=true, - package_license = "https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE", + LaplaceRegressor; + name="LaplaceRedux", + package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478", + package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", + is_pure_julia=true, + is_wrapper=true, + package_license="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE", ) MMI.metadata_pkg( - LaplaceClassifier, - name="LaplaceRedux", - package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478", - package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", - is_pure_julia=true, - is_wrapper=true, - package_license = "https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE", + LaplaceClassifier; + name="LaplaceRedux", + package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478", + package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", + is_pure_julia=true, + is_wrapper=true, + package_license="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE", ) - - MLJBase.metadata_model( LaplaceClassifier; input_scitype=Union{ @@ -530,7 +499,7 @@ MLJBase.metadata_model( MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types }, target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass - supports_training_losses = true, + supports_training_losses=true, load_path="LaplaceRedux.LaplaceClassifier", ) # metadata for each model, @@ -541,16 +510,16 @@ MLJBase.metadata_model( MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types }, target_scitype=AbstractArray{MLJBase.Continuous}, - supports_training_losses = true, + supports_training_losses=true, load_path="LaplaceRedux.LaplaceRegressor", - ) -const DOC_LAPLACE_REDUX = "[Laplace Redux – Effortless Bayesian Deep Learning]"* - "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in "* +const DOC_LAPLACE_REDUX = + "[Laplace Redux – Effortless Bayesian Deep Learning]" * + "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in " * "Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103" - """ +""" $(MMI.doc_header(LaplaceClassifier)) `LaplaceClassifier` implements the $DOC_LAPLACE_REDUX for classification models. @@ -820,4 +789,4 @@ See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl """ LaplaceRegressor -#end # module \ No newline at end of file +#end # module From ab8b6bfb38f003219afc876ed649580da1797ab0 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Tue, 15 Oct 2024 18:02:10 +0200 Subject: [PATCH 34/60] added unit tests --- test/direct_mlj_interface.jl | 10 ++++++++++ test/mlj_flux_interfacing.jl | 2 ++ 2 files changed, 12 insertions(+) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 66fe020d..c42546a0 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -22,6 +22,11 @@ using LaplaceRedux yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions MLJBase.predict_mode(mach, Xnew) # point predictions MLJBase.fitted_params(mach) #fitted params function + MLJBase.training_losses(mach) #training loss history + model.epochs= 100 #changing number of epochs + MLJBase.fit!(mach) #testing update function + model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps + MLJBase.fit!(mach) #testing update function (the laplace part) end @@ -47,5 +52,10 @@ yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions predict_mode(mach, Xnew) # point predictions pdf.(yhat, "virginica") # probabilities for the "verginica" class MLJBase.fitted_params(mach) # fitted params +MLJBase.training_losses(mach) #training loss history +model.epochs= 100 #changing number of epochs +MLJBase.fit!(mach) #testing update function +model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps +MLJBase.fit!(mach) #testing update function (the laplace part) end diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl index cb238daf..292397a8 100644 --- a/test/mlj_flux_interfacing.jl +++ b/test/mlj_flux_interfacing.jl @@ -1,3 +1,5 @@ +#deactivated in runtests.jl + using Random: Random import Random.seed! using MLJBase: MLJBase, categorical From 263cc67a68de24f1ec041a1f9d2185cf61eafb87 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Tue, 15 Oct 2024 18:15:50 +0200 Subject: [PATCH 35/60] more units --- test/direct_mlj_interface.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index c42546a0..16956064 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -25,6 +25,8 @@ using LaplaceRedux MLJBase.training_losses(mach) #training loss history model.epochs= 100 #changing number of epochs MLJBase.fit!(mach) #testing update function + model.epochs= 50 #changing number of epochs to a lower number + MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) end @@ -35,7 +37,6 @@ end # Define the model flux_model = Chain( Dense(4, 10, relu), - Dense(10, 10, relu), Dense(10, 3) ) @@ -55,7 +56,19 @@ MLJBase.fitted_params(mach) # fitted params MLJBase.training_losses(mach) #training loss history model.epochs= 100 #changing number of epochs MLJBase.fit!(mach) #testing update function +model.epochs= 50 #changing number of epochs to a lower number +MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) + +# Define a different model +flux_model_two = Chain( + Dense(4, 6, relu), + Dense(6, 3) +) + +model_two = LaplaceClassifier(model=flux_model_two,epochs=70) + +MLJBase.is_same_except(model_two, model, :epochs) end From 453b49fc19e2fa42e5a228f9f43759a8ed2da613 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Tue, 15 Oct 2024 19:07:37 +0200 Subject: [PATCH 36/60] fix --- src/direct_mlj.jl | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 15d3b39f..fb366f0c 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -239,14 +239,20 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y fitresult = (la, y[1]) report = (loss_history=old_loss_history,) - cache = (deepcopy(m), old_state_tree, old_loss_history) + cache = old_cache else - nothing + fitresult = old_fitresult + report = (loss_history=old_loss_history,) + cache = (deepcopy(m), old_state_tree, old_loss_history) end - end + + + - if MMI.is_same_except( + return fitresult, cache, report + + elseif MMI.is_same_except( m, old_model, :fit_prior_nsteps, @@ -282,9 +288,23 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y fitresult = la report = (loss_history=old_loss_history,) cache = (deepcopy(m), old_state_tree, old_loss_history) + + return fitresult, cache, report + + + + else + + fitresult, cache, report = MMI.fit(m, verbosity, X, y) + + + return fitresult, cache, report + + + + end - return fitresult, cache, report end @doc """ From 656b24ea1c0b2a1f13397e17899edfc5afdf5238 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 16 Oct 2024 05:14:53 +0200 Subject: [PATCH 37/60] changed unit test and a minor fix in the update function. there is still the if statement problem. --- src/direct_mlj.jl | 9 ++++++++- test/direct_mlj_interface.jl | 8 ++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index fb366f0c..b6bc3023 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -178,11 +178,12 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y old_model = old_cache[1] old_state_tree = old_cache[2] old_loss_history = old_cache[3] - old_la = old_fitresult[1] + #old_la = old_fitresult[1] epochs = m.epochs if MMI.is_same_except(m, old_model, :epochs) + old_la = old_fitresult[1] if epochs > old_model.epochs for epoch in (old_model.epochs + 1):(epochs) loss_per_epoch = 0.0 @@ -264,7 +265,9 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y :μ₀, :P₀, ) + println(" updating only the laplace optimization part") + old_la = old_fitresult[1] la = LaplaceRedux.Laplace( old_la.model; @@ -294,6 +297,10 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y else + println(" I believe this error is provoked by that if statement in the fit function. This case should address the possibility that \n + the user change the flux chain. In this case the update! function should revert to the fallback fit! function, \n + however y has already been transformed during the previous fit! run . One way to solve this issue is to make sure that the \n + data preparation is completely done in reformat") fitresult, cache, report = MMI.fit(m, verbosity, X, y) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 16956064..ac7b4847 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -17,7 +17,7 @@ using LaplaceRedux X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report) - MLJBase.fit!(mach) + MLJBase.fit!(mach,verbosity=1) Xnew, _ = make_regression(3, 4; rng=123) yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions MLJBase.predict_mode(mach, Xnew) # point predictions @@ -44,7 +44,7 @@ model = LaplaceClassifier(model=flux_model,epochs=50) X, y = @load_iris mach = machine(model, X, y) -MLJBase.fit!(mach) +MLJBase.fit!(mach,verbosity=1) Xnew = (sepal_length = [6.4, 7.2, 7.4], sepal_width = [2.8, 3.0, 2.8], petal_length = [5.6, 5.8, 6.1], @@ -67,8 +67,8 @@ flux_model_two = Chain( Dense(6, 3) ) -model_two = LaplaceClassifier(model=flux_model_two,epochs=70) +model.model = flux_model_two -MLJBase.is_same_except(model_two, model, :epochs) +MLJBase.fit!(mach) end From 7c4d74406c23413273271c7c06ad83f0d7587866 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 16 Oct 2024 09:03:27 +0200 Subject: [PATCH 38/60] only things left to fix are the selectrows functions --- .../mlj-interfacing/direct_mlj.ipynb | 1137 +++++++++++++++++ src/direct_mlj.jl | 79 +- test/direct_mlj_interface.jl | 2 + 3 files changed, 1172 insertions(+), 46 deletions(-) create mode 100644 dev/notebooks/mlj-interfacing/direct_mlj.ipynb diff --git a/dev/notebooks/mlj-interfacing/direct_mlj.ipynb b/dev/notebooks/mlj-interfacing/direct_mlj.ipynb new file mode 100644 index 00000000..4169dc0e --- /dev/null +++ b/dev/notebooks/mlj-interfacing/direct_mlj.ipynb @@ -0,0 +1,1137 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "using Revise\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\docs`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1mStatus\u001b[22m\u001b[39m `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\docs\\Project.toml`\n", + " \u001b[90m[324d7699] \u001b[39mCategoricalArrays v0.10.8\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[a93c6f00] \u001b[39mDataFrames v1.6.1\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[31c24e10] \u001b[39mDistributions v0.25.111\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[e30172f5] \u001b[39mDocumenter v1.6.0\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[587475ba] \u001b[39mFlux v0.14.19\n", + " \u001b[90m[c52c1a26] \u001b[39mLaplaceRedux v1.1.1 `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl`\n", + "\u001b[33m⌅\u001b[39m \u001b[90m[094fc8d1] \u001b[39mMLJFlux v0.5.1\n", + " \u001b[90m[e80e1ace] \u001b[39mMLJModelInterface v1.11.0\n", + " \u001b[90m[3bd65402] \u001b[39mOptimisers v0.3.3\n", + " \u001b[90m[91a5bcdd] \u001b[39mPlots v1.40.8\n", + " \u001b[90m[ce6b1742] \u001b[39mRDatasets v0.7.7\n", + " \u001b[90m[860ef19b] \u001b[39mStableRNGs v1.0.2\n", + " \u001b[90m[2913bbd2] \u001b[39mStatsBase v0.34.3\n", + " \u001b[90m[bd369af6] \u001b[39mTables v1.12.0\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[bd7198b4] \u001b[39mTaijaPlotting v1.2.0\n", + " \u001b[90m[592b5752] \u001b[39mTrapz v2.0.3\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[e88e6eb3] \u001b[39mZygote v0.6.70\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[02a925ec] \u001b[39mcuDNN v1.3.2\n", + " \u001b[90m[9a3f8284] \u001b[39mRandom\n", + " \u001b[90m[10745b16] \u001b[39mStatistics v1.10.0\n", + "\u001b[36m\u001b[1mInfo\u001b[22m\u001b[39m Packages marked with \u001b[32m⌃\u001b[39m and \u001b[33m⌅\u001b[39m have new versions available. Those with \u001b[32m⌃\u001b[39m may be upgradable, but those with \u001b[33m⌅\u001b[39m are restricted by compatibility constraints from upgrading. To see why use `status --outdated`\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: using MLJBase.predict in module LaplaceRedux conflicts with an existing identifier.\n", + "WARNING: using MLJBase.fit! in module LaplaceRedux conflicts with an existing identifier.\n" + ] + } + ], + "source": [ + "using Pkg\n", + "\n", + "\n", + "Pkg.activate(\"C:/Users/Pasqu/Documents/julia_projects/LaplaceRedux.jl/docs\")\n", + "Pkg.status()\n", + "using LaplaceRedux\n", + "using MLJBase\n", + "using Random\n", + "using DataFrames\n", + "using Flux" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Tables.MatrixTable{Matrix{Float64}} with 100 rows, 2 columns, and schema:\n", + " :x1 Float64\n", + " :x2 Float64, [-0.15053240354230857, -0.16143107735113107, -0.28104782384528254, 0.8905842690519058, -0.2716955057136559, 0.9606721208381163, 0.14403243794060133, 0.13743002853667605, 0.820641892942472, 0.2270783932443115 … 0.4639933046961763, 0.30449384622096687, 0.2744588755171263, -31.785240173822324, 2.58951832655098, 0.10969223924903307, 0.1600255529817666, 0.5913011997917647, -0.39253898214541977, -0.7247243478011852])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "X, y = make_regression(100, 2; noise=0.5, sparse=0.2, outliers=0.1)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of rows: 100\n", + "Number of columns: 2\n" + ] + } + ], + "source": [ + "using Tables\n", + "# Get the number of rows\n", + "num_rows = Tables.rows(X) |> length\n", + "\n", + "# Get the number of columns\n", + "num_columns = Tables.columnnames(X) |> length\n", + "\n", + "# Display the dimensions\n", + "println(\"Number of rows: \", num_rows)\n", + "println(\"Number of columns: \", num_columns)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Chain(\n", + " Dense(2 => 10, relu), \u001b[90m# 30 parameters\u001b[39m\n", + " Dense(10 => 10, relu), \u001b[90m# 110 parameters\u001b[39m\n", + " Dense(10 => 1), \u001b[90m# 11 parameters\u001b[39m\n", + ") \u001b[90m # Total: 6 arrays, \u001b[39m151 parameters, 988 bytes." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using Flux\n", + "# Define the model\n", + "flux_model = Chain(\n", + " Dense(2, 10, relu),\n", + " Dense(10, 10, relu),\n", + " Dense(10, 1)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LaplaceRegressor(\n", + " model = Chain(Dense(2 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), \n", + " flux_loss = Flux.Losses.mse, \n", + " optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), \n", + " epochs = 1000, \n", + " batch_size = 32, \n", + " subset_of_weights = :all, \n", + " subnetwork_indices = nothing, \n", + " hessian_structure = :full, \n", + " backend = :GGN, \n", + " σ = 1.0, \n", + " μ₀ = 0.0, \n", + " P₀ = nothing, \n", + " fit_prior_nsteps = 100)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "# Create an instance of LaplaceRegressor with the Flux model\n", + "laplace_regressor = LaplaceRegressor(\n", + " model = flux_model,\n", + " subset_of_weights = :all,\n", + " subnetwork_indices = nothing,\n", + " hessian_structure = :full,\n", + " backend = :GGN,\n", + " σ = 1.0,\n", + " μ₀ = 0.0,\n", + " P₀ = nothing,\n", + " fit_prior_nsteps = 100\n", + ")\n", + "\n", + "# Display the LaplaceRegressor object\n", + "laplace_regressor" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tables.MatrixTable{Matrix{Float64}}" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "typeof(X)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
17×5 DataFrame
Rowx1x2x3x4y
Float64Float64Float64Float64Cat…
15.8941912.99796.185128.992862
2-8.148346.23246-1.684978.969051
3-4.882295.35276-0.458768.207561
44.025856.9476913.4032-0.04192232
5-5.536356.55656-1.670638.770411
6-6.618584.65032-1.151988.348971
710.234411.427813.05448.530252
8-6.050528.12027-3.687088.787321
9-5.067694.8631-3.583468.413711
1010.83736.324729.791636.659622
11-6.632265.45149-0.388619.00071
121.628124.6107311.660211.72412
13-6.486796.68166-3.325699.196181
147.955962.2392812.68971.778572
15-6.364665.82985-0.7025028.449761
168.032943.859015.507412.20142
17-7.450677.01011-1.961877.843361
" + ], + "text/latex": [ + "\\begin{tabular}{r|ccccc}\n", + "\t& x1 & x2 & x3 & x4 & y\\\\\n", + "\t\\hline\n", + "\t& Float64 & Float64 & Float64 & Float64 & Cat…\\\\\n", + "\t\\hline\n", + "\t1 & 5.89419 & 12.9979 & 6.18512 & 8.99286 & 2 \\\\\n", + "\t2 & -8.14834 & 6.23246 & -1.68497 & 8.96905 & 1 \\\\\n", + "\t3 & -4.88229 & 5.35276 & -0.45876 & 8.20756 & 1 \\\\\n", + "\t4 & 4.02585 & 6.94769 & 13.4032 & -0.0419223 & 2 \\\\\n", + "\t5 & -5.53635 & 6.55656 & -1.67063 & 8.77041 & 1 \\\\\n", + "\t6 & -6.61858 & 4.65032 & -1.15198 & 8.34897 & 1 \\\\\n", + "\t7 & 10.2344 & 11.4278 & 13.0544 & 8.53025 & 2 \\\\\n", + "\t8 & -6.05052 & 8.12027 & -3.68708 & 8.78732 & 1 \\\\\n", + "\t9 & -5.06769 & 4.8631 & -3.58346 & 8.41371 & 1 \\\\\n", + "\t10 & 10.8373 & 6.32472 & 9.79163 & 6.65962 & 2 \\\\\n", + "\t11 & -6.63226 & 5.45149 & -0.38861 & 9.0007 & 1 \\\\\n", + "\t12 & 1.62812 & 4.61073 & 11.6602 & 11.7241 & 2 \\\\\n", + "\t13 & -6.48679 & 6.68166 & -3.32569 & 9.19618 & 1 \\\\\n", + "\t14 & 7.95596 & 2.23928 & 12.6897 & 1.77857 & 2 \\\\\n", + "\t15 & -6.36466 & 5.82985 & -0.702502 & 8.44976 & 1 \\\\\n", + "\t16 & 8.03294 & 3.85901 & 5.50741 & 2.2014 & 2 \\\\\n", + "\t17 & -7.45067 & 7.01011 & -1.96187 & 7.84336 & 1 \\\\\n", + "\\end{tabular}\n" + ], + "text/plain": [ + "\u001b[1m17×5 DataFrame\u001b[0m\n", + "\u001b[1m Row \u001b[0m│\u001b[1m x1 \u001b[0m\u001b[1m x2 \u001b[0m\u001b[1m x3 \u001b[0m\u001b[1m x4 \u001b[0m\u001b[1m y \u001b[0m\n", + " │\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Cat… \u001b[0m\n", + "─────┼─────────────────────────────────────────────────\n", + " 1 │ 5.89419 12.9979 6.18512 8.99286 2\n", + " 2 │ -8.14834 6.23246 -1.68497 8.96905 1\n", + " 3 │ -4.88229 5.35276 -0.45876 8.20756 1\n", + " 4 │ 4.02585 6.94769 13.4032 -0.0419223 2\n", + " 5 │ -5.53635 6.55656 -1.67063 8.77041 1\n", + " 6 │ -6.61858 4.65032 -1.15198 8.34897 1\n", + " 7 │ 10.2344 11.4278 13.0544 8.53025 2\n", + " 8 │ -6.05052 8.12027 -3.68708 8.78732 1\n", + " 9 │ -5.06769 4.8631 -3.58346 8.41371 1\n", + " 10 │ 10.8373 6.32472 9.79163 6.65962 2\n", + " 11 │ -6.63226 5.45149 -0.38861 9.0007 1\n", + " 12 │ 1.62812 4.61073 11.6602 11.7241 2\n", + " 13 │ -6.48679 6.68166 -3.32569 9.19618 1\n", + " 14 │ 7.95596 2.23928 12.6897 1.77857 2\n", + " 15 │ -6.36466 5.82985 -0.702502 8.44976 1\n", + " 16 │ 8.03294 3.85901 5.50741 2.2014 2\n", + " 17 │ -7.45067 7.01011 -1.96187 7.84336 1" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using MLJ, DataFrames\n", + "X, y = make_blobs(100, 4; centers=2, cluster_std=[1.0, 3.0 ])\n", + "dfBlobs = DataFrame(X)\n", + "dfBlobs.y = y\n", + "first(dfBlobs, 17)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Chain(\n", + " Dense(4 => 10, relu), \u001b[90m# 50 parameters\u001b[39m\n", + " Dense(10 => 10, relu), \u001b[90m# 110 parameters\u001b[39m\n", + " Dense(10 => 2), \u001b[90m# 22 parameters\u001b[39m\n", + ") \u001b[90m # Total: 6 arrays, \u001b[39m182 parameters, 1.086 KiB." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using Flux\n", + "# Define the model\n", + "flux_model = Chain(\n", + " Dense(4, 10, relu),\n", + " Dense(10, 10, relu),\n", + " Dense(10, 2)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LaplaceClassifier(\n", + " model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 2)), \n", + " flux_loss = Flux.Losses.logitcrossentropy, \n", + " optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), \n", + " epochs = 1000, \n", + " batch_size = 32, \n", + " subset_of_weights = :all, \n", + " subnetwork_indices = nothing, \n", + " hessian_structure = :full, \n", + " backend = :GGN, \n", + " σ = 1.0, \n", + " μ₀ = 0.0, \n", + " P₀ = nothing, \n", + " fit_prior_nsteps = 100, \n", + " link_approx = :probit)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Create an instance of LaplaceRegressor with the Flux model\n", + "laplace_classifier = LaplaceClassifier(\n", + " model = flux_model,\n", + " subset_of_weights = :all,\n", + " subnetwork_indices = nothing,\n", + " hessian_structure = :full,\n", + " backend = :GGN,\n", + " σ = 1.0,\n", + " μ₀ = 0.0,\n", + " P₀ = nothing,\n", + " fit_prior_nsteps = 100\n", + ")\n", + "\n", + "# Display the LaplaceRegressor object\n", + "laplace_classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([5.894190948229072 -8.148338397036424 … 4.6901216310216896 8.754850004556681; 12.99791729188348 6.232458473633034 … 10.904369405648966 12.187158715203248; 6.185122057812649 -1.6849673572899138 … 10.09295616393206 15.616135680604215; 8.992861668498659 8.969052840985539 … 8.673387526088336 3.444822170655238], (Bool[1 0 … 1 1; 0 1 … 0 0], CategoricalArrays.CategoricalValue{Int64, UInt32} 2))" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "X_test_reformat,y_test_reformat= MLJBase.reformat(laplace_classifier,X,y)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2×100 OneHotMatrix(::Vector{UInt32}) with eltype Bool:\n", + " 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ … 1 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1\n", + " ⋅ 1 1 ⋅ 1 1 ⋅ 1 1 ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 1 1 1 1 1 ⋅ ⋅" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "y_test_reformat[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2×3 view(OneHotMatrix(::Vector{UInt32}), :, [2, 3, 4]) with eltype Bool:\n", + " 0 0 1\n", + " 1 1 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "view(y_test_reformat[1],:, [2,3,4])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4×3 view(::Matrix{Float64}, :, [2, 3, 4]) with eltype Float64:\n", + " -8.14834 -4.88229 4.02585\n", + " 6.23246 5.35276 6.94769\n", + " -1.68497 -0.45876 13.4032\n", + " 8.96905 8.20756 -0.0419223" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "view(X_test_reformat, :, [2,3,4])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((sepal_length = [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9 … 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9], sepal_width = [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1 … 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0], petal_length = [1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5 … 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1], petal_width = [0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1 … 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]), CategoricalArrays.CategoricalValue{String, UInt32}[\"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\" … \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\"])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using MLJ\n", + "#DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree\n", + "\n", + "using Flux\n", + "# Define the model\n", + "flux_model = Chain(\n", + " Dense(4, 10, relu),\n", + " Dense(10, 10, relu),\n", + " Dense(10, 3)\n", + ")\n", + "\n", + "model = LaplaceClassifier(model=flux_model)\n", + "\n", + "X, y = @load_iris\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3-element Vector{String}:\n", + " \"setosa\"\n", + " \"versicolor\"\n", + " \"virginica\"" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "levels(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Training machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:499\n", + "┌ Warning: Layer with Float32 parameters got Float64 input.\n", + "│ The input will be converted, but any earlier layers may be very slow.\n", + "│ layer = Dense(4 => 10, relu)\n", + "│ summary(x) = 4×32 Matrix{Float64}\n", + "└ @ Flux C:\\Users\\Pasqu\\.julia\\packages\\Flux\\HBF2N\\src\\layers\\stateless.jl:60\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 100: Loss: 4.042439937591553 \n", + "Epoch 200: Loss: 3.533539354801178 \n", + "Epoch 300: Loss: 3.2006053030490875 \n", + "Epoch 400: Loss: 2.9201304614543915 \n", + "Epoch 500: Loss: 2.647124856710434 \n", + "Epoch 600: Loss: 2.3577273190021515 \n", + "Epoch 700: Loss: 2.0284294933080673 \n", + "Epoch 800: Loss: 1.6497879326343536 \n", + "Epoch 900: Loss: 1.3182911276817322 \n", + "Epoch 1000: Loss: 1.077766239643097 \n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n", + " args: \n", + " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mach = machine(model, X, y) |> MLJBase.fit!" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3-element Vector{Float64}:\n", + " 0.9488185211171959\n", + " 0.7685919895442062\n", + " 0.9454937120287679" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "Xnew = (sepal_length = [6.4, 7.2, 7.4],\n", + " sepal_width = [2.8, 3.0, 2.8],\n", + " petal_length = [5.6, 5.8, 6.1],\n", + " petal_width = [2.1, 1.6, 1.9],)\n", + "yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions\n", + "predict_mode(mach, Xnew) # point predictions\n", + "pdf.(yhat, \"virginica\") # probabilities for the \"verginica\" class" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CV(\n", + " nfolds = 3, \n", + " shuffle = false, \n", + " rng = TaskLocalRNG())" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using Random\n", + "cv = CV(nfolds=3)\n", + "CV(\n", + " nfolds = 3,\n", + " shuffle = false,\n", + " rng = Random.default_rng())" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(μ = Float32[0.36796558, -0.590969, -0.0005966219, -0.1918786, -0.21046183, 0.33172563, -0.6454487, -0.43057427, 0.09267368, -0.31252357 … 0.31747338, -0.42893136, -0.036560666, 0.12848923, -0.022290554, -0.17562295, 0.10335993, 2.0584264, -0.27860576, -0.9749556],\n", + " H = [19221.80854034424 0.0 … 175.01133728027344 416.5687526464462; 0.0 0.0 … 0.0 0.0; … ; 175.01133728027344 0.0 … 150.0 0.0; 416.5687526464462 0.0 … 0.0 150.0],\n", + " P = [19222.80854034424 0.0 … 175.01133728027344 416.5687526464462; 0.0 1.0 … 0.0 0.0; … ; 175.01133728027344 0.0 … 151.0 0.0; 416.5687526464462 0.0 … 0.0 151.0],\n", + " Σ = [0.4942412474220449 0.0 … 0.00012692197672594825 -0.00029493662610333626; 0.0 1.0 … -0.0 -0.0; … ; 0.00012692197672495198 0.0 … 0.019500175331854216 2.549733600297664e-5; -0.000294936626108007 0.0 … 2.5497336002974752e-5 0.019426448013349775],\n", + " n_data = 160,\n", + " n_params = 193,\n", + " n_out = 3,\n", + " loss = 76.21374633908272,)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "MLJBase.fitted_params(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1100: Loss: 0.9005963057279587 \n", + "Epoch 1200: Loss: 0.7674828618764877 \n", + "Epoch 1300: Loss: 0.6668129041790962 " + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 1400: Loss: 0.5900264084339142 \n", + "Epoch 1500: Loss: 0.5311783701181412 \n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n", + " args: \n", + " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.epochs= 1500\n", + "uff=MLJBase.fit!(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " updating only the laplace optimization part\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n", + " args: \n", + " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.fit_prior_nsteps = 200\n", + "uff=MLJBase.fit!(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Training machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:499\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 100: Loss: 123.99619626934783 \n", + "Epoch 200: Loss: 110.70134709925446 \n", + "Epoch 300: Loss: 101.54836914035226 \n", + "Epoch 400: Loss: 93.59849316803738 \n", + "Epoch 500: Loss: 85.18775669356168 \n", + "Epoch 600: Loss: 75.76182416873147 \n", + "Epoch 700: Loss: 67.35201676247976 \n", + "Epoch 800: Loss: 61.22682561405186 \n", + "Epoch 900: Loss: 56.388587609986764 \n", + "Epoch 1000: Loss: 51.8525101259686 \n" + ] + }, + { + "data": { + "text/plain": [ + "1×3 Matrix{Float64}:\n", + " -2.0295 -1.03661 1.29323" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using MLJ\n", + "#LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux\n", + "flux_model = Chain(\n", + " Dense(4, 10, relu),\n", + " Dense(10, 10, relu),\n", + " Dense(10, 1)\n", + ")\n", + "model = LaplaceRegressor(model=flux_model)\n", + "\n", + "X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)\n", + "mach = machine(model, X, y) \n", + "uff=MLJBase.fit!(mach)\n", + "Xnew, _ = make_regression(3, 4; rng=123)\n", + "yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions\n", + "MLJBase.predict_mode(mach, Xnew) # point predictions\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(μ = Float32[1.0187618, 1.1801981, 1.5047036, 0.09397529, -1.0688609, -0.59974676, 1.2977643, -1.8645017, -0.6493797, 0.8598072 … -3.1050463, -2.267896, 1.018275, -2.0316908, 0.72548866, -2.7309833, -2.1938443, 0.8200029, 1.2517836, 0.3313799],\n", + " H = [8124.873962402344 1238.1352081298828 … 1980.7237396240234 324.216007232666; 1238.1352081298828 496.2319145202637 … 451.8493309020996 62.39444828033447; … ; 1980.7237396240234 451.8493309020996 … 1277.6469421386719 212.83043670654297; 324.216007232666 62.39444828033447 … 212.83043670654297 100.0],\n", + " P = [8125.873962402344 1238.1352081298828 … 1980.7237396240234 324.216007232666; 1238.1352081298828 497.2319145202637 … 451.8493309020996 62.39444828033447; … ; 1980.7237396240234 451.8493309020996 … 1278.6469421386719 212.83043670654297; 324.216007232666 62.39444828033447 … 212.83043670654297 101.0],\n", + " Σ = [0.08818398027482557 0.01068994964485469 … -0.028911073548363073 0.0029380249323247053; 0.01068994964484666 0.3740392430965252 … -0.03321343036791264 0.03808189340124685; … ; -0.02891107354841918 -0.03321343036788035 … 0.37993018751981145 -0.007007167070441715; 0.002938024932320219 0.03808189340124439 … -0.007007167070367749 0.8779610764649773],\n", + " n_data = 128,\n", + " n_params = 171,\n", + " n_out = 1,\n", + " loss = 25.89132467732224,)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "MLJBase.fitted_params(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1000-element Vector{Any}:\n", + " 132.4757012476366\n", + " 132.34841820836348\n", + " 132.2594675536756\n", + " 132.1764167781873\n", + " 132.09720400291872\n", + " 132.02269947440837\n", + " 131.95014241005182\n", + " 131.87929837827625\n", + " 131.8104252672236\n", + " 131.74319014461662\n", + " ⋮\n", + " 52.175198365270234\n", + " 52.148401306572445\n", + " 52.10149557862212\n", + " 52.032551867920525\n", + " 51.99690913470439\n", + " 51.994054497037524\n", + " 51.97321907106968\n", + " 51.932579915942064\n", + " 51.8525101259686" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "a= MLJBase.training_losses(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1100: Loss: 47.37399118448011 \n", + "Epoch 1200: Loss: 42.67007994813126 \n", + "Epoch 1300: Loss: 38.057605481177234 \n", + "Epoch 1400: Loss: 33.846777755430416 \n", + "Epoch 1500: Loss: 29.514834265471144 \n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n", + " args: \n", + " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.epochs= 1500\n", + "uff=MLJBase.fit!(mach)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1500-element Vector{Any}:\n", + " 132.4757012476366\n", + " 132.34841820836348\n", + " 132.2594675536756\n", + " 132.1764167781873\n", + " 132.09720400291872\n", + " 132.02269947440837\n", + " 131.95014241005182\n", + " 131.87929837827625\n", + " 131.8104252672236\n", + " 131.74319014461662\n", + " ⋮\n", + " 29.979098347946184\n", + " 29.896050775635096\n", + " 29.83609799111873\n", + " 29.763058304997937\n", + " 29.69908817371154\n", + " 29.646864259231556\n", + " 29.623977677653713\n", + " 29.549562760589527\n", + " 29.514834265471144" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "MLJBase.training_losses(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n", + " args: \n", + " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.epochs= 1200\n", + "uff=MLJBase.fit!(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " updating only the laplace optimization part\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n", + " args: \n", + " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.fit_prior_nsteps = 200\n", + "uff=MLJBase.fit!(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "MLJBase.selectrows(model, I, Xmatrix, y) = (view(Xmatrix, :, I), view(y[1], I))" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "#MLJBase.evaluate(model, X, y, resampling=cv, measure=l2, verbosity=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "ename": "BoundsError", + "evalue": "BoundsError: attempt to access 4×100 Matrix{Float64} at index [[35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], 1:100]", + "output_type": "error", + "traceback": [ + "BoundsError: attempt to access 4×100 Matrix{Float64} at index [[35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], 1:100]\n", + "\n", + "Stacktrace:\n", + " [1] throw_boundserror(A::Matrix{Float64}, I::Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}})\n", + " @ Base .\\abstractarray.jl:737\n", + " [2] checkbounds\n", + " @ .\\abstractarray.jl:702 [inlined]\n", + " [3] _getindex\n", + " @ .\\multidimensional.jl:888 [inlined]\n", + " [4] getindex\n", + " @ .\\abstractarray.jl:1291 [inlined]\n", + " [5] _selectrows(::MLJModelInterface.FullInterface, ::Val{:other}, X::Matrix{Float64}, r::Vector{Int64})\n", + " @ MLJModelInterface C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:350\n", + " [6] selectrows\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:340 [inlined]\n", + " [7] selectrows\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:336 [inlined]\n", + " [8] #29\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\model_api.jl:88 [inlined]\n", + " [9] map\n", + " @ .\\tuple.jl:292 [inlined]\n", + " [10] selectrows(::LaplaceRegressor, ::Vector{Int64}, ::Matrix{Float64}, ::Tuple{Matrix{Float64}, Nothing})\n", + " @ MLJModelInterface C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\model_api.jl:88\n", + " [11] fit_only!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}; rows::Vector{Int64}, verbosity::Int64, force::Bool, composite::Nothing)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:676\n", + " [12] fit_only!\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:617 [inlined]\n", + " [13] #fit!#63\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:789 [inlined]\n", + " [14] fit!\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:786 [inlined]\n", + " [15] fit_and_extract_on_fold\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1463 [inlined]\n", + " [16] (::MLJBase.var\"#277#278\"{MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, Machine{LaplaceRegressor, LaplaceRegressor, true}, Int64})(k::Int64)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1289\n", + " [17] _mapreduce(f::MLJBase.var\"#277#278\"{MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, Machine{LaplaceRegressor, LaplaceRegressor, true}, Int64}, op::typeof(vcat), ::IndexLinear, A::UnitRange{Int64})\n", + " @ Base .\\reduce.jl:440\n", + " [18] _mapreduce_dim\n", + " @ .\\reducedim.jl:365 [inlined]\n", + " [19] mapreduce\n", + " @ .\\reducedim.jl:357 [inlined]\n", + " [20] _evaluate!(func::MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, mach::Machine{LaplaceRegressor, LaplaceRegressor, true}, ::CPU1{Nothing}, nfolds::Int64, verbosity::Int64)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1288\n", + " [21] evaluate!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}, resampling::Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, weights::Nothing, class_weights::Nothing, rows::Nothing, verbosity::Int64, repeats::Int64, measures::Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, operations::Vector{typeof(predict_mean)}, acceleration::CPU1{Nothing}, force::Bool, per_observation_flag::Bool, logger::Nothing, user_resampling::CV, compact::Bool)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1510\n", + " [22] evaluate!(::Machine{LaplaceRegressor, LaplaceRegressor, true}, ::CV, ::Nothing, ::Nothing, ::Nothing, ::Int64, ::Int64, ::Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, ::Vector{typeof(predict_mean)}, ::CPU1{Nothing}, ::Bool, ::Bool, ::Nothing, ::CV, ::Bool)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1603\n", + " [23] evaluate!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}; resampling::CV, measures::Nothing, measure::StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}, weights::Nothing, class_weights::Nothing, operations::Nothing, operation::Nothing, acceleration::CPU1{Nothing}, rows::Nothing, repeats::Int64, force::Bool, check_measure::Bool, per_observation::Bool, verbosity::Int64, logger::Nothing, compact::Bool)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1232\n", + " [24] top-level scope\n", + " @ c:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\dev\\notebooks\\mlj-interfacing\\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X54sZmlsZQ==.jl:1" + ] + } + ], + "source": [ + "evaluate!(mach, resampling=cv, measure=l2, verbosity=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "#flux_model = Chain(\n", + " #Dense(4, 10, relu),\n", + " #Dense(10, 10, relu),\n", + " #Dense(10, 1))\n", + "\n", + "\n", + "#nested_flux_model = Chain(\n", + " #Chain(Dense(10, 5, relu), Dense(5, 5, relu)),\n", + " #Chain(Dense(5, 3, relu), Dense(3, 2)))\n", + "#model = LaplaceRegressor()\n", + "#model.model= nested_flux_model\n", + "\n", + "#copy_model= deepcopy(model)\n", + "\n", + "#copy_model.epochs= 2000\n", + "#copy_model.optimiser = Descent()\n", + "#MLJBase.is_same_except(model , copy_model,:epochs )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.10.5", + "language": "julia", + "name": "julia-1.10" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.10.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index b6bc3023..1781fb6c 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -47,7 +47,25 @@ end Laplace_Models = Union{LaplaceRegressor,LaplaceClassifier} # for fit: -MMI.reformat(::Laplace_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :)) +MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :),nothing)) + + +function MMI.reformat(::LaplaceClassifier, X, y) + + X = MLJBase.matrix(X) |> permutedims + + + y= reshape(y, 1, :) + # Convert labels to integer format starting from 0 for one-hot encoding + y_plain = MLJBase.int(y[1, :]) .- 1 + # One-hot encoding of labels + unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding + y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding + return X,(y_onehot, y[1]) +end + +#MMI.selectrows(::LaplaceClassifier, I, Xmatrix, y) = (view(Xmatrix, :, I), (view(y[1],I),y[2])) +#MMI.selectrows(::LaplaceRegressor, I, Xmatrix, y) = (view(Xmatrix, :,I), (view(y[1],I),y[2]) ) #for predict: MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,) @@ -68,18 +86,10 @@ Fit a Laplace model using the provided features and target values. - `report`: A Namedtuple containing the loss history of the fitting process. """ function MMI.fit(m::Laplace_Models, verbosity, X, y) - decode = y[1] + decode = y[2] - if typeof(m) == LaplaceRegressor - nothing - else - # Convert labels to integer format starting from 0 for one-hot encoding - y_plain = MLJBase.int(y[1, :]) .- 1 + y= y[1] - # One-hot encoding of labels - unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding - y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding - end # Make a copy of the model because Flux does not allow to mutate hyperparameters copied_model = deepcopy(m.model) @@ -163,22 +173,14 @@ Update the Laplace model using the provided new data points. - `report`: A Namedtuple containing the complete loss history of the fitting process. """ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y) - if typeof(m) == LaplaceRegressor - nothing - else - # Convert labels to integer format starting from 0 for one-hot encoding - y_plain = MLJBase.int(y[1, :]) .- 1 - - # One-hot encoding of labels - unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding - y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding - end + decode = y[2] + y_up=y[1] - data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) + data_loader = Flux.DataLoader((X, y_up); batchsize=m.batch_size) old_model = old_cache[1] old_state_tree = old_cache[2] old_loss_history = old_cache[3] - #old_la = old_fitresult[1] + epochs = m.epochs @@ -238,21 +240,18 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y LaplaceRedux.fit!(la, data_loader) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - fitresult = (la, y[1]) + fitresult = (la, decode) report = (loss_history=old_loss_history,) - cache = old_cache + cache = (deepcopy(m), old_state_tree, old_loss_history) else - fitresult = old_fitresult + println("The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary") + fitresult = (old_la, decode) report = (loss_history=old_loss_history,) cache = (deepcopy(m), old_state_tree, old_loss_history) end - - - return fitresult, cache, report - elseif MMI.is_same_except( m, old_model, @@ -265,7 +264,6 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y :μ₀, :P₀, ) - println(" updating only the laplace optimization part") old_la = old_fitresult[1] @@ -288,30 +286,19 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y LaplaceRedux.fit!(la, data_loader) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - fitresult = la + fitresult = (la,decode) report = (loss_history=old_loss_history,) cache = (deepcopy(m), old_state_tree, old_loss_history) - return fitresult, cache, report - - else - println(" I believe this error is provoked by that if statement in the fit function. This case should address the possibility that \n - the user change the flux chain. In this case the update! function should revert to the fallback fit! function, \n - however y has already been transformed during the previous fit! run . One way to solve this issue is to make sure that the \n - data preparation is completely done in reformat") - - fitresult, cache, report = MMI.fit(m, verbosity, X, y) - - - return fitresult, cache, report - - + fitresult, cache, report = MLJBase.fit(m, verbosity,X,y) + end + return fitresult, cache, report end @doc """ diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index ac7b4847..c225bfde 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -4,6 +4,7 @@ using MLJBase: MLJBase, categorical using Flux using StableRNGs using MLJ +using MLJ:predict,fit! using LaplaceRedux @@ -29,6 +30,7 @@ using LaplaceRedux MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) + evaluate!(mach, resampling=cv, measure=l2, verbosity=0) end From f872d965fdb546c2105ed8582e51b01e7ac63b3c Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 16 Oct 2024 13:49:27 +0200 Subject: [PATCH 39/60] returning one-hot encoded directly --- src/direct_mlj.jl | 36 ++++++++------------ test/direct_mlj_interface.jl | 64 ++++++++++++++++++------------------ 2 files changed, 46 insertions(+), 54 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 1781fb6c..e5b611d1 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -44,30 +44,22 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic fit_prior_nsteps::Int = 100::(_ > 0) end -Laplace_Models = Union{LaplaceRegressor,LaplaceClassifier} +LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier} # for fit: -MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :),nothing)) - +MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :)) function MMI.reformat(::LaplaceClassifier, X, y) X = MLJBase.matrix(X) |> permutedims + y = categorical(y) + unique_labels = y.pool.levels + y = Flux.onehotbatch(y, unique_labels) # One-hot encoding - - y= reshape(y, 1, :) - # Convert labels to integer format starting from 0 for one-hot encoding - y_plain = MLJBase.int(y[1, :]) .- 1 - # One-hot encoding of labels - unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding - y_onehot = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding - return X,(y_onehot, y[1]) + return X, y end -#MMI.selectrows(::LaplaceClassifier, I, Xmatrix, y) = (view(Xmatrix, :, I), (view(y[1],I),y[2])) -#MMI.selectrows(::LaplaceRegressor, I, Xmatrix, y) = (view(Xmatrix, :,I), (view(y[1],I),y[2]) ) -#for predict: -MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,) +MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,) @doc """ MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y) @@ -85,7 +77,7 @@ Fit a Laplace model using the provided features and target values. - `cache`: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history. - `report`: A Namedtuple containing the loss history of the fitting process. """ -function MMI.fit(m::Laplace_Models, verbosity, X, y) +function MMI.fit(m::LaplaceModels, verbosity, X, y) decode = y[2] y= y[1] @@ -172,7 +164,7 @@ Update the Laplace model using the provided new data points. - `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history. - `report`: A Namedtuple containing the complete loss history of the fitting process. """ -function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y) +function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) decode = y[2] y_up=y[1] @@ -302,7 +294,7 @@ function MMI.update(m::Laplace_Models, verbosity, old_fitresult, old_cache, X, y end @doc """ - function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...) + function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...) If both `m1` and `m2` are of `MLJType`, return `true` if the following conditions all hold, and `false` otherwise: @@ -329,7 +321,7 @@ meaning; see [`deep_properties`](@ref)) for details. If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`. """ -function MMI.is_same_except(m1::Laplace_Models, m2::Laplace_Models, exceptions::Symbol...) +function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...) typeof(m1) === typeof(m2) || return false names = propertynames(m1) propertynames(m2) === names || return false @@ -420,7 +412,7 @@ end - `loss`: The loss value of the posterior distribution. """ -function MMI.fitted_params(model::Laplace_Models, fitresult) +function MMI.fitted_params(model::LaplaceModels, fitresult) la, decode = fitresult posterior = la.posterior return ( @@ -447,7 +439,7 @@ Retrieve the training loss history from the given `report`. # Returns - A collection representing the loss history from the training report. """ -function MMI.training_losses(model::Laplace_Models, report) +function MMI.training_losses(model::LaplaceModels, report) return report.loss_history end @@ -467,7 +459,7 @@ function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) for LaplaceClassifier: - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data. """ -function MMI.predict(m::Laplace_Models, fitresult, Xnew) +function MMI.predict(m::LaplaceModels, fitresult, Xnew) la, decode = fitresult if typeof(m) == LaplaceRegressor yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index c225bfde..6254a7de 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -7,6 +7,7 @@ using MLJ using MLJ:predict,fit! using LaplaceRedux +cv = CV(; nfolds=3) @testset "Regression" begin flux_model = Chain( @@ -36,41 +37,40 @@ end @testset "Classification" begin -# Define the model -flux_model = Chain( - Dense(4, 10, relu), - Dense(10, 3) -) + # Define the model + flux_model = Chain( + Dense(4, 10, relu), + Dense(10, 3) + ) -model = LaplaceClassifier(model=flux_model,epochs=50) + model = LaplaceClassifier(model=flux_model,epochs=50) -X, y = @load_iris -mach = machine(model, X, y) -MLJBase.fit!(mach,verbosity=1) -Xnew = (sepal_length = [6.4, 7.2, 7.4], - sepal_width = [2.8, 3.0, 2.8], - petal_length = [5.6, 5.8, 6.1], - petal_width = [2.1, 1.6, 1.9],) -yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions -predict_mode(mach, Xnew) # point predictions -pdf.(yhat, "virginica") # probabilities for the "verginica" class -MLJBase.fitted_params(mach) # fitted params -MLJBase.training_losses(mach) #training loss history -model.epochs= 100 #changing number of epochs -MLJBase.fit!(mach) #testing update function -model.epochs= 50 #changing number of epochs to a lower number -MLJBase.fit!(mach) #testing update function -model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps -MLJBase.fit!(mach) #testing update function (the laplace part) + X, y = @load_iris + mach = machine(model, X, y) + MLJBase.fit!(mach,verbosity=1) + Xnew = (sepal_length = [6.4, 7.2, 7.4], + sepal_width = [2.8, 3.0, 2.8], + petal_length = [5.6, 5.8, 6.1], + petal_width = [2.1, 1.6, 1.9],) + yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions + predict_mode(mach, Xnew) # point predictions + pdf.(yhat, "virginica") # probabilities for the "verginica" class + MLJBase.fitted_params(mach) # fitted params + MLJBase.training_losses(mach) #training loss history + model.epochs= 100 #changing number of epochs + MLJBase.fit!(mach) #testing update function + model.epochs= 50 #changing number of epochs to a lower number + MLJBase.fit!(mach) #testing update function + model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps + MLJBase.fit!(mach) #testing update function (the laplace part) -# Define a different model -flux_model_two = Chain( - Dense(4, 6, relu), - Dense(6, 3) -) + # Define a different model + flux_model_two = Chain( + Dense(4, 6, relu), + Dense(6, 3) + ) -model.model = flux_model_two + model.model = flux_model_two -MLJBase.fit!(mach) - + MLJBase.fit!(mach) end From 71a361135980f1a812fd6de63aa71d55d13698b7 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 16 Oct 2024 15:12:50 +0200 Subject: [PATCH 40/60] nearly there I think --- src/direct_mlj.jl | 19 +- test/Manifest.toml | 539 +++++++++++++++++++++++++++++---------------- test/runtests.jl | 3 - 3 files changed, 357 insertions(+), 204 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index e5b611d1..cc6133fd 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -47,20 +47,23 @@ end LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier} # for fit: -MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :)) +MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing)) function MMI.reformat(::LaplaceClassifier, X, y) X = MLJBase.matrix(X) |> permutedims y = categorical(y) - unique_labels = y.pool.levels - y = Flux.onehotbatch(y, unique_labels) # One-hot encoding + labels = y.pool.levels + y = Flux.onehotbatch(y, labels) # One-hot encoding - return X, y + return X, (y, labels) end MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,) +MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:,I], y[2])) +MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],) + @doc """ MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y) @@ -78,10 +81,7 @@ Fit a Laplace model using the provided features and target values. - `report`: A Namedtuple containing the loss history of the fitting process. """ function MMI.fit(m::LaplaceModels, verbosity, X, y) - decode = y[2] - - y= y[1] - + y, decode = y # Make a copy of the model because Flux does not allow to mutate hyperparameters copied_model = deepcopy(m.model) @@ -165,8 +165,7 @@ Update the Laplace model using the provided new data points. - `report`: A Namedtuple containing the complete loss history of the fitting process. """ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) - decode = y[2] - y_up=y[1] + y_up, decode = y data_loader = Flux.DataLoader((X, y_up); batchsize=m.batch_size) old_model = old_cache[1] diff --git a/test/Manifest.toml b/test/Manifest.toml index 7f9e18a9..3448bc9a 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,8 +1,14 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.3" +julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "30dc96d6146892242111894ebf221bf701ee0fdd" +project_hash = "48e3a5a4625c4493599b02acbfe0e972463bd78f" + +[[deps.ARFFFiles]] +deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] +git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" +uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" +version = "1.4.1" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] @@ -21,24 +27,28 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" version = "0.4.5" [[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"] +git-tree-sha1 = "b392ede862e506d451fc1616e79aa6f4c673dab8" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" +version = "0.1.38" [deps.Accessors.extensions] AccessorsAxisKeysExt = "AxisKeys" + AccessorsDatesExt = "Dates" AccessorsIntervalSetsExt = "IntervalSets" AccessorsStaticArraysExt = "StaticArrays" AccessorsStructArraysExt = "StructArrays" + AccessorsTestExt = "Test" AccessorsUnitfulExt = "Unitful" [deps.Accessors.weakdeps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [[deps.Adapt]] @@ -59,9 +69,9 @@ version = "1.1.3" [[deps.Aqua]] deps = ["Compat", "Pkg", "Test"] -git-tree-sha1 = "12e575f31a6f233ba2485ed86b9325b85df37c61" +git-tree-sha1 = "49b1d7a9870c87ba13dc63f8ccfcf578cb266f95" uuid = "4c88cf16-eb10-579e-8560-4a9242c79595" -version = "0.8.7" +version = "0.8.9" [[deps.ArgCheck]] git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" @@ -111,10 +121,10 @@ uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" version = "0.3.9" [[deps.BangBang]] -deps = ["Accessors", "Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] -git-tree-sha1 = "08e5fc6620a8d83534bf6149795054f1b1e8370a" +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.4.2" +version = "0.4.3" [deps.BangBang.extensions] BangBangChainRulesCoreExt = "ChainRulesCore" @@ -146,9 +156,9 @@ uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" version = "0.1.9" [[deps.BufferedStreams]] -git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" +git-tree-sha1 = "6863c5b7fc997eadcabdbaf6c5f201dc30032643" uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.2.1" +version = "1.2.2" [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -169,15 +179,9 @@ version = "0.10.14" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" +git-tree-sha1 = "009060c9a6168704143100f36ab08f06c2af4642" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+2" - -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" +version = "1.18.2+1" [[deps.CategoricalArrays]] deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] @@ -206,15 +210,15 @@ version = "0.1.15" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +git-tree-sha1 = "be227d253d132a6d57f9ccf5f67c0fb6488afd87" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" +version = "1.71.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" +version = "1.25.0" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] @@ -234,9 +238,9 @@ version = "0.10.4+0" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" +git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.5" +version = "0.7.6" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] @@ -272,16 +276,16 @@ uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" version = "1.0.2" [[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +deps = ["MacroTools"] +git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5" uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" +version = "0.3.1" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" +version = "4.16.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -313,17 +317,18 @@ uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" version = "2.4.2" [[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" +git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.6" +version = "1.5.8" [deps.ConstructionBase.extensions] ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseLinearAlgebraExt = "LinearAlgebra" ConstructionBaseStaticArraysExt = "StaticArrays" [deps.ConstructionBase.weakdeps] IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.ContextVariablesX]] @@ -370,10 +375,10 @@ uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" version = "0.7.13" [[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "fb61b4812c49343d7ef0b533ba982c46021938a6" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" +version = "1.7.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -390,6 +395,12 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.Dbus_jll]] +deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "fc173b380865f70627d7dd1190dc2fce6cc105af" +uuid = "ee1fde0b-3d02-5ea6-8484-8dfef6360eab" +version = "1.14.10+0" + [[deps.DecisionTree]] deps = ["AbstractTrees", "DelimitedFiles", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics"] git-tree-sha1 = "526ca14aaaf2d5a0e242f3a8a7966eb9065d7d78" @@ -436,9 +447,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" +git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.109" +version = "0.25.112" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -461,11 +472,11 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" +[[deps.EarlyStopping]] +deps = ["Dates", "Statistics"] +git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" +uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" +version = "0.3.0" [[deps.EpollShim_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -487,9 +498,9 @@ version = "2.6.2+0" [[deps.FFMPEG]] deps = ["FFMPEG_jll"] -git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" +git-tree-sha1 = "53ebe7511fa11d33bec688a9178fac4e49eeee00" uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" -version = "0.4.1" +version = "0.4.2" [[deps.FFMPEG_jll]] deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] @@ -511,24 +522,29 @@ version = "0.1.1" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +git-tree-sha1 = "62ca0547a14c57e98154423419d8a342dca75ca9" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" +version = "1.16.4" [[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +deps = ["Compat", "Dates"] +git-tree-sha1 = "7878ff7172a8e6beedd1dea14bd27c3c6340d361" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.21" +version = "0.9.22" +weakdeps = ["Mmap", "Test"] + + [deps.FilePathsBase.extensions] + FilePathsBaseMmapExt = "Mmap" + FilePathsBaseTestExt = "Test" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra"] -git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.11.0" +version = "1.13.0" weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -543,21 +559,27 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.5" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" +version = "0.14.22" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" FluxCUDAExt = "CUDA" FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxEnzymeExt = "Enzyme" + FluxMPIExt = "MPI" + FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] FluxMetalExt = "Metal" [deps.Flux.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.Fontconfig_jll]] @@ -595,25 +617,25 @@ version = "1.0.14+0" [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" +git-tree-sha1 = "64d8e93700c7a3f28f717d265382d52fac9fa1c1" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.11" +version = "0.4.12" [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GLFW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] -git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"] +git-tree-sha1 = "532f9126ad901533af1d4f5c198867227a7bb077" uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" -version = "3.4.0+0" +version = "3.4.0+1" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" +git-tree-sha1 = "62ee71528cca49be797076a76bdc654a170a523e" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.3.0" +version = "10.3.1" [[deps.GPUArraysCore]] deps = ["Adapt"] @@ -623,15 +645,15 @@ version = "0.1.6" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] -git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5" +git-tree-sha1 = "ee28ddcd5517d54e417182fec3886e7412d3926f" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.73.7" +version = "0.73.8" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d" +git-tree-sha1 = "f31929b9e67066bee48eec8b03c0df47d31a74b3" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.73.7+0" +version = "0.73.8+0" [[deps.GZip]] deps = ["Libdl", "Zlib_jll"] @@ -647,9 +669,9 @@ version = "0.21.0+0" [[deps.Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" +git-tree-sha1 = "674ff0db93fffcd11a3573986e550d66cd4fd71f" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.2+0" +version = "2.80.5+0" [[deps.Glob]] git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" @@ -680,10 +702,10 @@ version = "0.17.2" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" [[deps.HDF5_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739" uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.3+3" +version = "1.14.2+1" [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] @@ -692,22 +714,22 @@ uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" version = "1.10.8" [[deps.HarfBuzz_jll]] -deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] -git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll"] +git-tree-sha1 = "401e4f3f30f43af2c8478fc008da50096ea5240f" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" -version = "2.8.1+1" +version = "8.3.1+0" [[deps.Hwloc_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" +git-tree-sha1 = "dd3b49277ec2bb2c6b94eb1604d4d0616016f7a6" uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.11.1+0" +version = "2.11.2+0" [[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "7c4195be1649ae622304031ed46a2f4df989f1eb" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.23" +version = "0.3.24" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools"] @@ -762,14 +784,14 @@ uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" version = "0.7.0" [[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" +git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.15" -weakdeps = ["Dates"] +version = "0.1.17" +weakdeps = ["Dates", "Test"] [deps.InverseFunctions.extensions] - DatesExt = "Dates" + InverseFunctionsDatesExt = "Dates" + InverseFunctionsTestExt = "Test" [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" @@ -781,28 +803,34 @@ git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.2.2" +[[deps.IterationControl]] +deps = ["EarlyStopping", "InteractiveUtils"] +git-tree-sha1 = "e663925ebc3d93c1150a7570d114f9ea2f664726" +uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" +version = "0.5.4" + [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] +git-tree-sha1 = "aeab5c68eb2cf326619bf71235d8f4561c62fe22" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.50" +version = "0.5.5" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] -git-tree-sha1 = "a53ebe394b71470c7f97c2e7e170d51df21b17af" +git-tree-sha1 = "39d64b09147620f5ffbf6b2d3255be3c901bec63" uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c" -version = "0.1.7" +version = "0.1.8" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" +version = "1.6.1" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -824,9 +852,9 @@ version = "1.14.0" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" +git-tree-sha1 = "25ee0be4d43d0269027024d75a24c24d6c6e590c" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.3+0" +version = "3.0.4+0" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -835,16 +863,20 @@ uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" version = "0.2.4" [[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" +deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.22" +version = "0.9.28" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" + LinearAlgebraExt = "LinearAlgebra" + SparseArraysExt = "SparseArrays" [deps.KernelAbstractions.weakdeps] EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.LAME_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -853,16 +885,16 @@ uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" version = "3.100.2+0" [[deps.LERC_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "36bdbc52f13a7d1dcb0f3cd694e01677a515655b" uuid = "88015f11-f218-50d7-93a8-a6af411a945d" -version = "3.0.0+1" +version = "4.0.0+0" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" +git-tree-sha1 = "4ad43cb0a4bb5e5b1506e1d1f48646d7e0c80363" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.0.0" +version = "9.1.2" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] @@ -870,41 +902,49 @@ weakdeps = ["BFloat16s"] [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" +git-tree-sha1 = "05a8bd5a42309a9ec82f700876903abce1017dd3" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.30+0" +version = "0.0.34+0" [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" +git-tree-sha1 = "78211fb6cbc872f77cad3fc0b6cf647d923f4929" uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" -version = "15.0.7+0" +version = "18.1.7+0" [[deps.LZO_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" +git-tree-sha1 = "854a9c268c43b77b0a27f22d7fab8d33cdb3a731" uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.2+0" +version = "2.10.2+1" [[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" +version = "1.4.0" [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] -git-tree-sha1 = "5b0d630f3020b82c0775a51d05895852f8506f50" +git-tree-sha1 = "ce5f5621cac23a86011836badfedf664a612cee4" uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" -version = "0.16.4" +version = "0.16.5" [deps.Latexify.extensions] DataFramesExt = "DataFrames" + SparseArraysExt = "SparseArrays" SymEngineExt = "SymEngine" [deps.Latexify.weakdeps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" +[[deps.LatinHypercubeSampling]] +deps = ["Random", "StableRNGs", "StatsBase", "Test"] +git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" +uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" +version = "1.9.0" + [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" @@ -985,9 +1025,9 @@ version = "2.40.1+0" [[deps.Libtiff_jll]] deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] -git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a" +git-tree-sha1 = "b404131d06f7886402758c9ce2214b636eb4d54a" uuid = "89763e89-9b03-5906-acba-b20f662cd828" -version = "4.5.1+1" +version = "4.7.0+0" [[deps.Libuuid_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1030,35 +1070,109 @@ git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" uuid = "23992714-dd62-5051-b70f-ba57cb901cac" version = "0.10.7" +[[deps.MLDataDevices]] +deps = ["Adapt", "Functors", "Preferences", "Random"] +git-tree-sha1 = "e16288e37e76d68c3f1c418e0a2bec88d98d55fc" +uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +version = "1.2.0" + + [deps.MLDataDevices.extensions] + MLDataDevicesAMDGPUExt = "AMDGPU" + MLDataDevicesCUDAExt = "CUDA" + MLDataDevicesChainRulesCoreExt = "ChainRulesCore" + MLDataDevicesFillArraysExt = "FillArrays" + MLDataDevicesGPUArraysExt = "GPUArrays" + MLDataDevicesMLUtilsExt = "MLUtils" + MLDataDevicesMetalExt = ["GPUArrays", "Metal"] + MLDataDevicesReactantExt = "Reactant" + MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" + MLDataDevicesReverseDiffExt = "ReverseDiff" + MLDataDevicesSparseArraysExt = "SparseArrays" + MLDataDevicesTrackerExt = "Tracker" + MLDataDevicesZygoteExt = "Zygote" + MLDataDevicescuDNNExt = ["CUDA", "cuDNN"] + MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] + + [deps.MLDataDevices.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" + [[deps.MLDatasets]] deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" +git-tree-sha1 = "361c2692ee730944764945859f1a6b31072e275d" uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.16" +version = "0.7.18" + +[[deps.MLFlowClient]] +deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] +git-tree-sha1 = "9abb12b62debc27261c008daa13627255bf79967" +uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83" +version = "0.5.1" + +[[deps.MLJ]] +deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBalancing", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "StatisticalMeasures", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "bd2072e9cd65be0a3cb841f3d8cda1d2cacfe5db" +uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +version = "0.20.5" + +[[deps.MLJBalancing]] +deps = ["MLJBase", "MLJModelInterface", "MLUtils", "OrderedCollections", "Random", "StatsBase"] +git-tree-sha1 = "f707a01a92d664479522313907c07afa5d81df19" +uuid = "45f359ea-796d-4f51-95a5-deb1a414c586" +version = "0.1.5" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" version = "1.7.0" +weakdeps = ["StatisticalMeasures"] [deps.MLJBase.extensions] DefaultMeasuresExt = "StatisticalMeasures" - [deps.MLJBase.weakdeps] - StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" - [[deps.MLJDecisionTreeInterface]] deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"] git-tree-sha1 = "90ef4d3b6cacec631c57cc034e1e61b4aa0ce511" uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" version = "0.4.2" +[[deps.MLJEnsembles]] +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatisticalMeasuresBase", "StatsBase"] +git-tree-sha1 = "84a5be55a364bb6b6dc7780bbd64317ebdd3ad1e" +uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" +version = "0.4.3" + +[[deps.MLJFlow]] +deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"] +git-tree-sha1 = "508bff8071d7d1902d6f1b9d1e868d58821f1cfe" +uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f" +version = "0.5.0" + [[deps.MLJFlux]] deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables"] -git-tree-sha1 = "50c7f24b84005a2a80875c10d4f4059df17a0f68" +git-tree-sha1 = "98fd05da1bc1527f7849efb645ef1921ccf97c9a" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" -version = "0.5.1" +version = "0.6.0" + +[[deps.MLJIteration]] +deps = ["IterationControl", "MLJBase", "Random", "Serialization"] +git-tree-sha1 = "ad16cfd261e28204847f509d1221a581286448ae" +uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" +version = "0.6.3" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] @@ -1072,6 +1186,12 @@ git-tree-sha1 = "410da88e0e6ece5467293d2c76b51b7c6df7d072" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" version = "0.16.17" +[[deps.MLJTuning]] +deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase", "StatisticalMeasuresBase"] +git-tree-sha1 = "38aab60b1274ce7d6da784808e3be69e585dbbf6" +uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" +version = "0.8.8" + [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" @@ -1085,9 +1205,9 @@ version = "0.4.4" [[deps.MPICH_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" +git-tree-sha1 = "7715e65c47ba3941c502bffb7f266a41a7f54423" uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.2.2+0" +version = "4.2.3+0" [[deps.MPIPreferences]] deps = ["Libdl", "Preferences"] @@ -1097,9 +1217,9 @@ version = "0.1.11" [[deps.MPItrampoline_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" +git-tree-sha1 = "70e830dab5d0775183c99fc75e4c24c614ed7142" uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.4.0+0" +version = "5.5.1+0" [[deps.MacroTools]] deps = ["Markdown", "Random"] @@ -1134,9 +1254,9 @@ version = "0.3.2" [[deps.Metalhead]] deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"] -git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152" +git-tree-sha1 = "aef476e4958303f5ea9e1deb81a1ba2f510d4e11" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.9.3" +version = "0.9.4" [deps.Metalhead.extensions] MetalheadCUDAExt = "CUDA" @@ -1182,10 +1302,10 @@ uuid = "6f286f6a-111f-5878-ab1e-185364afe411" version = "0.10.3" [[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "da09a1e112fd75f9af2a5229323f01b56ec96a4c" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.21" +version = "0.9.24" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1193,12 +1313,14 @@ version = "0.9.21" NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" + NNlibForwardDiffExt = "ForwardDiff" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NPZ]] @@ -1254,11 +1376,17 @@ deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" version = "0.8.1+2" +[[deps.OpenML]] +deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"] +git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33" +uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" +version = "0.3.1" + [[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"] +git-tree-sha1 = "bfce6d523861a6c562721b262c0d1aaeead2647f" uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.6+0" +version = "5.0.5+0" [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] @@ -1268,9 +1396,9 @@ version = "1.4.3" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" +version = "3.0.15+1" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1285,10 +1413,10 @@ uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" version = "0.3.3" [[deps.Opus_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6703a85cb3781bd5909d48730a67205f3f31a575" uuid = "91d4177d-7536-5919-b921-800302f37372" -version = "1.3.2+0" +version = "1.3.3+0" [[deps.OrderedCollections]] git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" @@ -1318,6 +1446,12 @@ git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" version = "0.5.12" +[[deps.Pango_jll]] +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e127b609fb9ecba6f201ba7ab753d5a605d53801" +uuid = "36c8627f-9965-5494-a995-c6b170f724f3" +version = "1.54.1+0" + [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" @@ -1378,9 +1512,9 @@ version = "1.4.1" [[deps.Plots]] deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] -git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf" +git-tree-sha1 = "45470145863035bb124ca51b320ed35d071cc6c2" uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -version = "1.40.5" +version = "1.40.8" [deps.Plots.extensions] FileIOExt = "FileIO" @@ -1426,9 +1560,9 @@ version = "0.4.2" [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +git-tree-sha1 = "1101cd475833706e4d0e7b122218257178f48f34" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" +version = "2.4.0" [[deps.Printf]] deps = ["Unicode"] @@ -1447,9 +1581,9 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.10.2" [[deps.PtrArrays]] -git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.2.0" +version = "1.2.1" [[deps.Qt6Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] @@ -1477,9 +1611,15 @@ version = "6.7.1+1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" +git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.9.4" +version = "2.11.1" + + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] @@ -1526,15 +1666,15 @@ version = "1.3.0" [[deps.Rmath]] deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" +version = "0.8.0" [[deps.Rmath_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" +git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.2+0" +version = "0.5.1+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -1590,9 +1730,9 @@ uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" version = "1.0.3" [[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +git-tree-sha1 = "f305871d2f381d21527c770d4788c06c097c9bc1" uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" +version = "1.2.0" [[deps.SimpleTraits]] deps = ["InteractiveUtils", "MacroTools"] @@ -1664,6 +1804,20 @@ git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" version = "1.4.3" +[[deps.StatisticalMeasures]] +deps = ["CategoricalArrays", "CategoricalDistributions", "Distributions", "LearnAPI", "LinearAlgebra", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "StatisticalMeasuresBase", "Statistics", "StatsBase"] +git-tree-sha1 = "c1d4318fa41056b839dfbb3ee841f011fa6e8518" +uuid = "a19d573c-0a75-4610-95b3-7071388c7541" +version = "0.1.7" + + [deps.StatisticalMeasures.extensions] + LossFunctionsExt = "LossFunctions" + ScientificTypesExt = "ScientificTypes" + + [deps.StatisticalMeasures.weakdeps] + LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" + ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" + [[deps.StatisticalMeasuresBase]] deps = ["CategoricalArrays", "InteractiveUtils", "MLUtils", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "Statistics"] git-tree-sha1 = "17dfb22e2e4ccc9cd59b487dce52883e0151b4d3" @@ -1695,9 +1849,9 @@ version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" +git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.1" +version = "1.3.2" weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] @@ -1724,9 +1878,9 @@ version = "0.3.7" [[deps.StringManipulation]] deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +git-tree-sha1 = "a6b1675a536c5ad1a60e5a5153e1fee12eb146e3" uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" +version = "0.4.0" [[deps.StructArrays]] deps = ["ConstructionBase", "DataAPI", "Tables"] @@ -1743,9 +1897,9 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] [[deps.StructTypes]] deps = ["Dates", "UUIDs"] -git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.10.0" +version = "1.11.0" [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] @@ -1774,10 +1928,9 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.12.0" [[deps.TaijaBase]] -deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"] -git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0" +git-tree-sha1 = "4076f60078b12095ca71a2c26e2e4515e3a6a5e5" uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6" -version = "1.2.2" +version = "1.2.3" [[deps.TaijaData]] deps = ["CSV", "CounterfactualExplanations", "DataAPI", "DataFrames", "Flux", "LazyArtifacts", "MLDatasets", "MLJBase", "MLJModels", "Random", "StatsBase"] @@ -1801,21 +1954,18 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TranscodingStreams]] -git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.1" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] +version = "0.11.3" [[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +deps = ["Accessors", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "7deeab4ff96b85c5f72c824cae53a1398da3d1cb" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" +version = "0.4.84" [deps.Transducers.extensions] + TransducersAdaptExt = "Adapt" TransducersBlockArraysExt = "BlockArrays" TransducersDataFramesExt = "DataFrames" TransducersLazyArraysExt = "LazyArrays" @@ -1823,6 +1973,7 @@ version = "0.4.82" TransducersReferenceablesExt = "Referenceables" [deps.Transducers.weakdeps] + Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -1905,9 +2056,9 @@ version = "0.2.1" [[deps.UnsafeAtomicsLLVM]] deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" +git-tree-sha1 = "2d17fabcd17e67d7625ce9c531fb9f40b7c42ce4" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.5" +version = "0.2.1" [[deps.Unzip]] git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" @@ -1945,9 +2096,9 @@ version = "1.6.1" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d" +git-tree-sha1 = "1165b0443d0eca63ac1e32b8c0eb69ed2f4f8127" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.13.1+0" +version = "2.13.3+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] @@ -2118,15 +2269,15 @@ version = "1.2.13+1" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" +git-tree-sha1 = "555d1076590a6cc2fdee2ef1469451f872d8b41b" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.6+0" +version = "1.5.6+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +git-tree-sha1 = "f816633be6dc5c0ed9ffedda157ecfda0b3b6a69" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" +version = "0.6.72" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -2152,9 +2303,9 @@ version = "3.2.9+0" [[deps.fzf_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a68c9655fbe6dfcab3d972808f1aafec151ce3f8" +git-tree-sha1 = "936081b536ae4aa65415d869287d43ef3cb576b2" uuid = "214eeab7-80f7-51ab-84ad-2988db7cef09" -version = "0.43.0+0" +version = "0.53.0+0" [[deps.gperf_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -2175,15 +2326,21 @@ uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" version = "3.9.0+0" [[deps.libass_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "e17c115d55c5fbb7e52ebedb427a0dca79d4484e" uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" -version = "0.15.1+0" +version = "0.15.2+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" +version = "5.11.0+0" + +[[deps.libdecor_jll]] +deps = ["Artifacts", "Dbus_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pango_jll", "Wayland_jll", "xkbcommon_jll"] +git-tree-sha1 = "9bf7903af251d2050b467f76bdbe57ce541f7f4f" +uuid = "1183f4f0-6f2a-5f1a-908b-139f9cdfea6f" +version = "0.2.2+0" [[deps.libevdev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -2192,10 +2349,10 @@ uuid = "2db6ffa8-e38f-5e21-84af-90c45d0032cc" version = "1.11.0+0" [[deps.libfdk_aac_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8a22cf860a7d27e4f3498a0fe0811a7957badb38" uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" -version = "2.0.2+0" +version = "2.0.3+0" [[deps.libinput_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "eudev_jll", "libevdev_jll", "mtdev_jll"] @@ -2205,15 +2362,15 @@ version = "1.18.0+0" [[deps.libpng_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" +git-tree-sha1 = "b70c870239dc3d7bc094eb2d6be9b73d27bef280" uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.43+1" +version = "1.6.44+0" [[deps.libvorbis_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] -git-tree-sha1 = "b910cb81ef3fe6e78bf6acee440bda86fd6ae00c" +git-tree-sha1 = "490376214c4721cdaca654041f635213c6165cb3" uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" -version = "1.3.7+1" +version = "1.3.7+2" [[deps.mtdev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] diff --git a/test/runtests.jl b/test/runtests.jl index 4995c6f9..031224ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,9 +35,6 @@ using Test include("krondecomposed.jl") end - #@testset "MLJFlux" begin - #include("mlj_flux_interfacing.jl") - #end @testset "ML" begin include("direct_mlj_interface.jl") end From 74d778e0fd47e5dc7703454a63b8a26106cd23b8 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 16 Oct 2024 15:25:01 +0200 Subject: [PATCH 41/60] one more issue with regression --- src/direct_mlj.jl | 2 +- test/direct_mlj_interface.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index cc6133fd..819f9b85 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -473,7 +473,7 @@ function MMI.predict(m::LaplaceModels, fitresult, Xnew) LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> permutedims - return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions) + return MLJBase.UnivariateFinite(decode, predictions; pool=missing) end end diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 6254a7de..629bce86 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -31,7 +31,7 @@ cv = CV(; nfolds=3) MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) - evaluate!(mach, resampling=cv, measure=l2, verbosity=0) + # evaluate!(mach, resampling=cv, measure=l2, verbosity=0) end @@ -63,6 +63,7 @@ end MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) + evaluate!(mach; resampling=cv, measure=brier_loss, verbosity=0) # Define a different model flux_model_two = Chain( From 80784bb02c4af72a4186bed9cd4aed8f8769f463 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 18 Oct 2024 09:11:47 +0200 Subject: [PATCH 42/60] fixed predict so that it return a vector of distributions-> fixed evaluate! --- docs/Manifest.toml | 8 ++++---- test/direct_mlj_interface.jl | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 6d42a1e2..a97bfd27 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca" +project_hash = "616a9e89f5c520a58672ad91b5525001e0dadab3" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -955,10 +955,10 @@ uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" version = "1.3.1" [[deps.LaplaceRedux]] -deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] -git-tree-sha1 = "a84b72a27c93c72a6af5d22216eb81a419b1b97a" +deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] +path = "C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" -version = "1.0.2" +version = "1.1.1" [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 629bce86..c69f9cc5 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -18,11 +18,12 @@ cv = CV(; nfolds=3) model = LaplaceRegressor(model=flux_model,epochs=50) X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) + #train, test = partition(eachindex(y), 0.7); # 70:30 split mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report) - MLJBase.fit!(mach,verbosity=1) - Xnew, _ = make_regression(3, 4; rng=123) - yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions - MLJBase.predict_mode(mach, Xnew) # point predictions + MLJBase.fit!(mach, verbosity=1) + #Xnew, ynew = make_regression(3, 4; rng=123) + yhat = MLJBase.predict(mach, X) # probabilistic predictions + MLJBase.predict_mode(mach, X) # point predictions MLJBase.fitted_params(mach) #fitted params function MLJBase.training_losses(mach) #training loss history model.epochs= 100 #changing number of epochs @@ -31,7 +32,14 @@ cv = CV(; nfolds=3) MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) - # evaluate!(mach, resampling=cv, measure=l2, verbosity=0) + yhat = MLJBase.predict(mach, X) # probabilistic predictions + println( typeof(yhat) ) + println( size(yhat) ) + println( typeof(y) ) + println( size(y) ) + + evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0) + end From be80e32e291897a1ceef508c7bcc485a0cb29751 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 18 Oct 2024 09:11:47 +0200 Subject: [PATCH 43/60] amend: fixed predict so that it return a vector of distributions-> fixed evaluate! --- docs/Manifest.toml | 8 ++++---- src/direct_mlj.jl | 2 +- test/direct_mlj_interface.jl | 18 +++++++++++++----- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 6d42a1e2..a97bfd27 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca" +project_hash = "616a9e89f5c520a58672ad91b5525001e0dadab3" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -955,10 +955,10 @@ uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" version = "1.3.1" [[deps.LaplaceRedux]] -deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] -git-tree-sha1 = "a84b72a27c93c72a6af5d22216eb81a419b1b97a" +deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] +path = "C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" -version = "1.0.2" +version = "1.1.1" [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 819f9b85..14c5c8fd 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -466,7 +466,7 @@ function MMI.predict(m::LaplaceModels, fitresult, Xnew) means, variances = yhat # Create Normal distributions from the means and variances - return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)] + return vec([Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]) else predictions = diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 629bce86..c69f9cc5 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -18,11 +18,12 @@ cv = CV(; nfolds=3) model = LaplaceRegressor(model=flux_model,epochs=50) X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) + #train, test = partition(eachindex(y), 0.7); # 70:30 split mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report) - MLJBase.fit!(mach,verbosity=1) - Xnew, _ = make_regression(3, 4; rng=123) - yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions - MLJBase.predict_mode(mach, Xnew) # point predictions + MLJBase.fit!(mach, verbosity=1) + #Xnew, ynew = make_regression(3, 4; rng=123) + yhat = MLJBase.predict(mach, X) # probabilistic predictions + MLJBase.predict_mode(mach, X) # point predictions MLJBase.fitted_params(mach) #fitted params function MLJBase.training_losses(mach) #training loss history model.epochs= 100 #changing number of epochs @@ -31,7 +32,14 @@ cv = CV(; nfolds=3) MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) - # evaluate!(mach, resampling=cv, measure=l2, verbosity=0) + yhat = MLJBase.predict(mach, X) # probabilistic predictions + println( typeof(yhat) ) + println( size(yhat) ) + println( typeof(y) ) + println( size(y) ) + + evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0) + end From f4fcd958ff6f0722d051b7c2f1d10f0ac0c5e410 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 18 Oct 2024 09:36:38 +0200 Subject: [PATCH 44/60] madea mess with commits.... bah --- test/direct_mlj_interface.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index c69f9cc5..83d4e86f 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -33,11 +33,6 @@ cv = CV(; nfolds=3) model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) yhat = MLJBase.predict(mach, X) # probabilistic predictions - println( typeof(yhat) ) - println( size(yhat) ) - println( typeof(y) ) - println( size(y) ) - evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0) end From 851784fabea41be5c8c5cee183f897610a86c00e Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 18 Oct 2024 11:27:05 +0200 Subject: [PATCH 45/60] trying to increase patch coverage --- test/direct_mlj_interface.jl | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 83d4e86f..da0782e7 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -4,8 +4,7 @@ using MLJBase: MLJBase, categorical using Flux using StableRNGs using MLJ -using MLJ:predict,fit! -using LaplaceRedux +import LaplaceRedux: LaplaceClassifier, LaplaceRegressor cv = CV(; nfolds=3) @@ -35,6 +34,18 @@ cv = CV(; nfolds=3) yhat = MLJBase.predict(mach, X) # probabilistic predictions evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0) + + # Define a different model + flux_model_two = Chain( + Dense(4, 6, relu), + Dense(6, 1) + ) + # test update! fallback to fit! + model.model = flux_model_two + MLJBase.fit!(mach) + + + end From 0752b83180b7282d3726da8baafebb48bc503f9d Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Fri, 18 Oct 2024 11:46:16 +0200 Subject: [PATCH 46/60] fkn hell this codecov bot is worse than the inquisition --- test/direct_mlj_interface.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index da0782e7..84d0440a 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -44,6 +44,9 @@ cv = CV(; nfolds=3) model.model = flux_model_two MLJBase.fit!(mach) + model_two = LaplaceRegressor(model=flux_model_two,epochs=100) + @test !MLJBase.is_same_except(model,model_two) + end From 573ffd8956af000acfd6ee5bf33336960c66156a Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Mon, 21 Oct 2024 12:30:27 +0200 Subject: [PATCH 47/60] uhmmmmmm --- src/direct_mlj.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 14c5c8fd..98165616 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -353,18 +353,11 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy return true end -# Define helper functions used in is_same_except function _isdefined(obj, name) return hasproperty(obj, name) end -function deep_properties(::Type) - return Set{Symbol}() -end -function _equal_to_depth_one(a, b) - return a == b -end function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain) if length(chain1.layers) != length(chain2.layers) From db14b84f30ae256521150fd070f843d6890eb46e Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Mon, 21 Oct 2024 13:35:30 +0200 Subject: [PATCH 48/60] fixed _isdefined --- src/direct_mlj.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 98165616..b1ffa328 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -352,13 +352,16 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy end return true end - -function _isdefined(obj, name) - return hasproperty(obj, name) +function _isdefined(object, name) + pnames = propertynames(object) + fnames = fieldnames(typeof(object)) + name in pnames && !(name in fnames) && return true + isdefined(object, name) end + function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain) if length(chain1.layers) != length(chain2.layers) return false From 82c5714681b8ab0aa5e8b39581ef8ef8125f7939 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 22 Oct 2024 13:07:51 +0200 Subject: [PATCH 49/60] trying to fix docs issue and no longer importing MLJ nor MLJBase namespace --- src/direct_mlj.jl | 6 +++--- test/direct_mlj_interface.jl | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index b1ffa328..317d4b78 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -5,7 +5,7 @@ using Random using Tables using LinearAlgebra using LaplaceRedux -using MLJBase +using MLJBase: MLJBase import MLJModelInterface as MMI using Distributions: Normal @@ -315,7 +315,7 @@ The meaining of "equal" depends on the type of the property value: - values that are not of `MLJType` are "equal" if they are `==`. In the special case of a "deep" property, "equal" has a different -meaning; see [`deep_properties`](@ref)) for details. +meaning; see [`MMI.StatTraits.deep_properties`](@ref)) for details. If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`. @@ -330,7 +330,7 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy if !_isdefined(m1, name) !_isdefined(m2, name) || return false elseif _isdefined(m2, name) - if name in deep_properties(LaplaceRegressor) + if name in MMI.StatTraits.deep_properties(LaplaceRegressor) _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) || return false else diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 84d0440a..8853ec4f 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -3,7 +3,6 @@ import Random.seed! using MLJBase: MLJBase, categorical using Flux using StableRNGs -using MLJ import LaplaceRedux: LaplaceClassifier, LaplaceRegressor cv = CV(; nfolds=3) From 72020130bb669fe287f955001b0b39b02354a108 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 22 Oct 2024 13:08:23 +0200 Subject: [PATCH 50/60] formatting --- src/direct_mlj.jl | 28 +++++++---------- test/direct_mlj_interface.jl | 61 +++++++++++++----------------------- test/runtests.jl | 1 - 3 files changed, 33 insertions(+), 57 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 317d4b78..b76cdd93 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -47,10 +47,11 @@ end LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier} # for fit: -MMI.reformat(::LaplaceRegressor, X, y) = (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing)) - -function MMI.reformat(::LaplaceClassifier, X, y) +function MMI.reformat(::LaplaceRegressor, X, y) + return (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing)) +end +function MMI.reformat(::LaplaceClassifier, X, y) X = MLJBase.matrix(X) |> permutedims y = categorical(y) labels = y.pool.levels @@ -61,7 +62,7 @@ end MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,) -MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:,I], y[2])) +MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:, I], y[2])) MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],) @doc """ @@ -171,7 +172,6 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) old_model = old_cache[1] old_state_tree = old_cache[2] old_loss_history = old_cache[3] - epochs = m.epochs @@ -236,13 +236,14 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) cache = (deepcopy(m), old_state_tree, old_loss_history) else - println("The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary") + println( + "The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary", + ) fitresult = (old_la, decode) report = (loss_history=old_loss_history,) cache = (deepcopy(m), old_state_tree, old_loss_history) end - elseif MMI.is_same_except( m, old_model, @@ -277,16 +278,12 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) LaplaceRedux.fit!(la, data_loader) optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) - fitresult = (la,decode) + fitresult = (la, decode) report = (loss_history=old_loss_history,) cache = (deepcopy(m), old_state_tree, old_loss_history) - else - - fitresult, cache, report = MLJBase.fit(m, verbosity,X,y) - - + fitresult, cache, report = MLJBase.fit(m, verbosity, X, y) end return fitresult, cache, report @@ -356,12 +353,9 @@ function _isdefined(object, name) pnames = propertynames(object) fnames = fieldnames(typeof(object)) name in pnames && !(name in fnames) && return true - isdefined(object, name) + return isdefined(object, name) end - - - function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain) if length(chain1.layers) != length(chain2.layers) return false diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 8853ec4f..c9e5b689 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -8,84 +8,67 @@ import LaplaceRedux: LaplaceClassifier, LaplaceRegressor cv = CV(; nfolds=3) @testset "Regression" begin - flux_model = Chain( - Dense(4, 10, relu), - Dense(10, 10, relu), - Dense(10, 1) - ) - model = LaplaceRegressor(model=flux_model,epochs=50) - + flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1)) + model = LaplaceRegressor(; model=flux_model, epochs=50) + X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) #train, test = partition(eachindex(y), 0.7); # 70:30 split mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report) - MLJBase.fit!(mach, verbosity=1) + MLJBase.fit!(mach; verbosity=1) #Xnew, ynew = make_regression(3, 4; rng=123) yhat = MLJBase.predict(mach, X) # probabilistic predictions MLJBase.predict_mode(mach, X) # point predictions MLJBase.fitted_params(mach) #fitted params function MLJBase.training_losses(mach) #training loss history - model.epochs= 100 #changing number of epochs + model.epochs = 100 #changing number of epochs MLJBase.fit!(mach) #testing update function - model.epochs= 50 #changing number of epochs to a lower number + model.epochs = 50 #changing number of epochs to a lower number MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) yhat = MLJBase.predict(mach, X) # probabilistic predictions - evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0) - + evaluate!(mach; resampling=cv, measure=log_loss, verbosity=0) # Define a different model - flux_model_two = Chain( - Dense(4, 6, relu), - Dense(6, 1) - ) + flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 1)) # test update! fallback to fit! model.model = flux_model_two MLJBase.fit!(mach) - model_two = LaplaceRegressor(model=flux_model_two,epochs=100) - @test !MLJBase.is_same_except(model,model_two) - - - + model_two = LaplaceRegressor(; model=flux_model_two, epochs=100) + @test !MLJBase.is_same_except(model, model_two) end - - @testset "Classification" begin # Define the model - flux_model = Chain( - Dense(4, 10, relu), - Dense(10, 3) - ) + flux_model = Chain(Dense(4, 10, relu), Dense(10, 3)) - model = LaplaceClassifier(model=flux_model,epochs=50) + model = LaplaceClassifier(; model=flux_model, epochs=50) X, y = @load_iris mach = machine(model, X, y) - MLJBase.fit!(mach,verbosity=1) - Xnew = (sepal_length = [6.4, 7.2, 7.4], - sepal_width = [2.8, 3.0, 2.8], - petal_length = [5.6, 5.8, 6.1], - petal_width = [2.1, 1.6, 1.9],) + MLJBase.fit!(mach; verbosity=1) + Xnew = ( + sepal_length=[6.4, 7.2, 7.4], + sepal_width=[2.8, 3.0, 2.8], + petal_length=[5.6, 5.8, 6.1], + petal_width=[2.1, 1.6, 1.9], + ) yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions predict_mode(mach, Xnew) # point predictions pdf.(yhat, "virginica") # probabilities for the "verginica" class MLJBase.fitted_params(mach) # fitted params MLJBase.training_losses(mach) #training loss history - model.epochs= 100 #changing number of epochs + model.epochs = 100 #changing number of epochs MLJBase.fit!(mach) #testing update function - model.epochs= 50 #changing number of epochs to a lower number + model.epochs = 50 #changing number of epochs to a lower number MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) evaluate!(mach; resampling=cv, measure=brier_loss, verbosity=0) # Define a different model - flux_model_two = Chain( - Dense(4, 6, relu), - Dense(6, 3) - ) + flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 3)) model.model = flux_model_two diff --git a/test/runtests.jl b/test/runtests.jl index 031224ea..ed16d37b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,5 +38,4 @@ using Test @testset "ML" begin include("direct_mlj_interface.jl") end - end From a05e25fe69db103959f1720243e7e4a4e1064805 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 22 Oct 2024 13:12:51 +0200 Subject: [PATCH 51/60] removing mlj_flux --- src/LaplaceRedux.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index 15ab3d4f..08496e0f 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -19,7 +19,6 @@ export fit!, predict export optimize_prior!, glm_predictive_distribution, posterior_covariance, posterior_precision -include("mlj_flux.jl") export LaplaceClassification export LaplaceRegression From 05df2e1408db43ac1eff64931efb4522fe1fffa9 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 22 Oct 2024 15:11:42 +0200 Subject: [PATCH 52/60] fixed issues --- .github/workflows/CI.yml | 1 + Project.toml | 8 +- src/LaplaceRedux.jl | 2 - src/direct_mlj.jl | 4 +- src/mlj_flux.jl | 494 ----------------------------------- test/Manifest.toml | 101 ++++--- test/direct_mlj_interface.jl | 19 +- test/runtests.jl | 2 +- 8 files changed, 85 insertions(+), 546 deletions(-) delete mode 100644 src/mlj_flux.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 64bf7af1..22e67adc 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -22,6 +22,7 @@ jobs: version: - '1.9' - '1.10' + - '1' os: - ubuntu-latest - windows-latest diff --git a/Project.toml b/Project.toml index f9cb74ea..d78ce06a 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -33,18 +32,17 @@ Distributions = "0.25.109" Flux = "0.12, 0.13, 0.14" LinearAlgebra = "1.7, 1.10" MLJBase = "1" -MLJFlux = "0.5" MLJModelInterface = "1.8.0" MLUtils = "0.4" Optimisers = "0.2, 0.3" ProgressMeter = "1.7.2" -Random = "1.9, 1.10" +Random = "1" Statistics = "1" Tables = "1.10.1" -Test = "1.9, 1.10" +Test = "1" Tullio = "0.3.5" Zygote = "0.6" -julia = "1.9, 1.10" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index 08496e0f..b6324f95 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -19,8 +19,6 @@ export fit!, predict export optimize_prior!, glm_predictive_distribution, posterior_covariance, posterior_precision -export LaplaceClassification -export LaplaceRegression include("calibration_functions.jl") export empirical_frequency_binary_classification, diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index b76cdd93..42ec720a 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -53,7 +53,7 @@ end function MMI.reformat(::LaplaceClassifier, X, y) X = MLJBase.matrix(X) |> permutedims - y = categorical(y) + y = MLJBase.categorical(y) labels = y.pool.levels y = Flux.onehotbatch(y, labels) # One-hot encoding @@ -332,7 +332,7 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy return false else ( - is_same_except(getproperty(m1, name), getproperty(m2, name)) || + MMI.is_same_except(getproperty(m1, name), getproperty(m2, name)) || getproperty(m1, name) isa AbstractRNG || getproperty(m2, name) isa AbstractRNG || ( diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl deleted file mode 100644 index 0b1fa68f..00000000 --- a/src/mlj_flux.jl +++ /dev/null @@ -1,494 +0,0 @@ -using Flux -using MLJFlux -using ProgressMeter: Progress, next!, BarGlyphs -using Random -using Tables -using LinearAlgebra -using LaplaceRedux -using ComputationalResources -using MLJBase: MLJBase -import MLJModelInterface as MMI -using Optimisers: Optimisers - -""" - MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic - -A mutable struct representing a Laplace regression model that extends the `MLJFlux.MLJFluxProbabilistic` abstract type. -It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. - -The model is defined by the following default parameters for all `MLJFlux` models: - -- `builder`: a Flux model that constructs the neural network. -- `optimiser`: a Flux optimiser. -- `loss`: a loss function that takes the predicted output and the true output as arguments. -- `epochs`: the number of epochs. -- `batch_size`: the size of a batch. -- `lambda`: the regularization strength. -- `alpha`: the regularization mix (0 for all l2, 1 for all l1). -- `rng`: a random number generator. -- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. -- `acceleration`: the computational resource to use. - -The model also has the following parameters, which are specific to the Laplace approximation: - -- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. -- `subnetwork_indices`: the indices of the subnetworks. -- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. -- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `σ`: the standard deviation of the prior distribution. -- `μ₀`: the mean of the prior distribution. -- `P₀`: the covariance matrix of the prior distribution. -- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. -- `fit_prior_nsteps`: the number of steps used to fit the priors. -""" -MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic - builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) - optimiser = Optimisers.Adam() - loss = Flux.Losses.mse - epochs::Int = 10::(_ > 0) - batch_size::Int = 1::(_ > 0) - lambda::Float64 = 1.0 - alpha::Float64 = 0.0 - rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG - optimiser_changes_trigger_retraining::Bool = false::(_ in (true, false)) - acceleration = CPU1()::(_ in (CPU1(), CUDALibs())) - subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices = nothing - hessian_structure::Union{HessianStructure,Symbol,String} = - :full::(_ in (:full, :diagonal)) - backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64 = 1.0 - μ₀::Float64 = 0.0 - P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - ret_distr::Bool = false::(_ in (true, false)) - fit_prior_nsteps::Int = 100::(_ > 0) -end - -""" - MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic - -A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. -It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. - -The model is defined by the following default parameters for all `MLJFlux` models: -- `builder`: a Flux model that constructs the neural network. -- `finaliser`: a Flux model that processes the output of the neural network. -- `optimiser`: a Flux optimiser. -- `loss`: a loss function that takes the predicted output and the true output as arguments. -- `epochs`: the number of epochs. -- `batch_size`: the size of a batch. -- `lambda`: the regularization strength. -- `alpha`: the regularization mix (0 for all l2, 1 for all l1). -- `rng`: a random number generator. -- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. -- `acceleration`: the computational resource to use. - -The model also has the following parameters, which are specific to the Laplace approximation: - -- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. -- `subnetwork_indices`: the indices of the subnetworks. -- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. -- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `σ`: the standard deviation of the prior distribution. -- `μ₀`: the mean of the prior distribution. -- `P₀`: the covariance matrix of the prior distribution. -- `link_approx`: the link approximation to use, either `:probit` or `:plugin`. -- `predict_proba`: a boolean that select whether to predict probabilities or not. -- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. -- `fit_prior_nsteps`: the number of steps used to fit the priors. -""" -MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic - builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) - finaliser = Flux.softmax - optimiser = Optimisers.Adam() - loss = Flux.crossentropy - epochs::Int = 10::(_ > 0) - batch_size::Int = 1::(_ > 0) - lambda::Float64 = 1.0 - alpha::Float64 = 0.0 - rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG - optimiser_changes_trigger_retraining::Bool = false::(_ in (true, false)) - acceleration = CPU1()::(_ in (CPU1(), CUDALibs())) - subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) - hessian_structure::Union{HessianStructure,Symbol,String} = - :full::(_ in (:full, :diagonal)) - backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64 = 1.0 - μ₀::Float64 = 0.0 - P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - link_approx::Symbol = :probit::(_ in (:probit, :plugin)) - predict_proba::Bool = true::(_ in (true, false)) - ret_distr::Bool = false::(_ in (true, false)) - fit_prior_nsteps::Int = 100::(_ > 0) -end - -const MLJ_Laplace = Union{LaplaceClassification,LaplaceRegression} - -""" - MLJFlux.shape(model::LaplaceRegression, X, y) - -Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset. - -# Arguments -- `model::LaplaceRegression`: The LaplaceRegression model to fit. -- `X`: The input data for training. -- `y`: The target labels for training one-hot encoded. - -# Returns -- (input size, output size) -""" -function MLJFlux.shape(model::LaplaceRegression, X, y) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - n_input = size(X, 2) - dims = size(y) - if length(dims) == 1 - n_output = 1 - else - n_output = dims[1] - end - return (n_input, n_output) -end - -""" - MLJFlux.build(model::LaplaceRegression, rng, shape) - -Builds an MLJFlux model for Laplace regression compatible with the dimensions of the input and output layers specified by `shape`. - -# Arguments -- `model::LaplaceRegression`: The Laplace regression model. -- `rng`: A random number generator to ensure reproducibility. -- `shape`: A tuple or array specifying the dimensions of the input and output layers. - -# Returns -- The constructed MLJFlux model, compatible with the specified input and output dimensions. -""" -function MLJFlux.build(model::LaplaceRegression, rng, shape) - chain = MLJFlux.build(model.builder, rng, shape...) - return chain -end - -""" - MLJFlux.fitresult(model::LaplaceRegression, chain, y) - -Computes the fit result for a Laplace Regression model, returning the model chain and the number of output variables in the target data. - -# Arguments -- `model::LaplaceRegression`: The Laplace Regression model to be evaluated. -- `chain`: The trained model chain. -- `y`: The target data, typically a vector of class labels. - -# Returns - A tuple containing: - - The trained Flux chain. - - a deepcopy of the laplace model. -""" -function MLJFlux.fitresult(model::LaplaceRegression, chain, y) - return (chain, deepcopy(model)) -end - -""" - MLJFlux.train(model::LaplaceRegression, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y) - -Fit the LaplaceRegression model using Flux.jl. - -# Arguments -- `model::LaplaceRegression`: The LaplaceRegression model. -- `regularized_optimiser`: the regularized optimiser to apply to the loss function. -- `optimiser_state`: thestate of the optimiser. -- `epochs`: The number of epochs for training. -- `verbosity`: The verbosity level for training. -- `X`: The input data for training. -- `y`: The target labels for training. - -# Returns (la, optimiser_state, history ) -where -- `la`: the fitted Laplace model. -- `optimiser_state`: the state of the optimiser. -- `history`: the training loss history. -""" -function MLJFlux.train( - model::LaplaceRegression, - chain, - regularized_optimiser, - optimiser_state, - epochs, - verbosity, - X, - y, -) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - - # Initialize history: - history = [] - verbose_laplace = false - # intitialize and start progress meter: - meter = Progress( - epochs + 1; - dt=1.0, - desc="Optimising neural net:", - barglyphs=BarGlyphs("[=> ]"), - barlen=25, - color=:yellow, - ) - verbosity != 1 || next!(meter) - - # initiate history: - loss = model.loss - losses = (loss(chain(X[i]), y[i]) for i in 1:length(y)) - history = [mean(losses)] - - for i in 1:epochs - chain, optimiser_state, current_loss = MLJFlux.train_epoch( - model, chain, regularized_optimiser, optimiser_state, X, y - ) - verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" - verbosity != 1 || next!(meter) - push!(history, current_loss) - end - - if !isa(chain, AbstractLaplace) - la = LaplaceRedux.Laplace( - chain; - likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - else - la = chain - end - - # fit the Laplace model: - LaplaceRedux.fit!(la, zip(X, y)) - optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - - return la, optimiser_state, history -end - -""" - predict(model::LaplaceRegression, Xnew) - -Predict the output for new input data using a Laplace regression model. - -# Arguments -- `model::LaplaceRegression`: The trained Laplace regression model. -- the fitresult output produced by MLJFlux.fit! -- `Xnew`: The new input data. - -# Returns -- The predicted output for the new input data. - -""" -function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) - Xnew = MLJBase.matrix(Xnew) |> permutedims - la = fitresult[1] - yhat = LaplaceRedux.predict(la, Xnew; ret_distr=model.ret_distr) - return yhat -end - -""" - MLJFlux.shape(model::LaplaceClassification, X, y) - -Compute the the number of features of the dataset X and the number of unique classes in y. - -# Arguments -- `model::LaplaceClassification`: The LaplaceClassification model to fit. -- `X`: The input data for training. -- `y`: The target labels for training one-hot encoded. - -# Returns -- (input size, output size) -""" - -function MLJFlux.shape(model::LaplaceClassification, X, y) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - n_input = size(X, 2) - levels = unique(y) - n_output = length(levels) - return (n_input, n_output) -end - -""" - MLJFlux.build(model::LaplaceClassification, rng, shape) - -Builds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by `shape`. - -# Arguments -- `model::LaplaceClassification`: The Laplace classification model. -- `rng`: A random number generator to ensure reproducibility. -- `shape`: A tuple or array specifying the dimensions of the input and output layers. - -# Returns -- The constructed MLJFlux model, compatible with the specified input and output dimensions. -""" -function MLJFlux.build(model::LaplaceClassification, rng, shape) - chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) - - return chain -end - -""" - MLJFlux.fitresult(model::LaplaceClassification, chain, y) - -Computes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data. - -# Arguments -- `model::LaplaceClassification`: The Laplace classification model to be evaluated. -- `chain`: The trained model chain. -- `y`: The target data, typically a vector of class labels. - -# Returns -# Returns - A tuple containing: - - The trained Flux chain. - - a deepcopy of the laplace model. -""" -function MLJFlux.fitresult(model::LaplaceClassification, chain, y) - return (chain, deepcopy(model)) -end - -""" - MLJFlux.train(model::LaplaceClassification, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y) - -Fit the LaplaceRegression model using Flux.jl. - -# Arguments -- `model::LaplaceClassification`: The LaplaceClassification model. -- `regularized_optimiser`: the regularized optimiser to apply to the loss function. -- `optimiser_state`: thestate of the optimiser. -- `epochs`: The number of epochs for training. -- `verbosity`: The verbosity level for training. -- `X`: The input data for training. -- `y`: The target labels for training. - -# Returns (fitresult, cache, report ) -where -- `la`: the fitted Laplace model. -- `optimiser_state`: the state of the optimiser. -- `history`: the training loss history. -""" -function MLJFlux.train( - model::LaplaceClassification, - chain, - regularized_optimiser, - optimiser_state, - epochs, - verbosity, - X, - y, -) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - - # Initialize history: - history = [] - verbose_laplace = false - # intitialize and start progress meter: - meter = Progress( - epochs + 1; - dt=1.0, - desc="Optimising neural net:", - barglyphs=BarGlyphs("[=> ]"), - barlen=25, - color=:yellow, - ) - verbosity != 1 || next!(meter) - - # initiate history: - loss = model.loss - n_batches = length(y) - losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches) - history = [mean(losses)] - - for i in 1:epochs - chain, optimiser_state, current_loss = MLJFlux.train_epoch( - model, chain, regularized_optimiser, optimiser_state, X, y - ) - verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" - verbosity != 1 || next!(meter) - push!(history, current_loss) - end - - if !isa(chain, AbstractLaplace) - la = LaplaceRedux.Laplace( - chain; - likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - else - la = chain - end - - # fit the Laplace model: - LaplaceRedux.fit!(la, zip(X, y)) - optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - - return la, optimiser_state, history -end - -""" - predict(model::LaplaceClassification, Xnew) - -Predicts the class labels for new data using the LaplaceClassification model. - -# Arguments -- `model::LaplaceClassification`: The trained LaplaceClassification model. -- fitresult: the fitresult output produced by MLJFlux.fit! -- `Xnew`: The new data to make predictions on. - -# Returns -An array of predicted class labels. - -""" -function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) - la = fitresult[1] - Xnew = MLJBase.matrix(Xnew) |> permutedims - predictions = LaplaceRedux.predict( - la, - Xnew; - link_approx=model.link_approx, - predict_proba=model.predict_proba, - ret_distr=model.ret_distr, - ) - - return predictions -end - -# metadata for each model, -MLJBase.metadata_model( - LaplaceClassification; - input=Union{ - AbstractMatrix{MLJBase.Finite}, - MLJBase.Table(MLJBase.Finite), - AbstractMatrix{MLJBase.Continuous}, - MLJBase.Table(MLJBase.Continuous), - MLJBase.Table{AbstractVector{MLJBase.Continuous}}, - MLJBase.Table{AbstractVector{MLJBase.Finite}}, - }, - target=Union{AbstractArray{MLJBase.Finite},AbstractArray{MLJBase.Continuous}}, - path="MLJFlux.LaplaceClassification", -) -# metadata for each model, -MLJBase.metadata_model( - LaplaceRegression; - input=Union{ - AbstractMatrix{MLJBase.Continuous}, - MLJBase.Table(MLJBase.Continuous), - AbstractMatrix{MLJBase.Finite}, - MLJBase.Table(MLJBase.Finite), - MLJBase.Table{AbstractVector{MLJBase.Continuous}}, - MLJBase.Table{AbstractVector{MLJBase.Finite}}, - }, - target=Union{AbstractArray{MLJBase.Finite},AbstractArray{MLJBase.Continuous}}, - path="MLJFlux.LaplaceRegression", -) diff --git a/test/Manifest.toml b/test/Manifest.toml index 3448bc9a..bed7d79f 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,14 +1,14 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.5" +julia_version = "1.11.1" manifest_format = "2.0" project_hash = "48e3a5a4625c4493599b02acbfe0e972463bd78f" [[deps.ARFFFiles]] deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] -git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" +git-tree-sha1 = "678eb18590a8bc6674363da4d5faa4ac09c40a18" uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" -version = "1.4.1" +version = "1.5.0" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] @@ -80,7 +80,7 @@ version = "2.3.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" +version = "1.1.2" [[deps.Arpack]] deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"] @@ -96,6 +96,7 @@ version = "3.5.1+1" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" [[deps.Atomix]] deps = ["UnsafeAtomics"] @@ -144,6 +145,7 @@ version = "0.4.3" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" [[deps.Baselet]] git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" @@ -162,9 +164,9 @@ version = "1.2.2" [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" +git-tree-sha1 = "8873e196c2eb87962a2048b3b8e08946535864a1" uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+1" +version = "1.0.8+2" [[deps.CEnum]] git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" @@ -173,9 +175,9 @@ version = "0.5.0" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +git-tree-sha1 = "deddd8725e5e1cc49ee205a1964256043720a6c3" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" +version = "0.10.15" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] @@ -394,6 +396,7 @@ version = "1.0.0" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" [[deps.Dbus_jll]] deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"] @@ -444,6 +447,7 @@ weakdeps = ["ChainRulesCore", "SparseArrays"] [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +version = "1.11.0" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] @@ -539,6 +543,7 @@ weakdeps = ["Mmap", "Test"] [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" [[deps.FillArrays]] deps = ["LinearAlgebra"] @@ -624,6 +629,7 @@ version = "0.4.12" [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" +version = "1.11.0" [[deps.GLFW_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"] @@ -776,6 +782,7 @@ version = "1.4.2" [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" [[deps.InternedStrings]] deps = ["Random", "Test"] @@ -816,9 +823,9 @@ version = "1.0.0" [[deps.JLD2]] deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] -git-tree-sha1 = "aeab5c68eb2cf326619bf71235d8f4561c62fe22" +git-tree-sha1 = "b464b9b461ee989b435a689a4f7d870b68d467ed" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.5.5" +version = "0.5.6" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] @@ -840,9 +847,9 @@ version = "0.21.4" [[deps.JSON3]] deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +git-tree-sha1 = "1d322381ef7b087548321d3f878cb4c9bd8f8f9b" uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.0" +version = "1.14.1" [deps.JSON3.extensions] JSON3ArrowExt = ["ArrowTypes"] @@ -948,6 +955,7 @@ version = "1.9.0" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +version = "1.11.0" [[deps.LazyModules]] git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" @@ -968,16 +976,17 @@ version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" +version = "8.6.0+0" [[deps.LibGit2]] deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" [[deps.LibGit2_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" +version = "1.7.2+0" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] @@ -986,6 +995,7 @@ version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" [[deps.Libffi_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1038,6 +1048,7 @@ version = "2.40.1+0" [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.11.0" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] @@ -1057,6 +1068,7 @@ version = "0.3.28" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" [[deps.LoggingExtras]] deps = ["Dates", "Logging"] @@ -1071,10 +1083,10 @@ uuid = "23992714-dd62-5051-b70f-ba57cb901cac" version = "0.10.7" [[deps.MLDataDevices]] -deps = ["Adapt", "Functors", "Preferences", "Random"] -git-tree-sha1 = "e16288e37e76d68c3f1c418e0a2bec88d98d55fc" +deps = ["Adapt", "Compat", "Functors", "Preferences", "Random"] +git-tree-sha1 = "f19f2629ad20176e524c71d06e1c29689ab002fa" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" -version = "1.2.0" +version = "1.4.0" [deps.MLDataDevices.extensions] MLDataDevicesAMDGPUExt = "AMDGPU" @@ -1235,6 +1247,7 @@ version = "0.4.2" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" [[deps.MbedTLS]] deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] @@ -1245,7 +1258,7 @@ version = "1.1.9" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" +version = "2.28.6+0" [[deps.Measures]] git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" @@ -1284,6 +1297,7 @@ version = "1.2.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" [[deps.MosaicViews]] deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] @@ -1293,7 +1307,7 @@ version = "0.3.4" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" +version = "2023.12.12" [[deps.MultivariateStats]] deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] @@ -1369,7 +1383,7 @@ version = "0.2.5" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" +version = "0.3.27+1" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -1494,9 +1508,13 @@ uuid = "30392449-352a-5448-841d-b1acce4e97dc" version = "0.43.4+0" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" +version = "1.11.0" +weakdeps = ["REPL"] + + [deps.Pkg.extensions] + REPLExt = "REPL" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] @@ -1505,10 +1523,10 @@ uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" version = "3.2.0" [[deps.PlotUtils]] -deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" +deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"] +git-tree-sha1 = "650a022b2ce86c7dcfbdecf00f78afeeb20e5655" uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" -version = "1.4.1" +version = "1.4.2" [[deps.Plots]] deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] @@ -1567,6 +1585,7 @@ version = "2.4.0" [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" [[deps.ProgressLogging]] deps = ["Logging", "SHA", "UUIDs"] @@ -1622,12 +1641,14 @@ version = "2.11.1" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "StyledStrings", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +version = "1.11.0" [[deps.Random]] deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" [[deps.RealDot]] deps = ["LinearAlgebra"] @@ -1711,6 +1732,7 @@ version = "1.4.5" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" [[deps.Setfield]] deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] @@ -1742,6 +1764,7 @@ version = "0.9.4" [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +version = "1.11.0" [[deps.SortingAlgorithms]] deps = ["DataStructures"] @@ -1752,7 +1775,7 @@ version = "1.2.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" +version = "1.11.0" [[deps.SparseInverseSubset]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] @@ -1790,9 +1813,9 @@ version = "0.1.1" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +git-tree-sha1 = "777657803913ffc7e8cc20f0fd04b634f871af8f" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" +version = "1.9.8" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -1831,9 +1854,14 @@ uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" version = "3.4.0" [[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -1901,6 +1929,10 @@ git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" version = "1.11.0" +[[deps.StyledStrings]] +uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b" +version = "1.11.0" + [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -1908,7 +1940,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" +version = "7.7.0+0" [[deps.TOML]] deps = ["Dates"] @@ -1952,6 +1984,7 @@ version = "0.1.1" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +version = "1.11.0" [[deps.TranscodingStreams]] git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" @@ -2011,6 +2044,7 @@ version = "1.5.1" [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" [[deps.UnPack]] git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" @@ -2019,6 +2053,7 @@ version = "1.0.2" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" [[deps.UnicodeFun]] deps = ["REPL"] @@ -2381,7 +2416,7 @@ version = "1.1.6+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" +version = "1.59.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index c9e5b689..f4b6e5fc 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -1,19 +1,20 @@ using Random: Random import Random.seed! using MLJBase: MLJBase, categorical +using MLJ: MLJ using Flux using StableRNGs import LaplaceRedux: LaplaceClassifier, LaplaceRegressor -cv = CV(; nfolds=3) +cv = MLJBase.CV(; nfolds=3) @testset "Regression" begin flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1)) model = LaplaceRegressor(; model=flux_model, epochs=50) - X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) + X, y = MLJ.make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) #train, test = partition(eachindex(y), 0.7); # 70:30 split - mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report) + mach = MLJ.machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report) MLJBase.fit!(mach; verbosity=1) #Xnew, ynew = make_regression(3, 4; rng=123) yhat = MLJBase.predict(mach, X) # probabilistic predictions @@ -27,7 +28,7 @@ cv = CV(; nfolds=3) model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) yhat = MLJBase.predict(mach, X) # probabilistic predictions - evaluate!(mach; resampling=cv, measure=log_loss, verbosity=0) + MLJ.evaluate!(mach; resampling=cv, measure=MLJ.log_loss, verbosity=0) # Define a different model flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 1)) @@ -45,8 +46,8 @@ end model = LaplaceClassifier(; model=flux_model, epochs=50) - X, y = @load_iris - mach = machine(model, X, y) + X, y = MLJ.@load_iris + mach = MLJ.machine(model, X, y) MLJBase.fit!(mach; verbosity=1) Xnew = ( sepal_length=[6.4, 7.2, 7.4], @@ -55,8 +56,8 @@ end petal_width=[2.1, 1.6, 1.9], ) yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions - predict_mode(mach, Xnew) # point predictions - pdf.(yhat, "virginica") # probabilities for the "verginica" class + MLJBase.predict_mode(mach, Xnew) # point predictions + MLJBase.pdf.(yhat, "virginica") # probabilities for the "verginica" class MLJBase.fitted_params(mach) # fitted params MLJBase.training_losses(mach) #training loss history model.epochs = 100 #changing number of epochs @@ -65,7 +66,7 @@ end MLJBase.fit!(mach) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach) #testing update function (the laplace part) - evaluate!(mach; resampling=cv, measure=brier_loss, verbosity=0) + MLJ.evaluate!(mach; resampling=cv, measure=MLJ.brier_loss, verbosity=0) # Define a different model flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 3)) diff --git a/test/runtests.jl b/test/runtests.jl index ed16d37b..1184e74a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,7 +35,7 @@ using Test include("krondecomposed.jl") end - @testset "ML" begin + @testset "MLJ" begin include("direct_mlj_interface.jl") end end From 02abec26671c0b2f8911e03e06460291301332ce Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 23 Oct 2024 10:03:57 +0200 Subject: [PATCH 53/60] removing reference to deep_propertier --- docs/Manifest.toml | 414 ++++++++++++++++++++++++++++----------------- docs/Project.toml | 1 + src/direct_mlj.jl | 4 +- 3 files changed, 266 insertions(+), 153 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index a97bfd27..d01b6356 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.5" +julia_version = "1.11.1" manifest_format = "2.0" -project_hash = "616a9e89f5c520a58672ad91b5525001e0dadab3" +project_hash = "aafefca0e42df191a3fa625af8a8b686a9db4944" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -26,31 +26,35 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" version = "0.4.5" [[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"] +git-tree-sha1 = "b392ede862e506d451fc1616e79aa6f4c673dab8" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" +version = "0.1.38" [deps.Accessors.extensions] AccessorsAxisKeysExt = "AxisKeys" + AccessorsDatesExt = "Dates" AccessorsIntervalSetsExt = "IntervalSets" AccessorsStaticArraysExt = "StaticArrays" AccessorsStructArraysExt = "StructArrays" + AccessorsTestExt = "Test" AccessorsUnitfulExt = "Unitful" [deps.Accessors.weakdeps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +git-tree-sha1 = "d80af0733c99ea80575f612813fa6aa71022d33a" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" +version = "4.1.0" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -69,7 +73,7 @@ version = "2.3.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" +version = "1.1.2" [[deps.Arpack]] deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"] @@ -115,6 +119,7 @@ version = "7.16.0" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" [[deps.Atomix]] deps = ["UnsafeAtomics"] @@ -157,6 +162,7 @@ version = "0.4.3" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" [[deps.Baselet]] git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" @@ -170,9 +176,9 @@ version = "0.1.9" [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" +git-tree-sha1 = "8873e196c2eb87962a2048b3b8e08946535864a1" uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+1" +version = "1.0.8+2" [[deps.CEnum]] git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" @@ -181,15 +187,15 @@ version = "0.5.0" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +git-tree-sha1 = "deddd8725e5e1cc49ee205a1964256043720a6c3" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" +version = "0.10.15" [[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] -git-tree-sha1 = "fdd9dfb67dfefd548f51000cc400bb51003de247" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics", "demumble_jll"] +git-tree-sha1 = "e0725a467822697171af4dae15cec10b4fc19053" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.4.3" +version = "5.5.2" [deps.CUDA.extensions] ChainRulesCoreExt = "ChainRulesCore" @@ -203,9 +209,9 @@ version = "5.4.3" [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "325058b426c2b421e3d2df3d5fa646d72d2e3e7e" +git-tree-sha1 = "ccd1e54610c222fadfd4737dac66bff786f63656" uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.9.2+0" +version = "0.10.3+0" [[deps.CUDA_Runtime_Discovery]] deps = ["Libdl"] @@ -215,21 +221,21 @@ version = "0.3.5" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "afea94249b821dc754a8ca6695d3daed851e1f5a" +git-tree-sha1 = "e43727b237b2879a34391eeb81887699a26f8f2f" uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.14.1+0" +version = "0.15.3+0" [[deps.CUDNN_jll]] deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4" +git-tree-sha1 = "9851af16a2f357a793daa0f13634c82bc7e40419" uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" -version = "9.0.0+1" +version = "9.4.0+0" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" +git-tree-sha1 = "009060c9a6168704143100f36ab08f06c2af4642" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+2" +version = "1.18.2+1" [[deps.CategoricalArrays]] deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] @@ -263,15 +269,15 @@ version = "0.1.15" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +git-tree-sha1 = "be227d253d132a6d57f9ccf5f67c0fb6488afd87" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" +version = "1.71.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" +version = "1.25.0" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] @@ -285,9 +291,9 @@ version = "0.7.6" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +git-tree-sha1 = "13951eb68769ad1cd460cdb2e64e5e95f1bf123d" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" +version = "3.27.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -411,10 +417,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.16.0" [[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "fb61b4812c49343d7ef0b533ba982c46021938a6" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" +version = "1.7.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -430,6 +436,7 @@ version = "1.0.0" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" [[deps.Dbus_jll]] deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"] @@ -486,12 +493,13 @@ weakdeps = ["ChainRulesCore", "SparseArrays"] [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +version = "1.11.0" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "e6c693a0e4394f8fda0e51a5bdf5aef26f8235e9" +git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.111" +version = "0.25.112" weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] [deps.Distributions.extensions] @@ -499,6 +507,12 @@ weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] DistributionsDensityInterfaceExt = "DensityInterface" DistributionsTestExt = "Test" +[[deps.DocInventories]] +deps = ["CodecZlib", "Downloads", "TOML"] +git-tree-sha1 = "e97cfa8680a39396924dcdca4b7ff1014ed5c499" +uuid = "43dc2714-ed3b-44b5-b226-857eda1aa7de" +version = "1.0.0" + [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" @@ -507,9 +521,21 @@ version = "0.9.3" [[deps.Documenter]] deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] -git-tree-sha1 = "9d29b99b6b2b6bc2382a4c8dbec6eb694f389853" +git-tree-sha1 = "5a1ee886566f2fa9318df1273d8b778b9d42712d" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "1.6.0" +version = "1.7.0" + +[[deps.DocumenterInterLinks]] +deps = ["CodecZlib", "DocInventories", "Documenter", "DocumenterInventoryWritingBackport", "Markdown", "MarkdownAST", "TOML"] +git-tree-sha1 = "00dceb038f6cb24f4d8d6a9f2feb85bbe58305fd" +uuid = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" +version = "1.0.0" + +[[deps.DocumenterInventoryWritingBackport]] +deps = ["CodecZlib", "Documenter", "TOML"] +git-tree-sha1 = "1b89024e375353961bb98b9818b44a4e38961cc4" +uuid = "195adf08-069f-4855-af3e-8933a2cdae94" +version = "0.1.0" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -541,9 +567,9 @@ version = "0.1.10" [[deps.FFMPEG]] deps = ["FFMPEG_jll"] -git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" +git-tree-sha1 = "53ebe7511fa11d33bec688a9178fac4e49eeee00" uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" -version = "0.4.1" +version = "0.4.2" [[deps.FFMPEG_jll]] deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] @@ -565,9 +591,9 @@ version = "0.1.1" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +git-tree-sha1 = "62ca0547a14c57e98154423419d8a342dca75ca9" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" +version = "1.16.4" [[deps.FilePathsBase]] deps = ["Compat", "Dates"] @@ -582,6 +608,7 @@ weakdeps = ["Mmap", "Test"] [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" [[deps.FillArrays]] deps = ["LinearAlgebra"] @@ -596,19 +623,21 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"] FillArraysStatisticsExt = "Statistics" [[deps.FiniteDiff]] -deps = ["ArrayInterface", "LinearAlgebra", "Setfield", "SparseArrays"] -git-tree-sha1 = "f9219347ebf700e77ca1d48ef84e4a82a6701882" +deps = ["ArrayInterface", "LinearAlgebra", "Setfield"] +git-tree-sha1 = "b10bdafd1647f57ace3885143936749d61638c3b" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.24.0" +version = "2.26.0" [deps.FiniteDiff.extensions] FiniteDiffBandedMatricesExt = "BandedMatrices" FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" + FiniteDiffSparseArraysExt = "SparseArrays" FiniteDiffStaticArraysExt = "StaticArrays" [deps.FiniteDiff.weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.FixedPointNumbers]] @@ -618,23 +647,27 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.5" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "fbf100b4bed74c9b6fac0ebd1031e04977d35b3b" +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.19" +version = "0.14.22" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" FluxCUDAExt = "CUDA" FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] FluxEnzymeExt = "Enzyme" + FluxMPIExt = "MPI" + FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] FluxMetalExt = "Metal" [deps.Flux.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.Fontconfig_jll]] @@ -679,6 +712,7 @@ version = "0.4.12" [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" +version = "1.11.0" [[deps.GLFW_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"] @@ -699,22 +733,22 @@ uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.6" [[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "ab29216184312f99ff957b32cd63c2fe9c928b91" +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "PrecompileTools", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "1d6f290a5eb1201cd63574fbc4440c788d5cb38f" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.26.7" +version = "0.27.8" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] -git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5" +git-tree-sha1 = "ee28ddcd5517d54e417182fec3886e7412d3926f" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.73.7" +version = "0.73.8" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d" +git-tree-sha1 = "f31929b9e67066bee48eec8b03c0df47d31a74b3" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.73.7+0" +version = "0.73.8+0" [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] @@ -730,15 +764,15 @@ version = "1.3.1" [[deps.Git_jll]] deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809" +git-tree-sha1 = "ea372033d09e4552a04fd38361cd019f9003f4f4" uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" -version = "2.44.0+2" +version = "2.46.2+0" [[deps.Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" +git-tree-sha1 = "674ff0db93fffcd11a3573986e550d66cd4fd71f" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.2+0" +version = "2.80.5+0" [[deps.Graphite2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -814,11 +848,12 @@ version = "1.4.2" [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" [[deps.InverseFunctions]] -git-tree-sha1 = "2787db24f4e03daf859c6509ff87764e4182f7d1" +git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.16" +version = "0.1.17" weakdeps = ["Dates", "Test"] [deps.InverseFunctions.extensions] @@ -848,9 +883,9 @@ version = "1.0.0" [[deps.JLD2]] deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] -git-tree-sha1 = "a0746c21bdc986d0dc293efa6b1faee112c37c28" +git-tree-sha1 = "b464b9b461ee989b435a689a4f7d870b68d467ed" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.53" +version = "0.5.6" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] @@ -860,9 +895,9 @@ version = "0.1.8" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40" +git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.6.0" +version = "1.6.1" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -872,9 +907,9 @@ version = "0.21.4" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" +git-tree-sha1 = "25ee0be4d43d0269027024d75a24c24d6c6e590c" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.3+0" +version = "3.0.4+0" [[deps.JuliaNVTXCallbacks_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -890,9 +925,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "cb1cff88ef2f3a157cbad75bbe6b229e1975e498" +git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.25" +version = "0.9.28" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -911,16 +946,16 @@ uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" version = "3.100.2+0" [[deps.LERC_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "36bdbc52f13a7d1dcb0f3cd694e01677a515655b" uuid = "88015f11-f218-50d7-93a8-a6af411a945d" -version = "3.0.0+1" +version = "4.0.0+0" [[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "2470e69781ddd70b8878491233cd09bc1bd7fc96" +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Unicode"] +git-tree-sha1 = "d422dfd9707bec6617335dc2ea3c5172a87d5908" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.1.0" +version = "9.1.3" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] @@ -928,9 +963,9 @@ weakdeps = ["BFloat16s"] [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "597d1c758c9ae5d985ba4202386a607c675ee700" +git-tree-sha1 = "05a8bd5a42309a9ec82f700876903abce1017dd3" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.31+0" +version = "0.0.34+0" [[deps.LLVMLoopInfo]] git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" @@ -939,24 +974,24 @@ version = "1.0.0" [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e16271d212accd09d52ee0ae98956b8a05c4b626" +git-tree-sha1 = "78211fb6cbc872f77cad3fc0b6cf647d923f4929" uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" -version = "17.0.6+0" +version = "18.1.7+0" [[deps.LZO_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" +git-tree-sha1 = "854a9c268c43b77b0a27f22d7fab8d33cdb3a731" uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.2+0" +version = "2.10.2+1" [[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" +version = "1.4.0" [[deps.LaplaceRedux]] -deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] -path = "C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl" +deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] +path = ".." uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" version = "1.1.1" @@ -984,6 +1019,7 @@ version = "1.2.2" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +version = "1.11.0" [[deps.LearnAPI]] deps = ["InteractiveUtils", "Statistics"] @@ -999,16 +1035,17 @@ version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" +version = "8.6.0+0" [[deps.LibGit2]] deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" [[deps.LibGit2_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" +version = "1.7.2+0" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] @@ -1017,6 +1054,7 @@ version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" [[deps.Libffi_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1056,9 +1094,9 @@ version = "2.40.1+0" [[deps.Libtiff_jll]] deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] -git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a" +git-tree-sha1 = "b404131d06f7886402758c9ce2214b636eb4d54a" uuid = "89763e89-9b03-5906-acba-b20f662cd828" -version = "4.5.1+1" +version = "4.7.0+0" [[deps.Libuuid_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1075,6 +1113,7 @@ version = "7.3.0" [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.11.0" [[deps.LinearMaps]] deps = ["LinearAlgebra"] @@ -1106,6 +1145,7 @@ version = "0.3.28" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" [[deps.LoggingExtras]] deps = ["Dates", "Logging"] @@ -1113,6 +1153,46 @@ git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.0.3" +[[deps.MLDataDevices]] +deps = ["Adapt", "Compat", "Functors", "LinearAlgebra", "Preferences", "Random"] +git-tree-sha1 = "3207c2e66164e6366440ad3f0243a8d67abb4a47" +uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +version = "1.4.1" + + [deps.MLDataDevices.extensions] + MLDataDevicesAMDGPUExt = "AMDGPU" + MLDataDevicesCUDAExt = "CUDA" + MLDataDevicesChainRulesCoreExt = "ChainRulesCore" + MLDataDevicesFillArraysExt = "FillArrays" + MLDataDevicesGPUArraysExt = "GPUArrays" + MLDataDevicesMLUtilsExt = "MLUtils" + MLDataDevicesMetalExt = ["GPUArrays", "Metal"] + MLDataDevicesReactantExt = "Reactant" + MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" + MLDataDevicesReverseDiffExt = "ReverseDiff" + MLDataDevicesSparseArraysExt = "SparseArrays" + MLDataDevicesTrackerExt = "Tracker" + MLDataDevicesZygoteExt = "Zygote" + MLDataDevicescuDNNExt = ["CUDA", "cuDNN"] + MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] + + [deps.MLDataDevices.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" + [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9" @@ -1175,6 +1255,7 @@ version = "0.5.13" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" [[deps.MarkdownAST]] deps = ["AbstractTrees", "Markdown"] @@ -1191,7 +1272,7 @@ version = "1.1.9" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" +version = "2.28.6+0" [[deps.Measures]] git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" @@ -1200,9 +1281,9 @@ version = "0.3.2" [[deps.Metalhead]] deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"] -git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152" +git-tree-sha1 = "aef476e4958303f5ea9e1deb81a1ba2f510d4e11" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.9.3" +version = "0.9.4" weakdeps = ["CUDA"] [deps.Metalhead.extensions] @@ -1222,6 +1303,7 @@ version = "1.2.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" [[deps.Mocking]] deps = ["Compat", "ExprTools"] @@ -1231,7 +1313,7 @@ version = "0.8.1" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" +version = "2023.12.12" [[deps.MultivariateStats]] deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] @@ -1246,10 +1328,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" version = "7.8.3" [[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "ae52c156a63bb647f80c26319b104e99e5977e51" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "da09a1e112fd75f9af2a5229323f01b56ec96a4c" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.22" +version = "0.9.24" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1257,12 +1339,14 @@ version = "0.9.22" NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" + NNlibForwardDiffExt = "ForwardDiff" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NVTX]] @@ -1302,9 +1386,9 @@ version = "0.2.3" [[deps.NearestNeighbors]] deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" +git-tree-sha1 = "3cebfc94a0754cc329ebc3bab1e6c89621e791ad" uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.18" +version = "0.4.20" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -1325,7 +1409,7 @@ version = "0.2.5" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" +version = "0.3.27+1" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -1340,9 +1424,9 @@ version = "1.4.3" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" +version = "3.0.15+1" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1432,9 +1516,13 @@ uuid = "30392449-352a-5448-841d-b1acce4e97dc" version = "0.43.4+0" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" +version = "1.11.0" +weakdeps = ["REPL"] + + [deps.Pkg.extensions] + REPLExt = "REPL" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] @@ -1443,10 +1531,10 @@ uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" version = "3.2.0" [[deps.PlotUtils]] -deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" +deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"] +git-tree-sha1 = "650a022b2ce86c7dcfbdecf00f78afeeb20e5655" uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" -version = "1.4.1" +version = "1.4.2" [[deps.Plots]] deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] @@ -1499,13 +1587,14 @@ version = "0.2.0" [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +git-tree-sha1 = "1101cd475833706e4d0e7b122218257178f48f34" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" +version = "2.4.0" [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" [[deps.ProgressLogging]] deps = ["Logging", "SHA", "UUIDs"] @@ -1550,9 +1639,9 @@ version = "6.7.1+1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "1d587203cf851a51bf1ea31ad7ff89eff8d625ea" +git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.11.0" +version = "2.11.1" [deps.QuadGK.extensions] QuadGKEnzymeExt = "Enzyme" @@ -1573,12 +1662,14 @@ uuid = "ce6b1742-4840-55fa-b093-852dadbb1d8b" version = "0.7.7" [[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "StyledStrings", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +version = "1.11.0" [[deps.Random]] deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" [[deps.Random123]] deps = ["Random", "RandomNumbers"] @@ -1647,15 +1738,15 @@ version = "1.3.0" [[deps.Rmath]] deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" +version = "0.8.0" [[deps.Rmath_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e60724fd3beea548353984dc61c943ecddb0e29a" +git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.3+0" +version = "0.5.1+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -1692,6 +1783,7 @@ version = "1.4.5" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" [[deps.Setfield]] deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] @@ -1711,9 +1803,9 @@ uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" version = "1.0.3" [[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +git-tree-sha1 = "f305871d2f381d21527c770d4788c06c097c9bc1" uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" +version = "1.2.0" [[deps.SimpleTraits]] deps = ["InteractiveUtils", "MacroTools"] @@ -1723,6 +1815,7 @@ version = "0.9.4" [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +version = "1.11.0" [[deps.SortingAlgorithms]] deps = ["DataStructures"] @@ -1733,7 +1826,7 @@ version = "1.2.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" +version = "1.11.0" [[deps.SparseInverseSubset]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] @@ -1765,9 +1858,9 @@ version = "1.0.2" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +git-tree-sha1 = "777657803913ffc7e8cc20f0fd04b634f871af8f" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" +version = "1.9.8" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -1792,9 +1885,14 @@ uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" version = "3.4.0" [[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -1810,9 +1908,9 @@ version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" +git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.1" +version = "1.3.2" weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] @@ -1821,9 +1919,9 @@ weakdeps = ["ChainRulesCore", "InverseFunctions"] [[deps.StringManipulation]] deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +git-tree-sha1 = "a6b1675a536c5ad1a60e5a5153e1fee12eb146e3" uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" +version = "0.4.0" [[deps.StructArrays]] deps = ["ConstructionBase", "DataAPI", "Tables"] @@ -1838,6 +1936,10 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] StructArraysSparseArraysExt = "SparseArrays" StructArraysStaticArraysExt = "StaticArrays" +[[deps.StyledStrings]] +uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b" +version = "1.11.0" + [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -1845,7 +1947,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" +version = "7.7.0+0" [[deps.TOML]] deps = ["Dates"] @@ -1854,9 +1956,9 @@ version = "1.0.3" [[deps.TZJData]] deps = ["Artifacts"] -git-tree-sha1 = "1607ad46cf8d642aa779a1d45af1c8620dbf6915" +git-tree-sha1 = "36b40607bf2bf856828690e097e1c799623b0602" uuid = "dc5dba14-91b3-4cab-a142-028a31da12f7" -version = "1.2.0+2024a" +version = "1.3.0+2024b" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -1871,16 +1973,15 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.12.0" [[deps.TaijaBase]] -deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"] -git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0" +git-tree-sha1 = "4076f60078b12095ca71a2c26e2e4515e3a6a5e5" uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6" -version = "1.2.2" +version = "1.2.3" [[deps.TaijaPlotting]] -deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots", "Trapz"] -git-tree-sha1 = "2fc71041e1c215cf6ef3dc2d3b8419499c4b40ff" +deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "OneHotArrays", "Plots", "RecipesBase", "Trapz"] +git-tree-sha1 = "01c76535e8b87b05d0fbc275d44714c78e49fe6e" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" -version = "1.2.0" +version = "1.3.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -1896,6 +1997,7 @@ version = "0.1.1" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +version = "1.11.0" [[deps.ThreadsX]] deps = ["Accessors", "ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "SplittablesBase", "Transducers"] @@ -1905,9 +2007,9 @@ version = "0.1.12" [[deps.TimeZones]] deps = ["Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"] -git-tree-sha1 = "b92aebdd3555f3a7e3267cf17702033c2814ef48" +git-tree-sha1 = "8323074bc977aa85cf5ad71099a83ac75b0ac107" uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" -version = "1.18.0" +version = "1.18.1" weakdeps = ["RecipesBase"] [deps.TimeZones.extensions] @@ -1915,22 +2017,23 @@ weakdeps = ["RecipesBase"] [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" +git-tree-sha1 = "3a6f063d690135f5c1ba351412c82bae4d1402bf" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.24" +version = "0.5.25" [[deps.TranscodingStreams]] -git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2" +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.2" +version = "0.11.3" [[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +deps = ["Accessors", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "7deeab4ff96b85c5f72c824cae53a1398da3d1cb" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" +version = "0.4.84" [deps.Transducers.extensions] + TransducersAdaptExt = "Adapt" TransducersBlockArraysExt = "BlockArrays" TransducersDataFramesExt = "DataFrames" TransducersLazyArraysExt = "LazyArrays" @@ -1938,6 +2041,7 @@ version = "0.4.82" TransducersReferenceablesExt = "Referenceables" [deps.Transducers.weakdeps] + Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -1975,6 +2079,7 @@ version = "1.5.1" [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" [[deps.UnPack]] git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" @@ -1983,6 +2088,7 @@ version = "1.0.2" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" [[deps.UnicodeFun]] deps = ["REPL"] @@ -2221,15 +2327,15 @@ version = "1.2.13+1" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" +git-tree-sha1 = "555d1076590a6cc2fdee2ef1469451f872d8b41b" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.6+0" +version = "1.5.6+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +git-tree-sha1 = "f816633be6dc5c0ed9ffedda157ecfda0b3b6a69" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" +version = "0.6.72" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -2249,9 +2355,15 @@ version = "0.2.5" [[deps.cuDNN]] deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] -git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e" +git-tree-sha1 = "4b3ac62501ca73263eaa0d034c772f13c647fba6" uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" -version = "1.3.2" +version = "1.4.0" + +[[deps.demumble_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6498e3581023f8e530f34760d18f75a69e3a4ea8" +uuid = "1e29f10c-031c-5a83-9565-69cddfc27673" +version = "1.3.0+0" [[deps.eudev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"] @@ -2314,9 +2426,9 @@ version = "1.18.0+0" [[deps.libpng_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" +git-tree-sha1 = "b70c870239dc3d7bc094eb2d6be9b73d27bef280" uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.43+1" +version = "1.6.44+0" [[deps.libvorbis_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] @@ -2333,7 +2445,7 @@ version = "1.1.6+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" +version = "1.59.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] diff --git a/docs/Project.toml b/docs/Project.toml index c8d6234a..e8f6d849 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,7 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 42ec720a..c45d996b 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -312,7 +312,7 @@ The meaining of "equal" depends on the type of the property value: - values that are not of `MLJType` are "equal" if they are `==`. In the special case of a "deep" property, "equal" has a different -meaning; see [`MMI.StatTraits.deep_properties`](@ref)) for details. +meaning; see [`MLJBase.deep_properties`](@ref)) for details. If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`. @@ -327,7 +327,7 @@ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Sy if !_isdefined(m1, name) !_isdefined(m2, name) || return false elseif _isdefined(m2, name) - if name in MMI.StatTraits.deep_properties(LaplaceRegressor) + if name in MLJBase.deep_properties(LaplaceRegressor) _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) || return false else From 374aca5bf79b3c5bfedcd86d657a4b1c1fef287b Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 23 Oct 2024 10:16:16 +0200 Subject: [PATCH 54/60] hadn't saved file --- src/direct_mlj.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index c45d996b..aa644bd9 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -312,7 +312,7 @@ The meaining of "equal" depends on the type of the property value: - values that are not of `MLJType` are "equal" if they are `==`. In the special case of a "deep" property, "equal" has a different -meaning; see [`MLJBase.deep_properties`](@ref)) for details. +meaning; see `MLJBase.deep_properties` for details. If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`. From 59917f86d3c42a8f3c939ffa6b21b6336ecd6ec2 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 30 Oct 2024 00:46:40 +0100 Subject: [PATCH 55/60] added default mlp --- src/baselaplace/predicting.jl | 2 +- src/direct_mlj.jl | 70 +++++++++++++++++++++++++++++++++-- test/direct_mlj_interface.jl | 40 ++++++++++++++------ test/runtests.jl | 4 +- 4 files changed, 99 insertions(+), 17 deletions(-) diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index d26c9a07..feb457b3 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -93,7 +93,7 @@ Computes the Bayesian predictivie distribution from a neural network with a Lapl - `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`. - `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model. - `return_distr::Bool=false`: if `false` (default), the function outputs either the direct output of the chain or pseudo-probabilities (if `predict_proba=true`). - if `true` predict return a probability distribution. + if `true` predict returns a probability distribution. # Returns diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index aa644bd9..82c6d81b 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -65,6 +65,63 @@ MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,) MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:, I], y[2])) MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],) + + +""" + function features_shape(model::LaplaceRegression, X, y) + +Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset. + +# Arguments +- `model::LaplaceModels`: The Laplace model to fit. +- `X`: The input data for training. +- `y`: The target labels for training one-hot encoded. + +# Returns +- (input size, output size) +""" +function features_shape(model::LaplaceModels, X, y) + #X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X + n_input = size(X, 1) + dims = size(y) + if length(dims) == 1 + n_output = 1 + else + n_output = dims[1] + end + return (n_input, n_output) +end + + +""" + default_build( seed::Int, shape) + +Builds a default MLP Flux model compatible with the dimensions of the dataset, with reproducible initial weights. + +# Arguments +- `seed::Int`: The seed for random number generation. +- `shape`: a tuple containing the dimensions of the input layer and the output layer. + +# Returns +- The constructed Flux model, which consist in a simple MLP with 2 hidden layers with 20 neurons each and an input and output layers compatible with the dataset. +""" +function default_build(seed::Int, shape) + Random.seed!(seed) + (n_input, n_output) = shape + + chain = Chain( + Dense(n_input, 20, relu), + Dense(20, 20, relu), + #Dense(20, 20, relu), + Dense(20, n_output) + ) + + return chain +end + + + + @doc """ MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y) @@ -84,6 +141,13 @@ Fit a Laplace model using the provided features and target values. function MMI.fit(m::LaplaceModels, verbosity, X, y) y, decode = y + if (m.model === nothing) + shape = features_shape(m, X, y) + + m.model = default_build(11, shape) + + end + # Make a copy of the model because Flux does not allow to mutate hyperparameters copied_model = deepcopy(m.model) @@ -141,7 +205,7 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y) # fit the Laplace model: LaplaceRedux.fit!(la, data_loader) - optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) + optimize_prior!(la; verbosity= verbosity, n_steps=m.fit_prior_nsteps) fitresult = (la, decode) report = (loss_history=loss_history,) @@ -229,7 +293,7 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) # fit the Laplace model: LaplaceRedux.fit!(la, data_loader) - optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) + optimize_prior!(la; verbosity = verbosity, n_steps=m.fit_prior_nsteps) fitresult = (la, decode) report = (loss_history=old_loss_history,) @@ -276,7 +340,7 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) # fit the Laplace model: LaplaceRedux.fit!(la, data_loader) - optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps) + optimize_prior!(la; verbosity = verbosity, n_steps=m.fit_prior_nsteps) fitresult = (la, decode) report = (loss_history=old_loss_history,) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index f4b6e5fc..e2ec4022 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -9,24 +9,24 @@ import LaplaceRedux: LaplaceClassifier, LaplaceRegressor cv = MLJBase.CV(; nfolds=3) @testset "Regression" begin + @info " testing interface for LaplaceRegressor" flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1)) model = LaplaceRegressor(; model=flux_model, epochs=50) X, y = MLJ.make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) #train, test = partition(eachindex(y), 0.7); # 70:30 split - mach = MLJ.machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report) - MLJBase.fit!(mach; verbosity=1) - #Xnew, ynew = make_regression(3, 4; rng=123) + mach = MLJ.machine(model, X, y) + MLJBase.fit!(mach; verbosity=0) yhat = MLJBase.predict(mach, X) # probabilistic predictions MLJBase.predict_mode(mach, X) # point predictions MLJBase.fitted_params(mach) #fitted params function MLJBase.training_losses(mach) #training loss history model.epochs = 100 #changing number of epochs - MLJBase.fit!(mach) #testing update function + MLJBase.fit!(mach; verbosity=0) #testing update function model.epochs = 50 #changing number of epochs to a lower number - MLJBase.fit!(mach) #testing update function + MLJBase.fit!(mach; verbosity=0) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps - MLJBase.fit!(mach) #testing update function (the laplace part) + MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part) yhat = MLJBase.predict(mach, X) # probabilistic predictions MLJ.evaluate!(mach; resampling=cv, measure=MLJ.log_loss, verbosity=0) @@ -34,13 +34,29 @@ cv = MLJBase.CV(; nfolds=3) flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 1)) # test update! fallback to fit! model.model = flux_model_two - MLJBase.fit!(mach) + MLJBase.fit!(mach; verbosity=0) model_two = LaplaceRegressor(; model=flux_model_two, epochs=100) @test !MLJBase.is_same_except(model, model_two) + + + #testing default mlp builder + model = LaplaceRegressor(; model=nothing, epochs=50) + mach = MLJ.machine(model, X, y) + MLJBase.fit!(mach; verbosity=0) + yhat = MLJBase.predict(mach, X) # probabilistic predictions + MLJBase.predict_mode(mach, X) # point predictions + MLJBase.fitted_params(mach) #fitted params function + MLJBase.training_losses(mach) #training loss history + model.epochs = 100 #changing number of epochs + MLJBase.fit!(mach; verbosity=0) #testing update function + + + end @testset "Classification" begin + @info " testing interface for LaplaceClassifier" # Define the model flux_model = Chain(Dense(4, 10, relu), Dense(10, 3)) @@ -48,7 +64,7 @@ end X, y = MLJ.@load_iris mach = MLJ.machine(model, X, y) - MLJBase.fit!(mach; verbosity=1) + MLJBase.fit!(mach; verbosity=0) Xnew = ( sepal_length=[6.4, 7.2, 7.4], sepal_width=[2.8, 3.0, 2.8], @@ -61,11 +77,11 @@ end MLJBase.fitted_params(mach) # fitted params MLJBase.training_losses(mach) #training loss history model.epochs = 100 #changing number of epochs - MLJBase.fit!(mach) #testing update function + MLJBase.fit!(mach; verbosity=0) #testing update function model.epochs = 50 #changing number of epochs to a lower number - MLJBase.fit!(mach) #testing update function + MLJBase.fit!(mach; verbosity=0) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps - MLJBase.fit!(mach) #testing update function (the laplace part) + MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part) MLJ.evaluate!(mach; resampling=cv, measure=MLJ.brier_loss, verbosity=0) # Define a different model @@ -73,5 +89,5 @@ end model.model = flux_model_two - MLJBase.fit!(mach) + MLJBase.fit!(mach; verbosity=0) end diff --git a/test/runtests.jl b/test/runtests.jl index 92d97033..b29d8e14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,5 +34,7 @@ using Test @testset "KronDecomposed" begin include("krondecomposed.jl") end - + @testset "Interface" begin + include("direct_mlj_interface.jl") + end end From 9da5d7daaff901f63dcc9d4be4108eb68d96d170 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 30 Oct 2024 01:09:07 +0100 Subject: [PATCH 56/60] reducing number of epochs and trying to extende patch coverage --- src/direct_mlj.jl | 6 +++--- test/direct_mlj_interface.jl | 25 ++++++++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 82c6d81b..f2d4a204 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -68,7 +68,7 @@ MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],) """ - function features_shape(model::LaplaceRegression, X, y) + function dataset_shape(model::LaplaceRegression, X, y) Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset. @@ -80,7 +80,7 @@ Compute the the number of features of the X input dataset and the number of var # Returns - (input size, output size) """ -function features_shape(model::LaplaceModels, X, y) +function dataset_shape(model::LaplaceModels, X, y) #X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X n_input = size(X, 1) dims = size(y) @@ -142,7 +142,7 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y) y, decode = y if (m.model === nothing) - shape = features_shape(m, X, y) + shape = dataset_shape(m, X, y) m.model = default_build(11, shape) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index e2ec4022..88e26e37 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -11,8 +11,9 @@ cv = MLJBase.CV(; nfolds=3) @testset "Regression" begin @info " testing interface for LaplaceRegressor" flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1)) - model = LaplaceRegressor(; model=flux_model, epochs=50) + model = LaplaceRegressor(; model=flux_model, epochs=20) + #testing more complex dataset X, y = MLJ.make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) #train, test = partition(eachindex(y), 0.7); # 70:30 split mach = MLJ.machine(model, X, y) @@ -21,9 +22,9 @@ cv = MLJBase.CV(; nfolds=3) MLJBase.predict_mode(mach, X) # point predictions MLJBase.fitted_params(mach) #fitted params function MLJBase.training_losses(mach) #training loss history - model.epochs = 100 #changing number of epochs + model.epochs = 40 #changing number of epochs MLJBase.fit!(mach; verbosity=0) #testing update function - model.epochs = 50 #changing number of epochs to a lower number + model.epochs = 30 #changing number of epochs to a lower number MLJBase.fit!(mach; verbosity=0) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part) @@ -41,15 +42,21 @@ cv = MLJBase.CV(; nfolds=3) #testing default mlp builder - model = LaplaceRegressor(; model=nothing, epochs=50) + model = LaplaceRegressor(; model=nothing, epochs=20) mach = MLJ.machine(model, X, y) - MLJBase.fit!(mach; verbosity=0) + MLJBase.fit!(mach; verbosity=1) yhat = MLJBase.predict(mach, X) # probabilistic predictions MLJBase.predict_mode(mach, X) # point predictions MLJBase.fitted_params(mach) #fitted params function MLJBase.training_losses(mach) #training loss history model.epochs = 100 #changing number of epochs - MLJBase.fit!(mach; verbosity=0) #testing update function + MLJBase.fit!(mach; verbosity=1) #testing update function + + #testing dataset_shape for one dimensional function + X, y = MLJ.make_regression(100, 1; noise=0.5, sparse=0.2, outliers=0.1) + model = LaplaceRegressor(; model=nothing, epochs=20) + mach = MLJ.machine(model, X, y) + MLJBase.fit!(mach; verbosity=0) @@ -60,7 +67,7 @@ end # Define the model flux_model = Chain(Dense(4, 10, relu), Dense(10, 3)) - model = LaplaceClassifier(; model=flux_model, epochs=50) + model = LaplaceClassifier(; model=flux_model, epochs=20) X, y = MLJ.@load_iris mach = MLJ.machine(model, X, y) @@ -76,9 +83,9 @@ end MLJBase.pdf.(yhat, "virginica") # probabilities for the "verginica" class MLJBase.fitted_params(mach) # fitted params MLJBase.training_losses(mach) #training loss history - model.epochs = 100 #changing number of epochs + model.epochs = 40 #changing number of epochs MLJBase.fit!(mach; verbosity=0) #testing update function - model.epochs = 50 #changing number of epochs to a lower number + model.epochs = 30 #changing number of epochs to a lower number MLJBase.fit!(mach; verbosity=0) #testing update function model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part) From 503763dd3a0c0c91d195d09f31d625914826e01c Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 30 Oct 2024 17:33:15 +0100 Subject: [PATCH 57/60] removed the else because it seems to have no role. --- src/direct_mlj.jl | 4 ++-- test/direct_mlj_interface.jl | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index f2d4a204..3e87a520 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -346,8 +346,8 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) report = (loss_history=old_loss_history,) cache = (deepcopy(m), old_state_tree, old_loss_history) - else - fitresult, cache, report = MLJBase.fit(m, verbosity, X, y) + #else + #fitresult, cache, report = MLJBase.fit(m, verbosity, X, y) end return fitresult, cache, report diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 88e26e37..5e658c58 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -37,8 +37,6 @@ cv = MLJBase.CV(; nfolds=3) model.model = flux_model_two MLJBase.fit!(mach; verbosity=0) - model_two = LaplaceRegressor(; model=flux_model_two, epochs=100) - @test !MLJBase.is_same_except(model, model_two) #testing default mlp builder From e78e9b89a24b22fd3838023d05d52d4f6db5125f Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 30 Oct 2024 17:34:12 +0100 Subject: [PATCH 58/60] ops forgot to remove the comment --- src/direct_mlj.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 3e87a520..a1628cad 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -346,8 +346,6 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) report = (loss_history=old_loss_history,) cache = (deepcopy(m), old_state_tree, old_loss_history) - #else - #fitresult, cache, report = MLJBase.fit(m, verbosity, X, y) end return fitresult, cache, report From 6a5f26f9478cc6700f1c214520bed3c3a2ba3c30 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 30 Oct 2024 18:06:09 +0100 Subject: [PATCH 59/60] various change in the documentation --- src/direct_mlj.jl | 9 +++------ test/direct_mlj_interface.jl | 2 ++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index a1628cad..2a8187a5 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -1,4 +1,3 @@ -#module MLJLaplaceRedux using Optimisers: Optimisers using Flux using Random @@ -81,7 +80,6 @@ Compute the the number of features of the X input dataset and the number of var - (input size, output size) """ function dataset_shape(model::LaplaceModels, X, y) - #X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X n_input = size(X, 1) dims = size(y) if length(dims) == 1 @@ -142,6 +140,7 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y) y, decode = y if (m.model === nothing) + @warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 3 hidden layers with 20 neurons each and input and output layers compatible with the dataset." shape = dataset_shape(m, X, y) m.model = default_build(11, shape) @@ -606,7 +605,7 @@ Train the machine using `fit!(mach, rows=...)`. # Hyperparameters (format: name-type-default value-restrictions) -- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. +- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each. - `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function @@ -744,8 +743,7 @@ Train the machine using `fit!(mach, rows=...)`. # Hyperparameters (format: name-type-default value-restrictions) -- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset. - +- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each. - `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function - `optimiser = Adam()` a Flux optimiser @@ -846,4 +844,3 @@ See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl """ LaplaceRegressor -#end # module diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 5e658c58..8f2b0438 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -36,6 +36,8 @@ cv = MLJBase.CV(; nfolds=3) # test update! fallback to fit! model.model = flux_model_two MLJBase.fit!(mach; verbosity=0) + model_two = LaplaceRegressor(; model=flux_model_two, epochs=100) + @test !MLJBase.is_same_except(model, model_two) From 12f2584462f5dbad7090941f3d489071fcb258a8 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 30 Oct 2024 18:12:54 +0100 Subject: [PATCH 60/60] ufffffffffffffffffffff --- src/direct_mlj.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 2a8187a5..61aab7dd 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -110,7 +110,7 @@ function default_build(seed::Int, shape) chain = Chain( Dense(n_input, 20, relu), Dense(20, 20, relu), - #Dense(20, 20, relu), + Dense(20, 20, relu), Dense(20, n_output) ) @@ -140,7 +140,7 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y) y, decode = y if (m.model === nothing) - @warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 3 hidden layers with 20 neurons each and input and output layers compatible with the dataset." + @warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each and input and output layers compatible with the dataset." shape = dataset_shape(m, X, y) m.model = default_build(11, shape) @@ -605,7 +605,7 @@ Train the machine using `fit!(mach, rows=...)`. # Hyperparameters (format: name-type-default value-restrictions) -- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each. +- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each. - `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function @@ -743,7 +743,7 @@ Train the machine using `fit!(mach, rows=...)`. # Hyperparameters (format: name-type-default value-restrictions) -- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each. +- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each. - `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function - `optimiser = Adam()` a Flux optimiser