diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 64bf7af1..22e67adc 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -22,6 +22,7 @@ jobs:
version:
- '1.9'
- '1.10'
+ - '1'
os:
- ubuntu-latest
- windows-latest
diff --git a/Project.toml b/Project.toml
index a06a996b..0124164e 100644
--- a/Project.toml
+++ b/Project.toml
@@ -4,11 +4,14 @@ authors = ["Patrick Altmeyer"]
version = "1.1.1"
[deps]
+CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
+MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -19,20 +22,23 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
Aqua = "0.8"
+CategoricalDistributions = "0.1.15"
ChainRulesCore = "1.23.0"
Compat = "4.7.0"
Distributions = "0.25.109"
Flux = "0.12, 0.13, 0.14"
LinearAlgebra = "1.7, 1.10"
+MLJBase = "1"
+MLJModelInterface = "1.8.0"
MLUtils = "0.4"
Optimisers = "0.2, 0.3"
Random = "1.9, 1.10"
Statistics = "1"
Tables = "1.10.1"
-Test = "1.9, 1.10"
+Test = "1"
Tullio = "0.3.5"
Zygote = "0.6"
-julia = "1.9, 1.10"
+julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
diff --git a/dev/issues/predict_slow.jl b/dev/issues/predict_slow.jl
index ae713cbd..95c7c1f8 100644
--- a/dev/issues/predict_slow.jl
+++ b/dev/issues/predict_slow.jl
@@ -4,16 +4,16 @@ using Optimisers
X = MLJBase.table(rand(Float32, 100, 3));
y = coerce(rand("abc", 100), Multiclass);
-model = LaplaceClassification(optimiser=Optimisers.Adam(0.1), epochs=100);
+model = LaplaceClassification(; optimiser=Optimisers.Adam(0.1), epochs=100);
fitresult, _, _ = MLJBase.fit(model, 2, X, y);
la = fitresult[1];
Xmat = matrix(X) |> permutedims;
# Single test sample:
-Xtest = Xmat[:,1:10];
+Xtest = Xmat[:, 1:10];
Xtest_tab = MLJBase.table(Xtest');
MLJBase.predict(model, fitresult, Xtest_tab); # warm up
LaplaceRedux.predict(la, Xmat); # warm up
@time MLJBase.predict(model, fitresult, Xtest_tab);
@time LaplaceRedux.predict(la, Xtest);
-@time glm_predictive_distribution(la, Xtest);
\ No newline at end of file
+@time glm_predictive_distribution(la, Xtest);
diff --git a/dev/notebooks/mlj-interfacing/direct_mlj.ipynb b/dev/notebooks/mlj-interfacing/direct_mlj.ipynb
new file mode 100644
index 00000000..4169dc0e
--- /dev/null
+++ b/dev/notebooks/mlj-interfacing/direct_mlj.ipynb
@@ -0,0 +1,1137 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "using Revise\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\docs`\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m\u001b[1mStatus\u001b[22m\u001b[39m `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\docs\\Project.toml`\n",
+ " \u001b[90m[324d7699] \u001b[39mCategoricalArrays v0.10.8\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[a93c6f00] \u001b[39mDataFrames v1.6.1\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[31c24e10] \u001b[39mDistributions v0.25.111\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[e30172f5] \u001b[39mDocumenter v1.6.0\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[587475ba] \u001b[39mFlux v0.14.19\n",
+ " \u001b[90m[c52c1a26] \u001b[39mLaplaceRedux v1.1.1 `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl`\n",
+ "\u001b[33m⌅\u001b[39m \u001b[90m[094fc8d1] \u001b[39mMLJFlux v0.5.1\n",
+ " \u001b[90m[e80e1ace] \u001b[39mMLJModelInterface v1.11.0\n",
+ " \u001b[90m[3bd65402] \u001b[39mOptimisers v0.3.3\n",
+ " \u001b[90m[91a5bcdd] \u001b[39mPlots v1.40.8\n",
+ " \u001b[90m[ce6b1742] \u001b[39mRDatasets v0.7.7\n",
+ " \u001b[90m[860ef19b] \u001b[39mStableRNGs v1.0.2\n",
+ " \u001b[90m[2913bbd2] \u001b[39mStatsBase v0.34.3\n",
+ " \u001b[90m[bd369af6] \u001b[39mTables v1.12.0\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[bd7198b4] \u001b[39mTaijaPlotting v1.2.0\n",
+ " \u001b[90m[592b5752] \u001b[39mTrapz v2.0.3\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[e88e6eb3] \u001b[39mZygote v0.6.70\n",
+ "\u001b[32m⌃\u001b[39m \u001b[90m[02a925ec] \u001b[39mcuDNN v1.3.2\n",
+ " \u001b[90m[9a3f8284] \u001b[39mRandom\n",
+ " \u001b[90m[10745b16] \u001b[39mStatistics v1.10.0\n",
+ "\u001b[36m\u001b[1mInfo\u001b[22m\u001b[39m Packages marked with \u001b[32m⌃\u001b[39m and \u001b[33m⌅\u001b[39m have new versions available. Those with \u001b[32m⌃\u001b[39m may be upgradable, but those with \u001b[33m⌅\u001b[39m are restricted by compatibility constraints from upgrading. To see why use `status --outdated`\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING: using MLJBase.predict in module LaplaceRedux conflicts with an existing identifier.\n",
+ "WARNING: using MLJBase.fit! in module LaplaceRedux conflicts with an existing identifier.\n"
+ ]
+ }
+ ],
+ "source": [
+ "using Pkg\n",
+ "\n",
+ "\n",
+ "Pkg.activate(\"C:/Users/Pasqu/Documents/julia_projects/LaplaceRedux.jl/docs\")\n",
+ "Pkg.status()\n",
+ "using LaplaceRedux\n",
+ "using MLJBase\n",
+ "using Random\n",
+ "using DataFrames\n",
+ "using Flux"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(Tables.MatrixTable{Matrix{Float64}} with 100 rows, 2 columns, and schema:\n",
+ " :x1 Float64\n",
+ " :x2 Float64, [-0.15053240354230857, -0.16143107735113107, -0.28104782384528254, 0.8905842690519058, -0.2716955057136559, 0.9606721208381163, 0.14403243794060133, 0.13743002853667605, 0.820641892942472, 0.2270783932443115 … 0.4639933046961763, 0.30449384622096687, 0.2744588755171263, -31.785240173822324, 2.58951832655098, 0.10969223924903307, 0.1600255529817666, 0.5913011997917647, -0.39253898214541977, -0.7247243478011852])"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "X, y = make_regression(100, 2; noise=0.5, sparse=0.2, outliers=0.1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of rows: 100\n",
+ "Number of columns: 2\n"
+ ]
+ }
+ ],
+ "source": [
+ "using Tables\n",
+ "# Get the number of rows\n",
+ "num_rows = Tables.rows(X) |> length\n",
+ "\n",
+ "# Get the number of columns\n",
+ "num_columns = Tables.columnnames(X) |> length\n",
+ "\n",
+ "# Display the dimensions\n",
+ "println(\"Number of rows: \", num_rows)\n",
+ "println(\"Number of columns: \", num_columns)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Chain(\n",
+ " Dense(2 => 10, relu), \u001b[90m# 30 parameters\u001b[39m\n",
+ " Dense(10 => 10, relu), \u001b[90m# 110 parameters\u001b[39m\n",
+ " Dense(10 => 1), \u001b[90m# 11 parameters\u001b[39m\n",
+ ") \u001b[90m # Total: 6 arrays, \u001b[39m151 parameters, 988 bytes."
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using Flux\n",
+ "# Define the model\n",
+ "flux_model = Chain(\n",
+ " Dense(2, 10, relu),\n",
+ " Dense(10, 10, relu),\n",
+ " Dense(10, 1)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "LaplaceRegressor(\n",
+ " model = Chain(Dense(2 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), \n",
+ " flux_loss = Flux.Losses.mse, \n",
+ " optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), \n",
+ " epochs = 1000, \n",
+ " batch_size = 32, \n",
+ " subset_of_weights = :all, \n",
+ " subnetwork_indices = nothing, \n",
+ " hessian_structure = :full, \n",
+ " backend = :GGN, \n",
+ " σ = 1.0, \n",
+ " μ₀ = 0.0, \n",
+ " P₀ = nothing, \n",
+ " fit_prior_nsteps = 100)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "\n",
+ "# Create an instance of LaplaceRegressor with the Flux model\n",
+ "laplace_regressor = LaplaceRegressor(\n",
+ " model = flux_model,\n",
+ " subset_of_weights = :all,\n",
+ " subnetwork_indices = nothing,\n",
+ " hessian_structure = :full,\n",
+ " backend = :GGN,\n",
+ " σ = 1.0,\n",
+ " μ₀ = 0.0,\n",
+ " P₀ = nothing,\n",
+ " fit_prior_nsteps = 100\n",
+ ")\n",
+ "\n",
+ "# Display the LaplaceRegressor object\n",
+ "laplace_regressor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Tables.MatrixTable{Matrix{Float64}}"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "typeof(X)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
1 | 5.89419 | 12.9979 | 6.18512 | 8.99286 | 2 |
2 | -8.14834 | 6.23246 | -1.68497 | 8.96905 | 1 |
3 | -4.88229 | 5.35276 | -0.45876 | 8.20756 | 1 |
4 | 4.02585 | 6.94769 | 13.4032 | -0.0419223 | 2 |
5 | -5.53635 | 6.55656 | -1.67063 | 8.77041 | 1 |
6 | -6.61858 | 4.65032 | -1.15198 | 8.34897 | 1 |
7 | 10.2344 | 11.4278 | 13.0544 | 8.53025 | 2 |
8 | -6.05052 | 8.12027 | -3.68708 | 8.78732 | 1 |
9 | -5.06769 | 4.8631 | -3.58346 | 8.41371 | 1 |
10 | 10.8373 | 6.32472 | 9.79163 | 6.65962 | 2 |
11 | -6.63226 | 5.45149 | -0.38861 | 9.0007 | 1 |
12 | 1.62812 | 4.61073 | 11.6602 | 11.7241 | 2 |
13 | -6.48679 | 6.68166 | -3.32569 | 9.19618 | 1 |
14 | 7.95596 | 2.23928 | 12.6897 | 1.77857 | 2 |
15 | -6.36466 | 5.82985 | -0.702502 | 8.44976 | 1 |
16 | 8.03294 | 3.85901 | 5.50741 | 2.2014 | 2 |
17 | -7.45067 | 7.01011 | -1.96187 | 7.84336 | 1 |
"
+ ],
+ "text/latex": [
+ "\\begin{tabular}{r|ccccc}\n",
+ "\t& x1 & x2 & x3 & x4 & y\\\\\n",
+ "\t\\hline\n",
+ "\t& Float64 & Float64 & Float64 & Float64 & Cat…\\\\\n",
+ "\t\\hline\n",
+ "\t1 & 5.89419 & 12.9979 & 6.18512 & 8.99286 & 2 \\\\\n",
+ "\t2 & -8.14834 & 6.23246 & -1.68497 & 8.96905 & 1 \\\\\n",
+ "\t3 & -4.88229 & 5.35276 & -0.45876 & 8.20756 & 1 \\\\\n",
+ "\t4 & 4.02585 & 6.94769 & 13.4032 & -0.0419223 & 2 \\\\\n",
+ "\t5 & -5.53635 & 6.55656 & -1.67063 & 8.77041 & 1 \\\\\n",
+ "\t6 & -6.61858 & 4.65032 & -1.15198 & 8.34897 & 1 \\\\\n",
+ "\t7 & 10.2344 & 11.4278 & 13.0544 & 8.53025 & 2 \\\\\n",
+ "\t8 & -6.05052 & 8.12027 & -3.68708 & 8.78732 & 1 \\\\\n",
+ "\t9 & -5.06769 & 4.8631 & -3.58346 & 8.41371 & 1 \\\\\n",
+ "\t10 & 10.8373 & 6.32472 & 9.79163 & 6.65962 & 2 \\\\\n",
+ "\t11 & -6.63226 & 5.45149 & -0.38861 & 9.0007 & 1 \\\\\n",
+ "\t12 & 1.62812 & 4.61073 & 11.6602 & 11.7241 & 2 \\\\\n",
+ "\t13 & -6.48679 & 6.68166 & -3.32569 & 9.19618 & 1 \\\\\n",
+ "\t14 & 7.95596 & 2.23928 & 12.6897 & 1.77857 & 2 \\\\\n",
+ "\t15 & -6.36466 & 5.82985 & -0.702502 & 8.44976 & 1 \\\\\n",
+ "\t16 & 8.03294 & 3.85901 & 5.50741 & 2.2014 & 2 \\\\\n",
+ "\t17 & -7.45067 & 7.01011 & -1.96187 & 7.84336 & 1 \\\\\n",
+ "\\end{tabular}\n"
+ ],
+ "text/plain": [
+ "\u001b[1m17×5 DataFrame\u001b[0m\n",
+ "\u001b[1m Row \u001b[0m│\u001b[1m x1 \u001b[0m\u001b[1m x2 \u001b[0m\u001b[1m x3 \u001b[0m\u001b[1m x4 \u001b[0m\u001b[1m y \u001b[0m\n",
+ " │\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Cat… \u001b[0m\n",
+ "─────┼─────────────────────────────────────────────────\n",
+ " 1 │ 5.89419 12.9979 6.18512 8.99286 2\n",
+ " 2 │ -8.14834 6.23246 -1.68497 8.96905 1\n",
+ " 3 │ -4.88229 5.35276 -0.45876 8.20756 1\n",
+ " 4 │ 4.02585 6.94769 13.4032 -0.0419223 2\n",
+ " 5 │ -5.53635 6.55656 -1.67063 8.77041 1\n",
+ " 6 │ -6.61858 4.65032 -1.15198 8.34897 1\n",
+ " 7 │ 10.2344 11.4278 13.0544 8.53025 2\n",
+ " 8 │ -6.05052 8.12027 -3.68708 8.78732 1\n",
+ " 9 │ -5.06769 4.8631 -3.58346 8.41371 1\n",
+ " 10 │ 10.8373 6.32472 9.79163 6.65962 2\n",
+ " 11 │ -6.63226 5.45149 -0.38861 9.0007 1\n",
+ " 12 │ 1.62812 4.61073 11.6602 11.7241 2\n",
+ " 13 │ -6.48679 6.68166 -3.32569 9.19618 1\n",
+ " 14 │ 7.95596 2.23928 12.6897 1.77857 2\n",
+ " 15 │ -6.36466 5.82985 -0.702502 8.44976 1\n",
+ " 16 │ 8.03294 3.85901 5.50741 2.2014 2\n",
+ " 17 │ -7.45067 7.01011 -1.96187 7.84336 1"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using MLJ, DataFrames\n",
+ "X, y = make_blobs(100, 4; centers=2, cluster_std=[1.0, 3.0 ])\n",
+ "dfBlobs = DataFrame(X)\n",
+ "dfBlobs.y = y\n",
+ "first(dfBlobs, 17)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Chain(\n",
+ " Dense(4 => 10, relu), \u001b[90m# 50 parameters\u001b[39m\n",
+ " Dense(10 => 10, relu), \u001b[90m# 110 parameters\u001b[39m\n",
+ " Dense(10 => 2), \u001b[90m# 22 parameters\u001b[39m\n",
+ ") \u001b[90m # Total: 6 arrays, \u001b[39m182 parameters, 1.086 KiB."
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using Flux\n",
+ "# Define the model\n",
+ "flux_model = Chain(\n",
+ " Dense(4, 10, relu),\n",
+ " Dense(10, 10, relu),\n",
+ " Dense(10, 2)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "LaplaceClassifier(\n",
+ " model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 2)), \n",
+ " flux_loss = Flux.Losses.logitcrossentropy, \n",
+ " optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), \n",
+ " epochs = 1000, \n",
+ " batch_size = 32, \n",
+ " subset_of_weights = :all, \n",
+ " subnetwork_indices = nothing, \n",
+ " hessian_structure = :full, \n",
+ " backend = :GGN, \n",
+ " σ = 1.0, \n",
+ " μ₀ = 0.0, \n",
+ " P₀ = nothing, \n",
+ " fit_prior_nsteps = 100, \n",
+ " link_approx = :probit)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Create an instance of LaplaceRegressor with the Flux model\n",
+ "laplace_classifier = LaplaceClassifier(\n",
+ " model = flux_model,\n",
+ " subset_of_weights = :all,\n",
+ " subnetwork_indices = nothing,\n",
+ " hessian_structure = :full,\n",
+ " backend = :GGN,\n",
+ " σ = 1.0,\n",
+ " μ₀ = 0.0,\n",
+ " P₀ = nothing,\n",
+ " fit_prior_nsteps = 100\n",
+ ")\n",
+ "\n",
+ "# Display the LaplaceRegressor object\n",
+ "laplace_classifier"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "([5.894190948229072 -8.148338397036424 … 4.6901216310216896 8.754850004556681; 12.99791729188348 6.232458473633034 … 10.904369405648966 12.187158715203248; 6.185122057812649 -1.6849673572899138 … 10.09295616393206 15.616135680604215; 8.992861668498659 8.969052840985539 … 8.673387526088336 3.444822170655238], (Bool[1 0 … 1 1; 0 1 … 0 0], CategoricalArrays.CategoricalValue{Int64, UInt32} 2))"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "X_test_reformat,y_test_reformat= MLJBase.reformat(laplace_classifier,X,y)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "2×100 OneHotMatrix(::Vector{UInt32}) with eltype Bool:\n",
+ " 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ … 1 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1\n",
+ " ⋅ 1 1 ⋅ 1 1 ⋅ 1 1 ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 1 1 1 1 1 ⋅ ⋅"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "y_test_reformat[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "2×3 view(OneHotMatrix(::Vector{UInt32}), :, [2, 3, 4]) with eltype Bool:\n",
+ " 0 0 1\n",
+ " 1 1 0"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "view(y_test_reformat[1],:, [2,3,4])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "4×3 view(::Matrix{Float64}, :, [2, 3, 4]) with eltype Float64:\n",
+ " -8.14834 -4.88229 4.02585\n",
+ " 6.23246 5.35276 6.94769\n",
+ " -1.68497 -0.45876 13.4032\n",
+ " 8.96905 8.20756 -0.0419223"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "view(X_test_reformat, :, [2,3,4])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((sepal_length = [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9 … 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9], sepal_width = [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1 … 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0], petal_length = [1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5 … 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1], petal_width = [0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1 … 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]), CategoricalArrays.CategoricalValue{String, UInt32}[\"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\" … \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\"])"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using MLJ\n",
+ "#DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree\n",
+ "\n",
+ "using Flux\n",
+ "# Define the model\n",
+ "flux_model = Chain(\n",
+ " Dense(4, 10, relu),\n",
+ " Dense(10, 10, relu),\n",
+ " Dense(10, 3)\n",
+ ")\n",
+ "\n",
+ "model = LaplaceClassifier(model=flux_model)\n",
+ "\n",
+ "X, y = @load_iris\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "3-element Vector{String}:\n",
+ " \"setosa\"\n",
+ " \"versicolor\"\n",
+ " \"virginica\""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "levels(y)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Training machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:499\n",
+ "┌ Warning: Layer with Float32 parameters got Float64 input.\n",
+ "│ The input will be converted, but any earlier layers may be very slow.\n",
+ "│ layer = Dense(4 => 10, relu)\n",
+ "│ summary(x) = 4×32 Matrix{Float64}\n",
+ "└ @ Flux C:\\Users\\Pasqu\\.julia\\packages\\Flux\\HBF2N\\src\\layers\\stateless.jl:60\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 100: Loss: 4.042439937591553 \n",
+ "Epoch 200: Loss: 3.533539354801178 \n",
+ "Epoch 300: Loss: 3.2006053030490875 \n",
+ "Epoch 400: Loss: 2.9201304614543915 \n",
+ "Epoch 500: Loss: 2.647124856710434 \n",
+ "Epoch 600: Loss: 2.3577273190021515 \n",
+ "Epoch 700: Loss: 2.0284294933080673 \n",
+ "Epoch 800: Loss: 1.6497879326343536 \n",
+ "Epoch 900: Loss: 1.3182911276817322 \n",
+ "Epoch 1000: Loss: 1.077766239643097 \n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n",
+ " args: \n",
+ " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "mach = machine(model, X, y) |> MLJBase.fit!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "3-element Vector{Float64}:\n",
+ " 0.9488185211171959\n",
+ " 0.7685919895442062\n",
+ " 0.9454937120287679"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "Xnew = (sepal_length = [6.4, 7.2, 7.4],\n",
+ " sepal_width = [2.8, 3.0, 2.8],\n",
+ " petal_length = [5.6, 5.8, 6.1],\n",
+ " petal_width = [2.1, 1.6, 1.9],)\n",
+ "yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions\n",
+ "predict_mode(mach, Xnew) # point predictions\n",
+ "pdf.(yhat, \"virginica\") # probabilities for the \"verginica\" class"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "CV(\n",
+ " nfolds = 3, \n",
+ " shuffle = false, \n",
+ " rng = TaskLocalRNG())"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using Random\n",
+ "cv = CV(nfolds=3)\n",
+ "CV(\n",
+ " nfolds = 3,\n",
+ " shuffle = false,\n",
+ " rng = Random.default_rng())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(μ = Float32[0.36796558, -0.590969, -0.0005966219, -0.1918786, -0.21046183, 0.33172563, -0.6454487, -0.43057427, 0.09267368, -0.31252357 … 0.31747338, -0.42893136, -0.036560666, 0.12848923, -0.022290554, -0.17562295, 0.10335993, 2.0584264, -0.27860576, -0.9749556],\n",
+ " H = [19221.80854034424 0.0 … 175.01133728027344 416.5687526464462; 0.0 0.0 … 0.0 0.0; … ; 175.01133728027344 0.0 … 150.0 0.0; 416.5687526464462 0.0 … 0.0 150.0],\n",
+ " P = [19222.80854034424 0.0 … 175.01133728027344 416.5687526464462; 0.0 1.0 … 0.0 0.0; … ; 175.01133728027344 0.0 … 151.0 0.0; 416.5687526464462 0.0 … 0.0 151.0],\n",
+ " Σ = [0.4942412474220449 0.0 … 0.00012692197672594825 -0.00029493662610333626; 0.0 1.0 … -0.0 -0.0; … ; 0.00012692197672495198 0.0 … 0.019500175331854216 2.549733600297664e-5; -0.000294936626108007 0.0 … 2.5497336002974752e-5 0.019426448013349775],\n",
+ " n_data = 160,\n",
+ " n_params = 193,\n",
+ " n_out = 3,\n",
+ " loss = 76.21374633908272,)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "MLJBase.fitted_params(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1100: Loss: 0.9005963057279587 \n",
+ "Epoch 1200: Loss: 0.7674828618764877 \n",
+ "Epoch 1300: Loss: 0.6668129041790962 "
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Epoch 1400: Loss: 0.5900264084339142 \n",
+ "Epoch 1500: Loss: 0.5311783701181412 \n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n",
+ " args: \n",
+ " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.epochs= 1500\n",
+ "uff=MLJBase.fit!(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " updating only the laplace optimization part\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n",
+ " args: \n",
+ " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.fit_prior_nsteps = 200\n",
+ "uff=MLJBase.fit!(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Training machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:499\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 100: Loss: 123.99619626934783 \n",
+ "Epoch 200: Loss: 110.70134709925446 \n",
+ "Epoch 300: Loss: 101.54836914035226 \n",
+ "Epoch 400: Loss: 93.59849316803738 \n",
+ "Epoch 500: Loss: 85.18775669356168 \n",
+ "Epoch 600: Loss: 75.76182416873147 \n",
+ "Epoch 700: Loss: 67.35201676247976 \n",
+ "Epoch 800: Loss: 61.22682561405186 \n",
+ "Epoch 900: Loss: 56.388587609986764 \n",
+ "Epoch 1000: Loss: 51.8525101259686 \n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "1×3 Matrix{Float64}:\n",
+ " -2.0295 -1.03661 1.29323"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "using MLJ\n",
+ "#LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux\n",
+ "flux_model = Chain(\n",
+ " Dense(4, 10, relu),\n",
+ " Dense(10, 10, relu),\n",
+ " Dense(10, 1)\n",
+ ")\n",
+ "model = LaplaceRegressor(model=flux_model)\n",
+ "\n",
+ "X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)\n",
+ "mach = machine(model, X, y) \n",
+ "uff=MLJBase.fit!(mach)\n",
+ "Xnew, _ = make_regression(3, 4; rng=123)\n",
+ "yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions\n",
+ "MLJBase.predict_mode(mach, Xnew) # point predictions\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(μ = Float32[1.0187618, 1.1801981, 1.5047036, 0.09397529, -1.0688609, -0.59974676, 1.2977643, -1.8645017, -0.6493797, 0.8598072 … -3.1050463, -2.267896, 1.018275, -2.0316908, 0.72548866, -2.7309833, -2.1938443, 0.8200029, 1.2517836, 0.3313799],\n",
+ " H = [8124.873962402344 1238.1352081298828 … 1980.7237396240234 324.216007232666; 1238.1352081298828 496.2319145202637 … 451.8493309020996 62.39444828033447; … ; 1980.7237396240234 451.8493309020996 … 1277.6469421386719 212.83043670654297; 324.216007232666 62.39444828033447 … 212.83043670654297 100.0],\n",
+ " P = [8125.873962402344 1238.1352081298828 … 1980.7237396240234 324.216007232666; 1238.1352081298828 497.2319145202637 … 451.8493309020996 62.39444828033447; … ; 1980.7237396240234 451.8493309020996 … 1278.6469421386719 212.83043670654297; 324.216007232666 62.39444828033447 … 212.83043670654297 101.0],\n",
+ " Σ = [0.08818398027482557 0.01068994964485469 … -0.028911073548363073 0.0029380249323247053; 0.01068994964484666 0.3740392430965252 … -0.03321343036791264 0.03808189340124685; … ; -0.02891107354841918 -0.03321343036788035 … 0.37993018751981145 -0.007007167070441715; 0.002938024932320219 0.03808189340124439 … -0.007007167070367749 0.8779610764649773],\n",
+ " n_data = 128,\n",
+ " n_params = 171,\n",
+ " n_out = 1,\n",
+ " loss = 25.89132467732224,)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "MLJBase.fitted_params(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1000-element Vector{Any}:\n",
+ " 132.4757012476366\n",
+ " 132.34841820836348\n",
+ " 132.2594675536756\n",
+ " 132.1764167781873\n",
+ " 132.09720400291872\n",
+ " 132.02269947440837\n",
+ " 131.95014241005182\n",
+ " 131.87929837827625\n",
+ " 131.8104252672236\n",
+ " 131.74319014461662\n",
+ " ⋮\n",
+ " 52.175198365270234\n",
+ " 52.148401306572445\n",
+ " 52.10149557862212\n",
+ " 52.032551867920525\n",
+ " 51.99690913470439\n",
+ " 51.994054497037524\n",
+ " 51.97321907106968\n",
+ " 51.932579915942064\n",
+ " 51.8525101259686"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "a= MLJBase.training_losses(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1100: Loss: 47.37399118448011 \n",
+ "Epoch 1200: Loss: 42.67007994813126 \n",
+ "Epoch 1300: Loss: 38.057605481177234 \n",
+ "Epoch 1400: Loss: 33.846777755430416 \n",
+ "Epoch 1500: Loss: 29.514834265471144 \n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n",
+ " args: \n",
+ " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.epochs= 1500\n",
+ "uff=MLJBase.fit!(mach)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1500-element Vector{Any}:\n",
+ " 132.4757012476366\n",
+ " 132.34841820836348\n",
+ " 132.2594675536756\n",
+ " 132.1764167781873\n",
+ " 132.09720400291872\n",
+ " 132.02269947440837\n",
+ " 131.95014241005182\n",
+ " 131.87929837827625\n",
+ " 131.8104252672236\n",
+ " 131.74319014461662\n",
+ " ⋮\n",
+ " 29.979098347946184\n",
+ " 29.896050775635096\n",
+ " 29.83609799111873\n",
+ " 29.763058304997937\n",
+ " 29.69908817371154\n",
+ " 29.646864259231556\n",
+ " 29.623977677653713\n",
+ " 29.549562760589527\n",
+ " 29.514834265471144"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "MLJBase.training_losses(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n",
+ " args: \n",
+ " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.epochs= 1200\n",
+ "uff=MLJBase.fit!(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " updating only the laplace optimization part\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n",
+ "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "trained Machine; caches model-specific representations of data\n",
+ " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n",
+ " args: \n",
+ " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n",
+ " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model.fit_prior_nsteps = 200\n",
+ "uff=MLJBase.fit!(mach)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MLJBase.selectrows(model, I, Xmatrix, y) = (view(Xmatrix, :, I), view(y[1], I))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#MLJBase.evaluate(model, X, y, resampling=cv, measure=l2, verbosity=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "BoundsError",
+ "evalue": "BoundsError: attempt to access 4×100 Matrix{Float64} at index [[35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], 1:100]",
+ "output_type": "error",
+ "traceback": [
+ "BoundsError: attempt to access 4×100 Matrix{Float64} at index [[35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], 1:100]\n",
+ "\n",
+ "Stacktrace:\n",
+ " [1] throw_boundserror(A::Matrix{Float64}, I::Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}})\n",
+ " @ Base .\\abstractarray.jl:737\n",
+ " [2] checkbounds\n",
+ " @ .\\abstractarray.jl:702 [inlined]\n",
+ " [3] _getindex\n",
+ " @ .\\multidimensional.jl:888 [inlined]\n",
+ " [4] getindex\n",
+ " @ .\\abstractarray.jl:1291 [inlined]\n",
+ " [5] _selectrows(::MLJModelInterface.FullInterface, ::Val{:other}, X::Matrix{Float64}, r::Vector{Int64})\n",
+ " @ MLJModelInterface C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:350\n",
+ " [6] selectrows\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:340 [inlined]\n",
+ " [7] selectrows\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:336 [inlined]\n",
+ " [8] #29\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\model_api.jl:88 [inlined]\n",
+ " [9] map\n",
+ " @ .\\tuple.jl:292 [inlined]\n",
+ " [10] selectrows(::LaplaceRegressor, ::Vector{Int64}, ::Matrix{Float64}, ::Tuple{Matrix{Float64}, Nothing})\n",
+ " @ MLJModelInterface C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\model_api.jl:88\n",
+ " [11] fit_only!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}; rows::Vector{Int64}, verbosity::Int64, force::Bool, composite::Nothing)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:676\n",
+ " [12] fit_only!\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:617 [inlined]\n",
+ " [13] #fit!#63\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:789 [inlined]\n",
+ " [14] fit!\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:786 [inlined]\n",
+ " [15] fit_and_extract_on_fold\n",
+ " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1463 [inlined]\n",
+ " [16] (::MLJBase.var\"#277#278\"{MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, Machine{LaplaceRegressor, LaplaceRegressor, true}, Int64})(k::Int64)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1289\n",
+ " [17] _mapreduce(f::MLJBase.var\"#277#278\"{MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, Machine{LaplaceRegressor, LaplaceRegressor, true}, Int64}, op::typeof(vcat), ::IndexLinear, A::UnitRange{Int64})\n",
+ " @ Base .\\reduce.jl:440\n",
+ " [18] _mapreduce_dim\n",
+ " @ .\\reducedim.jl:365 [inlined]\n",
+ " [19] mapreduce\n",
+ " @ .\\reducedim.jl:357 [inlined]\n",
+ " [20] _evaluate!(func::MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, mach::Machine{LaplaceRegressor, LaplaceRegressor, true}, ::CPU1{Nothing}, nfolds::Int64, verbosity::Int64)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1288\n",
+ " [21] evaluate!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}, resampling::Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, weights::Nothing, class_weights::Nothing, rows::Nothing, verbosity::Int64, repeats::Int64, measures::Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, operations::Vector{typeof(predict_mean)}, acceleration::CPU1{Nothing}, force::Bool, per_observation_flag::Bool, logger::Nothing, user_resampling::CV, compact::Bool)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1510\n",
+ " [22] evaluate!(::Machine{LaplaceRegressor, LaplaceRegressor, true}, ::CV, ::Nothing, ::Nothing, ::Nothing, ::Int64, ::Int64, ::Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, ::Vector{typeof(predict_mean)}, ::CPU1{Nothing}, ::Bool, ::Bool, ::Nothing, ::CV, ::Bool)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1603\n",
+ " [23] evaluate!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}; resampling::CV, measures::Nothing, measure::StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}, weights::Nothing, class_weights::Nothing, operations::Nothing, operation::Nothing, acceleration::CPU1{Nothing}, rows::Nothing, repeats::Int64, force::Bool, check_measure::Bool, per_observation::Bool, verbosity::Int64, logger::Nothing, compact::Bool)\n",
+ " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1232\n",
+ " [24] top-level scope\n",
+ " @ c:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\dev\\notebooks\\mlj-interfacing\\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X54sZmlsZQ==.jl:1"
+ ]
+ }
+ ],
+ "source": [
+ "evaluate!(mach, resampling=cv, measure=l2, verbosity=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#flux_model = Chain(\n",
+ " #Dense(4, 10, relu),\n",
+ " #Dense(10, 10, relu),\n",
+ " #Dense(10, 1))\n",
+ "\n",
+ "\n",
+ "#nested_flux_model = Chain(\n",
+ " #Chain(Dense(10, 5, relu), Dense(5, 5, relu)),\n",
+ " #Chain(Dense(5, 3, relu), Dense(3, 2)))\n",
+ "#model = LaplaceRegressor()\n",
+ "#model.model= nested_flux_model\n",
+ "\n",
+ "#copy_model= deepcopy(model)\n",
+ "\n",
+ "#copy_model.epochs= 2000\n",
+ "#copy_model.optimiser = Descent()\n",
+ "#MLJBase.is_same_except(model , copy_model,:epochs )"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Julia 1.10.5",
+ "language": "julia",
+ "name": "julia-1.10"
+ },
+ "language_info": {
+ "file_extension": ".jl",
+ "mimetype": "application/julia",
+ "name": "julia",
+ "version": "1.10.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/Manifest.toml b/docs/Manifest.toml
index 6d42a1e2..d01b6356 100644
--- a/docs/Manifest.toml
+++ b/docs/Manifest.toml
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised
-julia_version = "1.10.5"
+julia_version = "1.11.1"
manifest_format = "2.0"
-project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca"
+project_hash = "aafefca0e42df191a3fa625af8a8b686a9db4944"
[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
@@ -26,31 +26,35 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.4.5"
[[deps.Accessors]]
-deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"]
-git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a"
+deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"]
+git-tree-sha1 = "b392ede862e506d451fc1616e79aa6f4c673dab8"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
-version = "0.1.37"
+version = "0.1.38"
[deps.Accessors.extensions]
AccessorsAxisKeysExt = "AxisKeys"
+ AccessorsDatesExt = "Dates"
AccessorsIntervalSetsExt = "IntervalSets"
AccessorsStaticArraysExt = "StaticArrays"
AccessorsStructArraysExt = "StructArrays"
+ AccessorsTestExt = "Test"
AccessorsUnitfulExt = "Unitful"
[deps.Accessors.weakdeps]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
+ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
+ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
-git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099"
+git-tree-sha1 = "d80af0733c99ea80575f612813fa6aa71022d33a"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
-version = "4.0.4"
+version = "4.1.0"
weakdeps = ["StaticArrays"]
[deps.Adapt.extensions]
@@ -69,7 +73,7 @@ version = "2.3.0"
[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
-version = "1.1.1"
+version = "1.1.2"
[[deps.Arpack]]
deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"]
@@ -115,6 +119,7 @@ version = "7.16.0"
[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+version = "1.11.0"
[[deps.Atomix]]
deps = ["UnsafeAtomics"]
@@ -157,6 +162,7 @@ version = "0.4.3"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
+version = "1.11.0"
[[deps.Baselet]]
git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
@@ -170,9 +176,9 @@ version = "0.1.9"
[[deps.Bzip2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd"
+git-tree-sha1 = "8873e196c2eb87962a2048b3b8e08946535864a1"
uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0"
-version = "1.0.8+1"
+version = "1.0.8+2"
[[deps.CEnum]]
git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc"
@@ -181,15 +187,15 @@ version = "0.5.0"
[[deps.CSV]]
deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
-git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab"
+git-tree-sha1 = "deddd8725e5e1cc49ee205a1964256043720a6c3"
uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
-version = "0.10.14"
+version = "0.10.15"
[[deps.CUDA]]
-deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"]
-git-tree-sha1 = "fdd9dfb67dfefd548f51000cc400bb51003de247"
+deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics", "demumble_jll"]
+git-tree-sha1 = "e0725a467822697171af4dae15cec10b4fc19053"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
-version = "5.4.3"
+version = "5.5.2"
[deps.CUDA.extensions]
ChainRulesCoreExt = "ChainRulesCore"
@@ -203,9 +209,9 @@ version = "5.4.3"
[[deps.CUDA_Driver_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "325058b426c2b421e3d2df3d5fa646d72d2e3e7e"
+git-tree-sha1 = "ccd1e54610c222fadfd4737dac66bff786f63656"
uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
-version = "0.9.2+0"
+version = "0.10.3+0"
[[deps.CUDA_Runtime_Discovery]]
deps = ["Libdl"]
@@ -215,21 +221,21 @@ version = "0.3.5"
[[deps.CUDA_Runtime_jll]]
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "afea94249b821dc754a8ca6695d3daed851e1f5a"
+git-tree-sha1 = "e43727b237b2879a34391eeb81887699a26f8f2f"
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
-version = "0.14.1+0"
+version = "0.15.3+0"
[[deps.CUDNN_jll]]
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4"
+git-tree-sha1 = "9851af16a2f357a793daa0f13634c82bc7e40419"
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
-version = "9.0.0+1"
+version = "9.4.0+0"
[[deps.Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd"
+git-tree-sha1 = "009060c9a6168704143100f36ab08f06c2af4642"
uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a"
-version = "1.18.0+2"
+version = "1.18.2+1"
[[deps.CategoricalArrays]]
deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"]
@@ -263,15 +269,15 @@ version = "0.1.15"
[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"]
-git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03"
+git-tree-sha1 = "be227d253d132a6d57f9ccf5f67c0fb6488afd87"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
-version = "1.69.0"
+version = "1.71.0"
[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
-git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f"
+git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
-version = "1.24.0"
+version = "1.25.0"
weakdeps = ["SparseArrays"]
[deps.ChainRulesCore.extensions]
@@ -285,9 +291,9 @@ version = "0.7.6"
[[deps.ColorSchemes]]
deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"]
-git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0"
+git-tree-sha1 = "13951eb68769ad1cd460cdb2e64e5e95f1bf123d"
uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
-version = "3.26.0"
+version = "3.27.0"
[[deps.ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
@@ -411,10 +417,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.16.0"
[[deps.DataFrames]]
-deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
-git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8"
+deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
+git-tree-sha1 = "fb61b4812c49343d7ef0b533ba982c46021938a6"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
-version = "1.6.1"
+version = "1.7.0"
[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
@@ -430,6 +436,7 @@ version = "1.0.0"
[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
+version = "1.11.0"
[[deps.Dbus_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"]
@@ -486,12 +493,13 @@ weakdeps = ["ChainRulesCore", "SparseArrays"]
[[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
+version = "1.11.0"
[[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
-git-tree-sha1 = "e6c693a0e4394f8fda0e51a5bdf5aef26f8235e9"
+git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
-version = "0.25.111"
+version = "0.25.112"
weakdeps = ["ChainRulesCore", "DensityInterface", "Test"]
[deps.Distributions.extensions]
@@ -499,6 +507,12 @@ weakdeps = ["ChainRulesCore", "DensityInterface", "Test"]
DistributionsDensityInterfaceExt = "DensityInterface"
DistributionsTestExt = "Test"
+[[deps.DocInventories]]
+deps = ["CodecZlib", "Downloads", "TOML"]
+git-tree-sha1 = "e97cfa8680a39396924dcdca4b7ff1014ed5c499"
+uuid = "43dc2714-ed3b-44b5-b226-857eda1aa7de"
+version = "1.0.0"
+
[[deps.DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
@@ -507,9 +521,21 @@ version = "0.9.3"
[[deps.Documenter]]
deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"]
-git-tree-sha1 = "9d29b99b6b2b6bc2382a4c8dbec6eb694f389853"
+git-tree-sha1 = "5a1ee886566f2fa9318df1273d8b778b9d42712d"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
-version = "1.6.0"
+version = "1.7.0"
+
+[[deps.DocumenterInterLinks]]
+deps = ["CodecZlib", "DocInventories", "Documenter", "DocumenterInventoryWritingBackport", "Markdown", "MarkdownAST", "TOML"]
+git-tree-sha1 = "00dceb038f6cb24f4d8d6a9f2feb85bbe58305fd"
+uuid = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
+version = "1.0.0"
+
+[[deps.DocumenterInventoryWritingBackport]]
+deps = ["CodecZlib", "Documenter", "TOML"]
+git-tree-sha1 = "1b89024e375353961bb98b9818b44a4e38961cc4"
+uuid = "195adf08-069f-4855-af3e-8933a2cdae94"
+version = "0.1.0"
[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
@@ -541,9 +567,9 @@ version = "0.1.10"
[[deps.FFMPEG]]
deps = ["FFMPEG_jll"]
-git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8"
+git-tree-sha1 = "53ebe7511fa11d33bec688a9178fac4e49eeee00"
uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a"
-version = "0.4.1"
+version = "0.4.2"
[[deps.FFMPEG_jll]]
deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"]
@@ -565,9 +591,9 @@ version = "0.1.1"
[[deps.FileIO]]
deps = ["Pkg", "Requires", "UUIDs"]
-git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322"
+git-tree-sha1 = "62ca0547a14c57e98154423419d8a342dca75ca9"
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
-version = "1.16.3"
+version = "1.16.4"
[[deps.FilePathsBase]]
deps = ["Compat", "Dates"]
@@ -582,6 +608,7 @@ weakdeps = ["Mmap", "Test"]
[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
+version = "1.11.0"
[[deps.FillArrays]]
deps = ["LinearAlgebra"]
@@ -596,19 +623,21 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"]
FillArraysStatisticsExt = "Statistics"
[[deps.FiniteDiff]]
-deps = ["ArrayInterface", "LinearAlgebra", "Setfield", "SparseArrays"]
-git-tree-sha1 = "f9219347ebf700e77ca1d48ef84e4a82a6701882"
+deps = ["ArrayInterface", "LinearAlgebra", "Setfield"]
+git-tree-sha1 = "b10bdafd1647f57ace3885143936749d61638c3b"
uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
-version = "2.24.0"
+version = "2.26.0"
[deps.FiniteDiff.extensions]
FiniteDiffBandedMatricesExt = "BandedMatrices"
FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices"
+ FiniteDiffSparseArraysExt = "SparseArrays"
FiniteDiffStaticArraysExt = "StaticArrays"
[deps.FiniteDiff.weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
+ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[[deps.FixedPointNumbers]]
@@ -618,23 +647,27 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.5"
[[deps.Flux]]
-deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
-git-tree-sha1 = "fbf100b4bed74c9b6fac0ebd1031e04977d35b3b"
+deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
+git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.14.19"
+version = "0.14.22"
[deps.Flux.extensions]
FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxEnzymeExt = "Enzyme"
+ FluxMPIExt = "MPI"
+ FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
FluxMetalExt = "Metal"
[deps.Flux.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
+ NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[[deps.Fontconfig_jll]]
@@ -679,6 +712,7 @@ version = "0.4.12"
[[deps.Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
+version = "1.11.0"
[[deps.GLFW_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"]
@@ -699,22 +733,22 @@ uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
version = "0.1.6"
[[deps.GPUCompiler]]
-deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"]
-git-tree-sha1 = "ab29216184312f99ff957b32cd63c2fe9c928b91"
+deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "PrecompileTools", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"]
+git-tree-sha1 = "1d6f290a5eb1201cd63574fbc4440c788d5cb38f"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
-version = "0.26.7"
+version = "0.27.8"
[[deps.GR]]
deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"]
-git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5"
+git-tree-sha1 = "ee28ddcd5517d54e417182fec3886e7412d3926f"
uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
-version = "0.73.7"
+version = "0.73.8"
[[deps.GR_jll]]
deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d"
+git-tree-sha1 = "f31929b9e67066bee48eec8b03c0df47d31a74b3"
uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
-version = "0.73.7+0"
+version = "0.73.8+0"
[[deps.Gettext_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"]
@@ -730,15 +764,15 @@ version = "1.3.1"
[[deps.Git_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"]
-git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809"
+git-tree-sha1 = "ea372033d09e4552a04fd38361cd019f9003f4f4"
uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb"
-version = "2.44.0+2"
+version = "2.46.2+0"
[[deps.Glib_jll]]
deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"]
-git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba"
+git-tree-sha1 = "674ff0db93fffcd11a3573986e550d66cd4fd71f"
uuid = "7746bdde-850d-59dc-9ae8-88ece973131d"
-version = "2.80.2+0"
+version = "2.80.5+0"
[[deps.Graphite2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -814,11 +848,12 @@ version = "1.4.2"
[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
+version = "1.11.0"
[[deps.InverseFunctions]]
-git-tree-sha1 = "2787db24f4e03daf859c6509ff87764e4182f7d1"
+git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
-version = "0.1.16"
+version = "0.1.17"
weakdeps = ["Dates", "Test"]
[deps.InverseFunctions.extensions]
@@ -848,9 +883,9 @@ version = "1.0.0"
[[deps.JLD2]]
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
-git-tree-sha1 = "a0746c21bdc986d0dc293efa6b1faee112c37c28"
+git-tree-sha1 = "b464b9b461ee989b435a689a4f7d870b68d467ed"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
-version = "0.4.53"
+version = "0.5.6"
[[deps.JLFzf]]
deps = ["Pipe", "REPL", "Random", "fzf_jll"]
@@ -860,9 +895,9 @@ version = "0.1.8"
[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
-git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40"
+git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
-version = "1.6.0"
+version = "1.6.1"
[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
@@ -872,9 +907,9 @@ version = "0.21.4"
[[deps.JpegTurbo_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637"
+git-tree-sha1 = "25ee0be4d43d0269027024d75a24c24d6c6e590c"
uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8"
-version = "3.0.3+0"
+version = "3.0.4+0"
[[deps.JuliaNVTXCallbacks_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -890,9 +925,9 @@ version = "0.2.4"
[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
-git-tree-sha1 = "cb1cff88ef2f3a157cbad75bbe6b229e1975e498"
+git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
-version = "0.9.25"
+version = "0.9.28"
[deps.KernelAbstractions.extensions]
EnzymeExt = "EnzymeCore"
@@ -911,16 +946,16 @@ uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d"
version = "3.100.2+0"
[[deps.LERC_jll]]
-deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434"
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "36bdbc52f13a7d1dcb0f3cd694e01677a515655b"
uuid = "88015f11-f218-50d7-93a8-a6af411a945d"
-version = "3.0.0+1"
+version = "4.0.0+0"
[[deps.LLVM]]
-deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"]
-git-tree-sha1 = "2470e69781ddd70b8878491233cd09bc1bd7fc96"
+deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Unicode"]
+git-tree-sha1 = "d422dfd9707bec6617335dc2ea3c5172a87d5908"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
-version = "8.1.0"
+version = "9.1.3"
weakdeps = ["BFloat16s"]
[deps.LLVM.extensions]
@@ -928,9 +963,9 @@ weakdeps = ["BFloat16s"]
[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "597d1c758c9ae5d985ba4202386a607c675ee700"
+git-tree-sha1 = "05a8bd5a42309a9ec82f700876903abce1017dd3"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
-version = "0.0.31+0"
+version = "0.0.34+0"
[[deps.LLVMLoopInfo]]
git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea"
@@ -939,26 +974,26 @@ version = "1.0.0"
[[deps.LLVMOpenMP_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "e16271d212accd09d52ee0ae98956b8a05c4b626"
+git-tree-sha1 = "78211fb6cbc872f77cad3fc0b6cf647d923f4929"
uuid = "1d63c593-3942-5779-bab2-d838dc0a180e"
-version = "17.0.6+0"
+version = "18.1.7+0"
[[deps.LZO_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d"
+git-tree-sha1 = "854a9c268c43b77b0a27f22d7fab8d33cdb3a731"
uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac"
-version = "2.10.2+0"
+version = "2.10.2+1"
[[deps.LaTeXStrings]]
-git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
+git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c"
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
-version = "1.3.1"
+version = "1.4.0"
[[deps.LaplaceRedux]]
-deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
-git-tree-sha1 = "a84b72a27c93c72a6af5d22216eb81a419b1b97a"
+deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
+path = ".."
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
-version = "1.0.2"
+version = "1.1.1"
[[deps.Latexify]]
deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"]
@@ -984,6 +1019,7 @@ version = "1.2.2"
[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
+version = "1.11.0"
[[deps.LearnAPI]]
deps = ["InteractiveUtils", "Statistics"]
@@ -999,16 +1035,17 @@ version = "0.6.4"
[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
-version = "8.4.0+0"
+version = "8.6.0+0"
[[deps.LibGit2]]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
+version = "1.11.0"
[[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
-version = "1.6.4+0"
+version = "1.7.2+0"
[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
@@ -1017,6 +1054,7 @@ version = "1.11.0+1"
[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
+version = "1.11.0"
[[deps.Libffi_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1056,9 +1094,9 @@ version = "2.40.1+0"
[[deps.Libtiff_jll]]
deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"]
-git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a"
+git-tree-sha1 = "b404131d06f7886402758c9ce2214b636eb4d54a"
uuid = "89763e89-9b03-5906-acba-b20f662cd828"
-version = "4.5.1+1"
+version = "4.7.0+0"
[[deps.Libuuid_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
@@ -1075,6 +1113,7 @@ version = "7.3.0"
[[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+version = "1.11.0"
[[deps.LinearMaps]]
deps = ["LinearAlgebra"]
@@ -1106,6 +1145,7 @@ version = "0.3.28"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
+version = "1.11.0"
[[deps.LoggingExtras]]
deps = ["Dates", "Logging"]
@@ -1113,6 +1153,46 @@ git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075"
uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
version = "1.0.3"
+[[deps.MLDataDevices]]
+deps = ["Adapt", "Compat", "Functors", "LinearAlgebra", "Preferences", "Random"]
+git-tree-sha1 = "3207c2e66164e6366440ad3f0243a8d67abb4a47"
+uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
+version = "1.4.1"
+
+ [deps.MLDataDevices.extensions]
+ MLDataDevicesAMDGPUExt = "AMDGPU"
+ MLDataDevicesCUDAExt = "CUDA"
+ MLDataDevicesChainRulesCoreExt = "ChainRulesCore"
+ MLDataDevicesFillArraysExt = "FillArrays"
+ MLDataDevicesGPUArraysExt = "GPUArrays"
+ MLDataDevicesMLUtilsExt = "MLUtils"
+ MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
+ MLDataDevicesReactantExt = "Reactant"
+ MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
+ MLDataDevicesReverseDiffExt = "ReverseDiff"
+ MLDataDevicesSparseArraysExt = "SparseArrays"
+ MLDataDevicesTrackerExt = "Tracker"
+ MLDataDevicesZygoteExt = "Zygote"
+ MLDataDevicescuDNNExt = ["CUDA", "cuDNN"]
+ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
+
+ [deps.MLDataDevices.weakdeps]
+ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
+ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
+ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
+ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
+ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
+ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
+ Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
+ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
+ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
+ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
+ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
+
[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9"
@@ -1175,6 +1255,7 @@ version = "0.5.13"
[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
+version = "1.11.0"
[[deps.MarkdownAST]]
deps = ["AbstractTrees", "Markdown"]
@@ -1191,7 +1272,7 @@ version = "1.1.9"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
-version = "2.28.2+1"
+version = "2.28.6+0"
[[deps.Measures]]
git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102"
@@ -1200,9 +1281,9 @@ version = "0.3.2"
[[deps.Metalhead]]
deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"]
-git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152"
+git-tree-sha1 = "aef476e4958303f5ea9e1deb81a1ba2f510d4e11"
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
-version = "0.9.3"
+version = "0.9.4"
weakdeps = ["CUDA"]
[deps.Metalhead.extensions]
@@ -1222,6 +1303,7 @@ version = "1.2.0"
[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
+version = "1.11.0"
[[deps.Mocking]]
deps = ["Compat", "ExprTools"]
@@ -1231,7 +1313,7 @@ version = "0.8.1"
[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
-version = "2023.1.10"
+version = "2023.12.12"
[[deps.MultivariateStats]]
deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"]
@@ -1246,10 +1328,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c"
version = "7.8.3"
[[deps.NNlib]]
-deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
-git-tree-sha1 = "ae52c156a63bb647f80c26319b104e99e5977e51"
+deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"]
+git-tree-sha1 = "da09a1e112fd75f9af2a5229323f01b56ec96a4c"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.9.22"
+version = "0.9.24"
[deps.NNlib.extensions]
NNlibAMDGPUExt = "AMDGPU"
@@ -1257,12 +1339,14 @@ version = "0.9.22"
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"
+ NNlibForwardDiffExt = "ForwardDiff"
[deps.NNlib.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
+ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[[deps.NVTX]]
@@ -1302,9 +1386,9 @@ version = "0.2.3"
[[deps.NearestNeighbors]]
deps = ["Distances", "StaticArrays"]
-git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0"
+git-tree-sha1 = "3cebfc94a0754cc329ebc3bab1e6c89621e791ad"
uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
-version = "0.4.18"
+version = "0.4.20"
[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
@@ -1325,7 +1409,7 @@ version = "0.2.5"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
-version = "0.3.23+4"
+version = "0.3.27+1"
[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
@@ -1340,9 +1424,9 @@ version = "1.4.3"
[[deps.OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5"
+git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10"
uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
-version = "3.0.14+0"
+version = "3.0.15+1"
[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
@@ -1432,9 +1516,13 @@ uuid = "30392449-352a-5448-841d-b1acce4e97dc"
version = "0.43.4+0"
[[deps.Pkg]]
-deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
+deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
-version = "1.10.0"
+version = "1.11.0"
+weakdeps = ["REPL"]
+
+ [deps.Pkg.extensions]
+ REPLExt = "REPL"
[[deps.PlotThemes]]
deps = ["PlotUtils", "Statistics"]
@@ -1443,10 +1531,10 @@ uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a"
version = "3.2.0"
[[deps.PlotUtils]]
-deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"]
-git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5"
+deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"]
+git-tree-sha1 = "650a022b2ce86c7dcfbdecf00f78afeeb20e5655"
uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043"
-version = "1.4.1"
+version = "1.4.2"
[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"]
@@ -1499,13 +1587,14 @@ version = "0.2.0"
[[deps.PrettyTables]]
deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"]
-git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7"
+git-tree-sha1 = "1101cd475833706e4d0e7b122218257178f48f34"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
-version = "2.3.2"
+version = "2.4.0"
[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
+version = "1.11.0"
[[deps.ProgressLogging]]
deps = ["Logging", "SHA", "UUIDs"]
@@ -1550,9 +1639,9 @@ version = "6.7.1+1"
[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
-git-tree-sha1 = "1d587203cf851a51bf1ea31ad7ff89eff8d625ea"
+git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
-version = "2.11.0"
+version = "2.11.1"
[deps.QuadGK.extensions]
QuadGKEnzymeExt = "Enzyme"
@@ -1573,12 +1662,14 @@ uuid = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
version = "0.7.7"
[[deps.REPL]]
-deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
+deps = ["InteractiveUtils", "Markdown", "Sockets", "StyledStrings", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
+version = "1.11.0"
[[deps.Random]]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+version = "1.11.0"
[[deps.Random123]]
deps = ["Random", "RandomNumbers"]
@@ -1647,15 +1738,15 @@ version = "1.3.0"
[[deps.Rmath]]
deps = ["Random", "Rmath_jll"]
-git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b"
+git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
-version = "0.7.1"
+version = "0.8.0"
[[deps.Rmath_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "e60724fd3beea548353984dc61c943ecddb0e29a"
+git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
-version = "0.4.3+0"
+version = "0.5.1+0"
[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
@@ -1692,6 +1783,7 @@ version = "1.4.5"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+version = "1.11.0"
[[deps.Setfield]]
deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"]
@@ -1711,9 +1803,9 @@ uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f"
version = "1.0.3"
[[deps.SimpleBufferStream]]
-git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1"
+git-tree-sha1 = "f305871d2f381d21527c770d4788c06c097c9bc1"
uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7"
-version = "1.1.0"
+version = "1.2.0"
[[deps.SimpleTraits]]
deps = ["InteractiveUtils", "MacroTools"]
@@ -1723,6 +1815,7 @@ version = "0.9.4"
[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
+version = "1.11.0"
[[deps.SortingAlgorithms]]
deps = ["DataStructures"]
@@ -1733,7 +1826,7 @@ version = "1.2.1"
[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
-version = "1.10.0"
+version = "1.11.0"
[[deps.SparseInverseSubset]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
@@ -1765,9 +1858,9 @@ version = "1.0.2"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
-git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50"
+git-tree-sha1 = "777657803913ffc7e8cc20f0fd04b634f871af8f"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
-version = "1.9.7"
+version = "1.9.8"
weakdeps = ["ChainRulesCore", "Statistics"]
[deps.StaticArrays.extensions]
@@ -1792,9 +1885,14 @@ uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9"
version = "3.4.0"
[[deps.Statistics]]
-deps = ["LinearAlgebra", "SparseArrays"]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0"
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
-version = "1.10.0"
+version = "1.11.1"
+weakdeps = ["SparseArrays"]
+
+ [deps.Statistics.extensions]
+ SparseArraysExt = ["SparseArrays"]
[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
@@ -1810,9 +1908,9 @@ version = "0.34.3"
[[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
-git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a"
+git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
-version = "1.3.1"
+version = "1.3.2"
weakdeps = ["ChainRulesCore", "InverseFunctions"]
[deps.StatsFuns.extensions]
@@ -1821,9 +1919,9 @@ weakdeps = ["ChainRulesCore", "InverseFunctions"]
[[deps.StringManipulation]]
deps = ["PrecompileTools"]
-git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5"
+git-tree-sha1 = "a6b1675a536c5ad1a60e5a5153e1fee12eb146e3"
uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e"
-version = "0.3.4"
+version = "0.4.0"
[[deps.StructArrays]]
deps = ["ConstructionBase", "DataAPI", "Tables"]
@@ -1838,6 +1936,10 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"]
StructArraysSparseArraysExt = "SparseArrays"
StructArraysStaticArraysExt = "StaticArrays"
+[[deps.StyledStrings]]
+uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
+version = "1.11.0"
+
[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
@@ -1845,7 +1947,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
-version = "7.2.1+1"
+version = "7.7.0+0"
[[deps.TOML]]
deps = ["Dates"]
@@ -1854,9 +1956,9 @@ version = "1.0.3"
[[deps.TZJData]]
deps = ["Artifacts"]
-git-tree-sha1 = "1607ad46cf8d642aa779a1d45af1c8620dbf6915"
+git-tree-sha1 = "36b40607bf2bf856828690e097e1c799623b0602"
uuid = "dc5dba14-91b3-4cab-a142-028a31da12f7"
-version = "1.2.0+2024a"
+version = "1.3.0+2024b"
[[deps.TableTraits]]
deps = ["IteratorInterfaceExtensions"]
@@ -1871,16 +1973,15 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.12.0"
[[deps.TaijaBase]]
-deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"]
-git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0"
+git-tree-sha1 = "4076f60078b12095ca71a2c26e2e4515e3a6a5e5"
uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6"
-version = "1.2.2"
+version = "1.2.3"
[[deps.TaijaPlotting]]
-deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots", "Trapz"]
-git-tree-sha1 = "2fc71041e1c215cf6ef3dc2d3b8419499c4b40ff"
+deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "OneHotArrays", "Plots", "RecipesBase", "Trapz"]
+git-tree-sha1 = "01c76535e8b87b05d0fbc275d44714c78e49fe6e"
uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240"
-version = "1.2.0"
+version = "1.3.0"
[[deps.Tar]]
deps = ["ArgTools", "SHA"]
@@ -1896,6 +1997,7 @@ version = "0.1.1"
[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+version = "1.11.0"
[[deps.ThreadsX]]
deps = ["Accessors", "ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "SplittablesBase", "Transducers"]
@@ -1905,9 +2007,9 @@ version = "0.1.12"
[[deps.TimeZones]]
deps = ["Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"]
-git-tree-sha1 = "b92aebdd3555f3a7e3267cf17702033c2814ef48"
+git-tree-sha1 = "8323074bc977aa85cf5ad71099a83ac75b0ac107"
uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53"
-version = "1.18.0"
+version = "1.18.1"
weakdeps = ["RecipesBase"]
[deps.TimeZones.extensions]
@@ -1915,22 +2017,23 @@ weakdeps = ["RecipesBase"]
[[deps.TimerOutputs]]
deps = ["ExprTools", "Printf"]
-git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531"
+git-tree-sha1 = "3a6f063d690135f5c1ba351412c82bae4d1402bf"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
-version = "0.5.24"
+version = "0.5.25"
[[deps.TranscodingStreams]]
-git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2"
+git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
-version = "0.11.2"
+version = "0.11.3"
[[deps.Transducers]]
-deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"]
-git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23"
+deps = ["Accessors", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"]
+git-tree-sha1 = "7deeab4ff96b85c5f72c824cae53a1398da3d1cb"
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
-version = "0.4.82"
+version = "0.4.84"
[deps.Transducers.extensions]
+ TransducersAdaptExt = "Adapt"
TransducersBlockArraysExt = "BlockArrays"
TransducersDataFramesExt = "DataFrames"
TransducersLazyArraysExt = "LazyArrays"
@@ -1938,6 +2041,7 @@ version = "0.4.82"
TransducersReferenceablesExt = "Referenceables"
[deps.Transducers.weakdeps]
+ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
@@ -1975,6 +2079,7 @@ version = "1.5.1"
[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
+version = "1.11.0"
[[deps.UnPack]]
git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b"
@@ -1983,6 +2088,7 @@ version = "1.0.2"
[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
+version = "1.11.0"
[[deps.UnicodeFun]]
deps = ["REPL"]
@@ -2221,15 +2327,15 @@ version = "1.2.13+1"
[[deps.Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b"
+git-tree-sha1 = "555d1076590a6cc2fdee2ef1469451f872d8b41b"
uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
-version = "1.5.6+0"
+version = "1.5.6+1"
[[deps.Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
-git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54"
+git-tree-sha1 = "f816633be6dc5c0ed9ffedda157ecfda0b3b6a69"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.70"
+version = "0.6.72"
[deps.Zygote.extensions]
ZygoteColorsExt = "Colors"
@@ -2249,9 +2355,15 @@ version = "0.2.5"
[[deps.cuDNN]]
deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"]
-git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e"
+git-tree-sha1 = "4b3ac62501ca73263eaa0d034c772f13c647fba6"
uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
-version = "1.3.2"
+version = "1.4.0"
+
+[[deps.demumble_jll]]
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "6498e3581023f8e530f34760d18f75a69e3a4ea8"
+uuid = "1e29f10c-031c-5a83-9565-69cddfc27673"
+version = "1.3.0+0"
[[deps.eudev_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"]
@@ -2314,9 +2426,9 @@ version = "1.18.0+0"
[[deps.libpng_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"]
-git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4"
+git-tree-sha1 = "b70c870239dc3d7bc094eb2d6be9b73d27bef280"
uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f"
-version = "1.6.43+1"
+version = "1.6.44+0"
[[deps.libvorbis_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"]
@@ -2333,7 +2445,7 @@ version = "1.1.6+0"
[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
-version = "1.52.0+1"
+version = "1.59.0+0"
[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
diff --git a/docs/Project.toml b/docs/Project.toml
index 7118d5fe..e8f6d849 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -3,9 +3,12 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
+DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
+MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
+Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
diff --git a/docs/make.jl b/docs/make.jl
index 8aedd777..f24ea50b 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -24,7 +24,7 @@ makedocs(;
"Calibrated forecasts" => "tutorials/calibration.md",
],
"Reference" => "reference.md",
- "MLJ interface"=> "mlj_interface.md",
+ "MLJ interface" => "mlj_interface.md",
"Additional Resources" => "resources/_resources.md",
],
)
diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl
index ed931ca7..b6324f95 100644
--- a/src/LaplaceRedux.jl
+++ b/src/LaplaceRedux.jl
@@ -26,5 +26,8 @@ export empirical_frequency_binary_classification,
empirical_frequency_regression,
sharpness_regression,
extract_mean_and_variance,
- sigma_scaling, rescale_stddev
+ sigma_scaling,
+ rescale_stddev
+
+include("direct_mlj.jl")
end
diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl
index 4a0a1d6c..feb457b3 100644
--- a/src/baselaplace/predicting.jl
+++ b/src/baselaplace/predicting.jl
@@ -93,7 +93,7 @@ Computes the Bayesian predictivie distribution from a neural network with a Lapl
- `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`.
- `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model.
- `return_distr::Bool=false`: if `false` (default), the function outputs either the direct output of the chain or pseudo-probabilities (if `predict_proba=true`).
- if `true` predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.
+ if `true` predict returns a probability distribution.
# Returns
@@ -144,7 +144,6 @@ function predict(
else
return fμ, pred_var
end
-
end
# Classification:
diff --git a/src/calibration_functions.jl b/src/calibration_functions.jl
index 1d313248..d6aa5062 100644
--- a/src/calibration_functions.jl
+++ b/src/calibration_functions.jl
@@ -78,7 +78,7 @@ Outputs: \
"""
function sharpness_classification(y_binary, distributions::Vector{Bernoulli{Float64}})
mean_class_one = mean(mean.(distributions[findall(y_binary .== 1)]))
- mean_class_zero = mean( 1 .- mean.(distributions[findall(y_binary .== 0)]))
+ mean_class_zero = mean(1 .- mean.(distributions[findall(y_binary .== 0)]))
return mean_class_one, mean_class_zero
end
@@ -176,8 +176,9 @@ Inputs: \
Outputs: \
- `sigma`: the scalar that maximize the likelihood.
"""
-function sigma_scaling(distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat}
- ) where T <: AbstractFloat
+function sigma_scaling(
+ distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat}
+) where {T<:AbstractFloat}
means, variances = extract_mean_and_variance(distr)
sigma = sqrt(1 / length(y_cal) * sum(norm.(y_cal .- means) ./ variances))
@@ -198,4 +199,4 @@ Outputs: \
function rescale_stddev(distr::Vector{Normal{T}}, s::T) where {T<:AbstractFloat}
rescaled_distr = [Normal(mean(d), std(d) * s) for d in distr]
return rescaled_distr
-end
\ No newline at end of file
+end
diff --git a/src/data/functions.jl b/src/data/functions.jl
index 99e4440c..5ae83805 100644
--- a/src/data/functions.jl
+++ b/src/data/functions.jl
@@ -1,4 +1,4 @@
-import Random
+using Random: Random
"""
toy_data_linear(N=100)
@@ -42,7 +42,7 @@ toy_data_non_linear()
```
"""
-function toy_data_non_linear( N=100; seed=nothing)
+function toy_data_non_linear(N=100; seed=nothing)
#set seed if available
if seed !== nothing
diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl
new file mode 100644
index 00000000..61aab7dd
--- /dev/null
+++ b/src/direct_mlj.jl
@@ -0,0 +1,846 @@
+using Optimisers: Optimisers
+using Flux
+using Random
+using Tables
+using LinearAlgebra
+using LaplaceRedux
+using MLJBase: MLJBase
+import MLJModelInterface as MMI
+using Distributions: Normal
+
+MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
+ model::Union{Flux.Chain,Nothing} = nothing
+ flux_loss = Flux.Losses.logitcrossentropy
+ optimiser = Optimisers.Adam()
+ epochs::Integer = 1000::(_ > 0)
+ batch_size::Integer = 32::(_ > 0)
+ subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
+ subnetwork_indices = nothing
+ hessian_structure::Union{HessianStructure,Symbol,String} =
+ :full::(_ in (:full, :diagonal))
+ backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
+ σ::Float64 = 1.0
+ μ₀::Float64 = 0.0
+ P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
+ fit_prior_nsteps::Int = 100::(_ > 0)
+ link_approx::Symbol = :probit::(_ in (:probit, :plugin))
+end
+
+MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
+ model::Union{Flux.Chain,Nothing} = nothing
+ flux_loss = Flux.Losses.mse
+ optimiser = Optimisers.Adam()
+ epochs::Integer = 1000::(_ > 0)
+ batch_size::Integer = 32::(_ > 0)
+ subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
+ subnetwork_indices = nothing
+ hessian_structure::Union{HessianStructure,Symbol,String} =
+ :full::(_ in (:full, :diagonal))
+ backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
+ σ::Float64 = 1.0
+ μ₀::Float64 = 0.0
+ P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
+ fit_prior_nsteps::Int = 100::(_ > 0)
+end
+
+LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier}
+
+# for fit:
+function MMI.reformat(::LaplaceRegressor, X, y)
+ return (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing))
+end
+
+function MMI.reformat(::LaplaceClassifier, X, y)
+ X = MLJBase.matrix(X) |> permutedims
+ y = MLJBase.categorical(y)
+ labels = y.pool.levels
+ y = Flux.onehotbatch(y, labels) # One-hot encoding
+
+ return X, (y, labels)
+end
+
+MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,)
+
+MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:, I], y[2]))
+MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],)
+
+
+
+"""
+ function dataset_shape(model::LaplaceRegression, X, y)
+
+Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.
+
+# Arguments
+- `model::LaplaceModels`: The Laplace model to fit.
+- `X`: The input data for training.
+- `y`: The target labels for training one-hot encoded.
+
+# Returns
+- (input size, output size)
+"""
+function dataset_shape(model::LaplaceModels, X, y)
+ n_input = size(X, 1)
+ dims = size(y)
+ if length(dims) == 1
+ n_output = 1
+ else
+ n_output = dims[1]
+ end
+ return (n_input, n_output)
+end
+
+
+"""
+ default_build( seed::Int, shape)
+
+Builds a default MLP Flux model compatible with the dimensions of the dataset, with reproducible initial weights.
+
+# Arguments
+- `seed::Int`: The seed for random number generation.
+- `shape`: a tuple containing the dimensions of the input layer and the output layer.
+
+# Returns
+- The constructed Flux model, which consist in a simple MLP with 2 hidden layers with 20 neurons each and an input and output layers compatible with the dataset.
+"""
+function default_build(seed::Int, shape)
+ Random.seed!(seed)
+ (n_input, n_output) = shape
+
+ chain = Chain(
+ Dense(n_input, 20, relu),
+ Dense(20, 20, relu),
+ Dense(20, 20, relu),
+ Dense(20, n_output)
+ )
+
+ return chain
+end
+
+
+
+
+@doc """
+ MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)
+
+Fit a Laplace model using the provided features and target values.
+
+# Arguments
+- `m::Laplace`: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted.
+- `verbosity`: Verbosity level for logging.
+- `X`: Input features, expected to be in a format compatible with MLJBase.matrix.
+- `y`: Target values.
+
+# Returns
+- `fitresult`: a tuple (la,decode) cointaing the fitted Laplace model and y[1],the first element of the categorical y vector.
+- `cache`: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history.
+- `report`: A Namedtuple containing the loss history of the fitting process.
+"""
+function MMI.fit(m::LaplaceModels, verbosity, X, y)
+ y, decode = y
+
+ if (m.model === nothing)
+ @warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each and input and output layers compatible with the dataset."
+ shape = dataset_shape(m, X, y)
+
+ m.model = default_build(11, shape)
+
+ end
+
+ # Make a copy of the model because Flux does not allow to mutate hyperparameters
+ copied_model = deepcopy(m.model)
+
+ data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
+ state_tree = Optimisers.setup(m.optimiser, copied_model)
+ loss_history = []
+
+ for epoch in 1:(m.epochs)
+ loss_per_epoch = 0.0
+
+ for (X_batch, y_batch) in data_loader
+ # Forward pass: compute predictions
+ y_pred = copied_model(X_batch)
+
+ # Compute loss
+ loss = m.flux_loss(y_pred, y_batch)
+
+ # Compute gradients
+ grads, _ = gradient(copied_model, X_batch) do grad_model, X
+ # Recompute predictions inside gradient context
+ y_pred = grad_model(X)
+ m.flux_loss(y_pred, y_batch)
+ end
+
+ # Update parameters using the optimizer and computed gradients
+ state_tree, copied_model = Optimisers.update!(state_tree, copied_model, grads)
+
+ # Accumulate the loss for this batch
+ loss_per_epoch += sum(loss) # Summing the batch loss
+ end
+
+ push!(loss_history, loss_per_epoch)
+
+ # Print loss every 100 epochs if verbosity is 1 or more
+ if verbosity >= 1 && epoch % 100 == 0
+ println("Epoch $epoch: Loss: $loss_per_epoch ")
+ end
+ end
+
+ la = LaplaceRedux.Laplace(
+ copied_model;
+ likelihood=:regression,
+ subset_of_weights=m.subset_of_weights,
+ subnetwork_indices=m.subnetwork_indices,
+ hessian_structure=m.hessian_structure,
+ backend=m.backend,
+ σ=m.σ,
+ μ₀=m.μ₀,
+ P₀=m.P₀,
+ )
+
+ if typeof(m) == LaplaceClassifier
+ la.likelihood = :classification
+ end
+
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbosity= verbosity, n_steps=m.fit_prior_nsteps)
+
+ fitresult = (la, decode)
+ report = (loss_history=loss_history,)
+ cache = (deepcopy(m), state_tree, loss_history)
+ return fitresult, cache, report
+end
+
+@doc """
+ MMI.update(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)
+
+Update the Laplace model using the provided new data points.
+
+# Arguments
+- `m`: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted.
+- `verbosity`: Verbosity level for logging.
+- `X`: New input features, expected to be in a format compatible with MLJBase.matrix.
+- `y`: New target values.
+
+# Returns
+- `fitresult`: a tuple (la,decode) cointaing the updated fitted Laplace model and y[1],the first element of the categorical y vector.
+- `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history.
+- `report`: A Namedtuple containing the complete loss history of the fitting process.
+"""
+function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
+ y_up, decode = y
+
+ data_loader = Flux.DataLoader((X, y_up); batchsize=m.batch_size)
+ old_model = old_cache[1]
+ old_state_tree = old_cache[2]
+ old_loss_history = old_cache[3]
+
+ epochs = m.epochs
+
+ if MMI.is_same_except(m, old_model, :epochs)
+ old_la = old_fitresult[1]
+ if epochs > old_model.epochs
+ for epoch in (old_model.epochs + 1):(epochs)
+ loss_per_epoch = 0.0
+
+ for (X_batch, y_batch) in data_loader
+ # Forward pass: compute predictions
+ y_pred = old_la.model(X_batch)
+
+ # Compute loss
+ loss = m.flux_loss(y_pred, y_batch)
+
+ # Compute gradients
+ grads, _ = gradient(old_la.model, X_batch) do grad_model, X
+ # Recompute predictions inside gradient context
+ y_pred = grad_model(X)
+ m.flux_loss(y_pred, y_batch)
+ end
+
+ # Update parameters using the optimizer and computed gradients
+ old_state_tree, old_la.model = Optimisers.update!(
+ old_state_tree, old_la.model, grads
+ )
+
+ # Accumulate the loss for this batch
+ loss_per_epoch += sum(loss) # Summing the batch loss
+ end
+
+ push!(old_loss_history, loss_per_epoch)
+
+ # Print loss every 100 epochs if verbosity is 1 or more
+ if verbosity >= 1 && epoch % 100 == 0
+ println("Epoch $epoch: Loss: $loss_per_epoch ")
+ end
+ end
+
+ la = LaplaceRedux.Laplace(
+ old_la.model;
+ likelihood=:regression,
+ subset_of_weights=m.subset_of_weights,
+ subnetwork_indices=m.subnetwork_indices,
+ hessian_structure=m.hessian_structure,
+ backend=m.backend,
+ σ=m.σ,
+ μ₀=m.μ₀,
+ P₀=m.P₀,
+ )
+ if typeof(m) == LaplaceClassifier
+ la.likelihood = :classification
+ end
+
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbosity = verbosity, n_steps=m.fit_prior_nsteps)
+
+ fitresult = (la, decode)
+ report = (loss_history=old_loss_history,)
+ cache = (deepcopy(m), old_state_tree, old_loss_history)
+
+ else
+ println(
+ "The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary",
+ )
+ fitresult = (old_la, decode)
+ report = (loss_history=old_loss_history,)
+ cache = (deepcopy(m), old_state_tree, old_loss_history)
+ end
+
+ elseif MMI.is_same_except(
+ m,
+ old_model,
+ :fit_prior_nsteps,
+ :subset_of_weights,
+ :subnetwork_indices,
+ :hessian_structure,
+ :backend,
+ :σ,
+ :μ₀,
+ :P₀,
+ )
+ println(" updating only the laplace optimization part")
+ old_la = old_fitresult[1]
+
+ la = LaplaceRedux.Laplace(
+ old_la.model;
+ likelihood=:regression,
+ subset_of_weights=m.subset_of_weights,
+ subnetwork_indices=m.subnetwork_indices,
+ hessian_structure=m.hessian_structure,
+ backend=m.backend,
+ σ=m.σ,
+ μ₀=m.μ₀,
+ P₀=m.P₀,
+ )
+ if typeof(m) == LaplaceClassifier
+ la.likelihood = :classification
+ end
+
+ # fit the Laplace model:
+ LaplaceRedux.fit!(la, data_loader)
+ optimize_prior!(la; verbosity = verbosity, n_steps=m.fit_prior_nsteps)
+
+ fitresult = (la, decode)
+ report = (loss_history=old_loss_history,)
+ cache = (deepcopy(m), old_state_tree, old_loss_history)
+
+ end
+
+ return fitresult, cache, report
+end
+
+@doc """
+ function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...)
+
+If both `m1` and `m2` are of `MLJType`, return `true` if the
+following conditions all hold, and `false` otherwise:
+
+- `typeof(m1) === typeof(m2)`
+
+- `propertynames(m1) === propertynames(m2)`
+
+- with the exception of properties listed as `exceptions` or bound to
+ an `AbstractRNG`, each pair of corresponding property values is
+ either "equal" or both undefined. (If a property appears as a
+ `propertyname` but not a `fieldname`, it is deemed as always defined.)
+
+The meaining of "equal" depends on the type of the property value:
+
+- values that are themselves of `MLJType` are "equal" if they are
+ equal in the sense of `is_same_except` with no exceptions.
+
+- values that are not of `MLJType` are "equal" if they are `==`.
+
+In the special case of a "deep" property, "equal" has a different
+meaning; see `MLJBase.deep_properties` for details.
+
+If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`.
+
+"""
+function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...)
+ typeof(m1) === typeof(m2) || return false
+ names = propertynames(m1)
+ propertynames(m2) === names || return false
+
+ for name in names
+ if !(name in exceptions)
+ if !_isdefined(m1, name)
+ !_isdefined(m2, name) || return false
+ elseif _isdefined(m2, name)
+ if name in MLJBase.deep_properties(LaplaceRegressor)
+ _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) ||
+ return false
+ else
+ (
+ MMI.is_same_except(getproperty(m1, name), getproperty(m2, name)) ||
+ getproperty(m1, name) isa AbstractRNG ||
+ getproperty(m2, name) isa AbstractRNG ||
+ (
+ getproperty(m1, name) isa Flux.Chain &&
+ getproperty(m2, name) isa Flux.Chain &&
+ _equal_flux_chain(getproperty(m1, name), getproperty(m2, name))
+ )
+ ) || return false
+ end
+ else
+ return false
+ end
+ end
+ end
+ return true
+end
+function _isdefined(object, name)
+ pnames = propertynames(object)
+ fnames = fieldnames(typeof(object))
+ name in pnames && !(name in fnames) && return true
+ return isdefined(object, name)
+end
+
+function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
+ if length(chain1.layers) != length(chain2.layers)
+ return false
+ end
+ params1 = Flux.params(chain1)
+ params2 = Flux.params(chain2)
+ if length(params1) != length(params2)
+ return false
+ end
+ for (p1, p2) in zip(params1, params2)
+ if !isequal(p1, p2)
+ return false
+ end
+ end
+ for (layer1, layer2) in zip(chain1.layers, chain2.layers)
+ if typeof(layer1) != typeof(layer2)
+ return false
+ end
+ end
+ return true
+end
+
+@doc """
+
+ function MMI.fitted_params(model::LaplaceRegressor, fitresult)
+
+
+ This function extracts the fitted parameters from a `LaplaceRegressor` model.
+
+ # Arguments
+ - `model::LaplaceRegressor`: The Laplace regression model.
+ - `fitresult`: the Laplace approximation (`la`).
+
+ # Returns
+ A named tuple containing:
+ - `μ`: The mean of the posterior distribution.
+ - `H`: The Hessian of the posterior distribution.
+ - `P`: The precision matrix of the posterior distribution.
+ - `Σ`: The covariance matrix of the posterior distribution.
+ - `n_data`: The number of data points.
+ - `n_params`: The number of parameters.
+ - `n_out`: The number of outputs.
+ - `loss`: The loss value of the posterior distribution.
+
+"""
+function MMI.fitted_params(model::LaplaceModels, fitresult)
+ la, decode = fitresult
+ posterior = la.posterior
+ return (
+ μ=posterior.μ,
+ H=posterior.H,
+ P=posterior.P,
+ Σ=posterior.Σ,
+ n_data=posterior.n_data,
+ n_params=posterior.n_params,
+ n_out=posterior.n_out,
+ loss=posterior.loss,
+ )
+end
+
+@doc """
+ MMI.training_losses(model::Union{LaplaceRegressor,LaplaceClassifier}, report)
+
+Retrieve the training loss history from the given `report`.
+
+# Arguments
+- `model`: The model for which the training losses are being retrieved.
+- `report`: An object containing the training report, which includes the loss history.
+
+# Returns
+- A collection representing the loss history from the training report.
+"""
+function MMI.training_losses(model::LaplaceModels, report)
+ return report.loss_history
+end
+
+@doc """
+function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
+
+ Predicts the response for new data using a fitted Laplace model.
+
+ # Arguments
+ - `m::LaplaceRegressor`: The Laplace model.
+ - `fitresult`: The result of the fitting procedure.
+ - `Xnew`: The new data for which predictions are to be made.
+
+ # Returns
+ for LaplaceRegressor:
+ - An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`.
+ for LaplaceClassifier:
+ - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data.
+"""
+function MMI.predict(m::LaplaceModels, fitresult, Xnew)
+ la, decode = fitresult
+ if typeof(m) == LaplaceRegressor
+ yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
+ # Extract mean and variance matrices
+ means, variances = yhat
+
+ # Create Normal distributions from the means and variances
+ return vec([Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)])
+
+ else
+ predictions =
+ LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |>
+ permutedims
+
+ return MLJBase.UnivariateFinite(decode, predictions; pool=missing)
+ end
+end
+
+MMI.metadata_pkg(
+ LaplaceRegressor;
+ name="LaplaceRedux",
+ package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478",
+ package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
+ is_pure_julia=true,
+ is_wrapper=true,
+ package_license="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE",
+)
+
+MMI.metadata_pkg(
+ LaplaceClassifier;
+ name="LaplaceRedux",
+ package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478",
+ package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl",
+ is_pure_julia=true,
+ is_wrapper=true,
+ package_license="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE",
+)
+
+MLJBase.metadata_model(
+ LaplaceClassifier;
+ input_scitype=Union{
+ AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types
+ MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
+ },
+ target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass
+ supports_training_losses=true,
+ load_path="LaplaceRedux.LaplaceClassifier",
+)
+# metadata for each model,
+MLJBase.metadata_model(
+ LaplaceRegressor;
+ input_scitype=Union{
+ AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types
+ MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
+ },
+ target_scitype=AbstractArray{MLJBase.Continuous},
+ supports_training_losses=true,
+ load_path="LaplaceRedux.LaplaceRegressor",
+)
+
+const DOC_LAPLACE_REDUX =
+ "[Laplace Redux – Effortless Bayesian Deep Learning]" *
+ "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in " *
+ "Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103"
+
+"""
+$(MMI.doc_header(LaplaceClassifier))
+
+`LaplaceClassifier` implements the $DOC_LAPLACE_REDUX for classification models.
+
+# Training data
+
+In MLJ or MLJBase, given a dataset X,y and a Flux Chain adapt to the dataset, pass the chain to the model
+
+laplace_model = LaplaceClassifier(model = Flux_Chain,kwargs...)
+
+then bind an instance `laplace_model` to data with
+
+ mach = machine(laplace_model, X, y)
+
+where
+
+- `X`: any table of input features (eg, a `DataFrame`) whose columns
+ each have one of the following element scitypes: `Continuous`,
+ `Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`
+
+- `y`: is the target, which can be any `AbstractVector` whose element
+ scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
+ with `scitype(y)`
+
+Train the machine using `fit!(mach, rows=...)`.
+
+
+# Hyperparameters (format: name-type-default value-restrictions)
+
+- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each.
+
+- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
+
+- `optimiser = Adam()` a Flux optimiser
+
+- `epochs::Integer = 1000::(_ > 0)`: the number of training epochs.
+
+- `batch_size::Integer = 32::(_ > 0)`: the batch size.
+
+- `subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
+
+- `subnetwork_indices = nothing`: the indices of the subnetworks.
+
+- `hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal))`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
+
+- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
+
+- `σ::Float64 = 1.0`: the standard deviation of the prior distribution.
+
+- `μ₀::Float64 = 0.0`: the mean of the prior distribution.
+
+- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution.
+
+- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors.
+
+- `link_approx::Symbol = :probit::(_ in (:probit, :plugin))`: the approximation to adopt to compute the probabilities.
+
+# Operations
+
+- `predict(mach, Xnew)`: return predictions of the target given
+ features `Xnew` having the same scitype as `X` above. Predictions
+ are probabilistic, but uncalibrated.
+
+- `predict_mode(mach, Xnew)`: instead return the mode of each
+ prediction above.
+
+- `training_losses(mach)`: return the loss history from report
+
+
+# Fitted parameters
+
+The fields of `fitted_params(mach)` are:
+
+ - `μ`: The mean of the posterior distribution.
+
+ - `H`: The Hessian of the posterior distribution.
+
+ - `P`: The precision matrix of the posterior distribution.
+
+ - `Σ`: The covariance matrix of the posterior distribution.
+
+ - `n_data`: The number of data points.
+
+ - `n_params`: The number of parameters.
+
+ - `n_out`: The number of outputs.
+
+ - `loss`: The loss value of the posterior distribution.
+
+
+
+ # Report
+
+The fields of `report(mach)` are:
+
+- `loss_history`: an array containing the total loss per epoch.
+
+# Accessor functions
+
+
+# Examples
+
+```
+using MLJ
+LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux
+
+X, y = @load_iris
+
+# Define the Flux Chain model
+using Flux
+model = Chain(
+ Dense(4, 10, relu),
+ Dense(10, 10, relu),
+ Dense(10, 3)
+)
+
+#Define the LaplaceClassifier
+model = LaplaceClassifier(model=model)
+
+mach = machine(model, X, y) |> fit!
+
+Xnew = (sepal_length = [6.4, 7.2, 7.4],
+ sepal_width = [2.8, 3.0, 2.8],
+ petal_length = [5.6, 5.8, 6.1],
+ petal_width = [2.1, 1.6, 1.9],)
+yhat = predict(mach, Xnew) # probabilistic predictions
+predict_mode(mach, Xnew) # point predictions
+training_losses(mach) # loss history per epoch
+pdf.(yhat, "virginica") # probabilities for the "verginica" class
+fitted_params(mach) # NamedTuple with the fitted params of Laplace
+
+```
+
+See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl).
+
+"""
+LaplaceClassifier
+
+"""
+$(MMI.doc_header(LaplaceRegressor))
+
+`LaplaceRegressor` implements the $DOC_LAPLACE_REDUX for regression models.
+
+# Training data
+
+In MLJ or MLJBase, given a dataset X,y and a Flux Chain adapt to the dataset, pass the chain to the model
+
+laplace_model = LaplaceRegressor(model = Flux_Chain,kwargs...)
+
+then bind an instance `laplace_model` to data with
+
+ mach = machine(laplace_model, X, y)
+where
+
+- `X`: any table of input features (eg, a `DataFrame`) whose columns
+ each have one of the following element scitypes: `Continuous`,
+ `Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`
+
+- `y`: is the target, which can be any `AbstractVector` whose element
+ scitype is `<:Continuous`; check the scitype
+ with `scitype(y)`
+
+Train the machine using `fit!(mach, rows=...)`.
+
+
+# Hyperparameters (format: name-type-default value-restrictions)
+
+- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each.
+- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
+
+- `optimiser = Adam()` a Flux optimiser
+
+- `epochs::Integer = 1000::(_ > 0)`: the number of training epochs.
+
+- `batch_size::Integer = 32::(_ > 0)`: the batch size.
+
+- `subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
+
+- `subnetwork_indices = nothing`: the indices of the subnetworks.
+
+- `hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal))`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
+
+- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
+
+- `σ::Float64 = 1.0`: the standard deviation of the prior distribution.
+
+- `μ₀::Float64 = 0.0`: the mean of the prior distribution.
+
+- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution.
+
+- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors.
+
+
+# Operations
+
+- `predict(mach, Xnew)`: return predictions of the target given
+ features `Xnew` having the same scitype as `X` above. Predictions
+ are probabilistic, but uncalibrated.
+
+- `predict_mode(mach, Xnew)`: instead return the mode of each
+ prediction above.
+
+- `training_losses(mach)`: return the loss history from report
+
+
+# Fitted parameters
+
+The fields of `fitted_params(mach)` are:
+
+ - `μ`: The mean of the posterior distribution.
+
+ - `H`: The Hessian of the posterior distribution.
+
+ - `P`: The precision matrix of the posterior distribution.
+
+ - `Σ`: The covariance matrix of the posterior distribution.
+
+ - `n_data`: The number of data points.
+
+ - `n_params`: The number of parameters.
+
+ - `n_out`: The number of outputs.
+
+ - `loss`: The loss value of the posterior distribution.
+
+
+# Report
+
+The fields of `report(mach)` are:
+
+- `loss_history`: an array containing the total loss per epoch.
+
+
+
+
+# Accessor functions
+
+
+
+# Examples
+
+```
+using MLJ
+using Flux
+LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux
+model = Chain(
+ Dense(4, 10, relu),
+ Dense(10, 10, relu),
+ Dense(10, 1)
+)
+model = LaplaceRegressor(model=model)
+
+X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
+mach = machine(model, X, y) |> fit!
+
+Xnew, _ = make_regression(3, 4; rng=123)
+yhat = predict(mach, Xnew) # probabilistic predictions
+predict_mode(mach, Xnew) # point predictions
+training_losses(mach) # loss history per epoch
+fitted_params(mach) # NamedTuple with the fitted params of Laplace
+
+```
+
+See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl).
+
+"""
+LaplaceRegressor
+
diff --git a/test/Manifest.toml b/test/Manifest.toml
index 31dcb255..9717977b 100644
--- a/test/Manifest.toml
+++ b/test/Manifest.toml
@@ -2,7 +2,13 @@
julia_version = "1.10.5"
manifest_format = "2.0"
-project_hash = "2fde859c875aff2c1b66bd10b3f4f3d64f67067a"
+project_hash = "44de0b245e083feedd88d905933ce8d9b455e504"
+
+[[deps.ARFFFiles]]
+deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"]
+git-tree-sha1 = "678eb18590a8bc6674363da4d5faa4ac09c40a18"
+uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8"
+version = "1.5.0"
[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
@@ -233,9 +239,9 @@ version = "0.7.6"
[[deps.ColorSchemes]]
deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"]
-git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0"
+git-tree-sha1 = "13951eb68769ad1cd460cdb2e64e5e95f1bf123d"
uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
-version = "3.26.0"
+version = "3.27.0"
[[deps.ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
@@ -421,9 +427,9 @@ version = "1.15.1"
[[deps.Distances]]
deps = ["LinearAlgebra", "Statistics", "StatsAPI"]
-git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0"
+git-tree-sha1 = "c7e3a542b999843086e2f29dac96a618c105be1d"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
-version = "0.10.11"
+version = "0.10.12"
weakdeps = ["ChainRulesCore", "SparseArrays"]
[deps.Distances.extensions]
@@ -461,6 +467,12 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"
+[[deps.EarlyStopping]]
+deps = ["Dates", "Statistics"]
+git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6"
+uuid = "792122b4-ca99-40de-a6bc-6742525f08b6"
+version = "0.3.0"
+
[[deps.EpollShim_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643"
@@ -543,9 +555,9 @@ version = "0.8.5"
[[deps.Flux]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
-git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777"
+git-tree-sha1 = "b78bd94ef1588881983bd6b8b860b2c27293140a"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.14.22"
+version = "0.14.23"
[deps.Flux.extensions]
FluxAMDGPUExt = "AMDGPU"
@@ -554,14 +566,12 @@ version = "0.14.22"
FluxEnzymeExt = "Enzyme"
FluxMPIExt = "MPI"
FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
- FluxMetalExt = "Metal"
[deps.Flux.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
- Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
@@ -692,9 +702,9 @@ version = "1.14.2+1"
[[deps.HTTP]]
deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"]
-git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed"
+git-tree-sha1 = "bc3f416a965ae61968c20d0ad867556367f2817d"
uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3"
-version = "1.10.8"
+version = "1.10.9"
[[deps.HarfBuzz_jll]]
deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll"]
@@ -786,6 +796,12 @@ git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.2.2"
+[[deps.IterationControl]]
+deps = ["EarlyStopping", "InteractiveUtils"]
+git-tree-sha1 = "e663925ebc3d93c1150a7570d114f9ea2f664726"
+uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
+version = "0.5.4"
+
[[deps.IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
@@ -793,9 +809,9 @@ version = "1.0.0"
[[deps.JLD2]]
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
-git-tree-sha1 = "b464b9b461ee989b435a689a4f7d870b68d467ed"
+git-tree-sha1 = "783c1be5213a09609b23237a0c9e5dfd258ae6f2"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
-version = "0.5.6"
+version = "0.5.7"
[[deps.JLFzf]]
deps = ["Pipe", "REPL", "Random", "fzf_jll"]
@@ -841,9 +857,9 @@ version = "0.2.4"
[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
-git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8"
+git-tree-sha1 = "e73a077abc7fe798fe940deabe30ef6c66bdde52"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
-version = "0.9.28"
+version = "0.9.29"
[deps.KernelAbstractions.extensions]
EnzymeExt = "EnzymeCore"
@@ -916,6 +932,12 @@ version = "0.16.5"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8"
+[[deps.LatinHypercubeSampling]]
+deps = ["Random", "StableRNGs", "StatsBase", "Test"]
+git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8"
+uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d"
+version = "1.9.0"
+
[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
@@ -966,9 +988,9 @@ version = "3.2.2+1"
[[deps.Libgcrypt_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"]
-git-tree-sha1 = "9fd170c4bbfd8b935fdc5f8b7aa33532c991a673"
+git-tree-sha1 = "8be878062e0ffa2c3f67bb58a595375eda5de80b"
uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4"
-version = "1.8.11+0"
+version = "1.11.0+0"
[[deps.Libglvnd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll", "Xorg_libXext_jll"]
@@ -978,15 +1000,15 @@ version = "1.6.0+0"
[[deps.Libgpg_error_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed"
+git-tree-sha1 = "c6ce1e19f3aec9b59186bdf06cdf3c4fc5f5f3e6"
uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8"
-version = "1.49.0+0"
+version = "1.50.0+0"
[[deps.Libiconv_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175"
+git-tree-sha1 = "61dfdba58e585066d8bce214c5a51eaa0539f269"
uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531"
-version = "1.17.0+0"
+version = "1.17.0+1"
[[deps.Libmount_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
@@ -1031,9 +1053,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[deps.LoggingExtras]]
deps = ["Dates", "Logging"]
-git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075"
+git-tree-sha1 = "f02b56007b064fbfddb4c9cd60161b6dd0f40df3"
uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
-version = "1.0.3"
+version = "1.1.0"
[[deps.MAT]]
deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"]
@@ -1042,10 +1064,10 @@ uuid = "23992714-dd62-5051-b70f-ba57cb901cac"
version = "0.10.7"
[[deps.MLDataDevices]]
-deps = ["Adapt", "Compat", "Functors", "LinearAlgebra", "Preferences", "Random"]
-git-tree-sha1 = "3207c2e66164e6366440ad3f0243a8d67abb4a47"
+deps = ["Adapt", "Compat", "Functors", "Preferences", "Random"]
+git-tree-sha1 = "5cffc52b59227864b665459e1f7bcc4d3c4fb47b"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
-version = "1.4.1"
+version = "1.4.2"
[deps.MLDataDevices.extensions]
MLDataDevicesAMDGPUExt = "AMDGPU"
@@ -1087,24 +1109,58 @@ git-tree-sha1 = "361c2692ee730944764945859f1a6b31072e275d"
uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458"
version = "0.7.18"
+[[deps.MLFlowClient]]
+deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"]
+git-tree-sha1 = "9abb12b62debc27261c008daa13627255bf79967"
+uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
+version = "0.5.1"
+
+[[deps.MLJ]]
+deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBalancing", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "StatisticalMeasures", "Statistics", "StatsBase", "Tables"]
+git-tree-sha1 = "bd2072e9cd65be0a3cb841f3d8cda1d2cacfe5db"
+uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
+version = "0.20.5"
+
+[[deps.MLJBalancing]]
+deps = ["MLJBase", "MLJModelInterface", "MLUtils", "OrderedCollections", "Random", "StatsBase"]
+git-tree-sha1 = "f707a01a92d664479522313907c07afa5d81df19"
+uuid = "45f359ea-796d-4f51-95a5-deb1a414c586"
+version = "0.1.5"
+
[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "1.7.0"
+weakdeps = ["StatisticalMeasures"]
[deps.MLJBase.extensions]
DefaultMeasuresExt = "StatisticalMeasures"
- [deps.MLJBase.weakdeps]
- StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
-
[[deps.MLJDecisionTreeInterface]]
deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"]
git-tree-sha1 = "90ef4d3b6cacec631c57cc034e1e61b4aa0ce511"
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
version = "0.4.2"
+[[deps.MLJEnsembles]]
+deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatisticalMeasuresBase", "StatsBase"]
+git-tree-sha1 = "84a5be55a364bb6b6dc7780bbd64317ebdd3ad1e"
+uuid = "50ed68f4-41fd-4504-931a-ed422449fee0"
+version = "0.4.3"
+
+[[deps.MLJFlow]]
+deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"]
+git-tree-sha1 = "508bff8071d7d1902d6f1b9d1e868d58821f1cfe"
+uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f"
+version = "0.5.0"
+
+[[deps.MLJIteration]]
+deps = ["IterationControl", "MLJBase", "Random", "Serialization"]
+git-tree-sha1 = "ad16cfd261e28204847f509d1221a581286448ae"
+uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
+version = "0.6.3"
+
[[deps.MLJModelInterface]]
deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
git-tree-sha1 = "ceaff6618408d0e412619321ae43b33b40c1a733"
@@ -1117,6 +1173,12 @@ git-tree-sha1 = "410da88e0e6ece5467293d2c76b51b7c6df7d072"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
version = "0.16.17"
+[[deps.MLJTuning]]
+deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase", "StatisticalMeasuresBase"]
+git-tree-sha1 = "38aab60b1274ce7d6da784808e3be69e585dbbf6"
+uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
+version = "0.8.8"
+
[[deps.MLStyle]]
git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
uuid = "d8e11817-5142-5d16-987a-aa16d5891078"
@@ -1289,6 +1351,12 @@ deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+2"
+[[deps.OpenML]]
+deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"]
+git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33"
+uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66"
+version = "0.3.1"
+
[[deps.OpenMPI_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"]
git-tree-sha1 = "bfce6d523861a6c562721b262c0d1aaeead2647f"
@@ -1401,9 +1469,9 @@ version = "1.10.0"
[[deps.PlotThemes]]
deps = ["PlotUtils", "Statistics"]
-git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0"
+git-tree-sha1 = "41031ef3a1be6f5bbbf3e8073f210556daeae5ca"
uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a"
-version = "3.2.0"
+version = "3.3.0"
[[deps.PlotUtils]]
deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"]
@@ -1606,9 +1674,9 @@ version = "1.2.1"
[[deps.SentinelArrays]]
deps = ["Dates", "Random"]
-git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a"
+git-tree-sha1 = "305becf8af67eae1dbc912ee9097f00aeeabb8d5"
uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
-version = "1.4.5"
+version = "1.4.6"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -1705,6 +1773,20 @@ git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682"
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
version = "1.4.3"
+[[deps.StatisticalMeasures]]
+deps = ["CategoricalArrays", "CategoricalDistributions", "Distributions", "LearnAPI", "LinearAlgebra", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "StatisticalMeasuresBase", "Statistics", "StatsBase"]
+git-tree-sha1 = "c1d4318fa41056b839dfbb3ee841f011fa6e8518"
+uuid = "a19d573c-0a75-4610-95b3-7071388c7541"
+version = "0.1.7"
+
+ [deps.StatisticalMeasures.extensions]
+ LossFunctionsExt = "LossFunctions"
+ ScientificTypesExt = "ScientificTypes"
+
+ [deps.StatisticalMeasures.weakdeps]
+ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
+ ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
+
[[deps.StatisticalMeasuresBase]]
deps = ["CategoricalArrays", "InteractiveUtils", "MLUtils", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "Statistics"]
git-tree-sha1 = "17dfb22e2e4ccc9cd59b487dce52883e0151b4d3"
@@ -1983,9 +2065,9 @@ version = "1.6.1"
[[deps.XML2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"]
-git-tree-sha1 = "1165b0443d0eca63ac1e32b8c0eb69ed2f4f8127"
+git-tree-sha1 = "6a451c6f33a176150f315726eba8b92fbfdb9ae7"
uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a"
-version = "2.13.3+0"
+version = "2.13.4+0"
[[deps.XSLT_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"]
@@ -1995,9 +2077,9 @@ version = "1.1.41+0"
[[deps.XZ_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632"
+git-tree-sha1 = "15e637a697345f6743674f1322beefbc5dcd5cfc"
uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800"
-version = "5.4.6+0"
+version = "5.6.3+0"
[[deps.Xorg_libICE_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
diff --git a/test/Project.toml b/test/Project.toml
index 1e853790..aa0bf7d9 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -7,6 +7,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
diff --git a/test/calibration.jl b/test/calibration.jl
index 847543ce..988ff1ef 100644
--- a/test/calibration.jl
+++ b/test/calibration.jl
@@ -3,8 +3,6 @@ using LaplaceRedux
using Distributions
using Trapz
-
-
# Test for `sharpness_regression` function
@testset "sharpness_regression distributions tests" begin
@info " testing sharpness_regression with distributions"
@@ -201,4 +199,4 @@ end
distributions = [Normal(0, 1), Normal(2, 1), Normal(4, 1)]
rescaled_distr = rescale_stddev(distributions, 2.0)
@test rescaled_distr == [Normal(0, 2), Normal(2, 2), Normal(4, 2)]
-end
\ No newline at end of file
+end
diff --git a/test/data.jl b/test/data.jl
index e6d61881..05a5b30d 100644
--- a/test/data.jl
+++ b/test/data.jl
@@ -15,7 +15,7 @@ for fun in fun_list
# Generate data with the same seed
Random.seed!(seed)
xs1, ys1 = fun(N; seed=seed)
-
+
Random.seed!(seed)
xs2, ys2 = fun(N; seed=seed)
diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl
new file mode 100644
index 00000000..8f2b0438
--- /dev/null
+++ b/test/direct_mlj_interface.jl
@@ -0,0 +1,100 @@
+using Random: Random
+import Random.seed!
+using MLJBase: MLJBase, categorical
+using MLJ: MLJ
+using Flux
+using StableRNGs
+import LaplaceRedux: LaplaceClassifier, LaplaceRegressor
+
+cv = MLJBase.CV(; nfolds=3)
+
+@testset "Regression" begin
+ @info " testing interface for LaplaceRegressor"
+ flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1))
+ model = LaplaceRegressor(; model=flux_model, epochs=20)
+
+ #testing more complex dataset
+ X, y = MLJ.make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
+ #train, test = partition(eachindex(y), 0.7); # 70:30 split
+ mach = MLJ.machine(model, X, y)
+ MLJBase.fit!(mach; verbosity=0)
+ yhat = MLJBase.predict(mach, X) # probabilistic predictions
+ MLJBase.predict_mode(mach, X) # point predictions
+ MLJBase.fitted_params(mach) #fitted params function
+ MLJBase.training_losses(mach) #training loss history
+ model.epochs = 40 #changing number of epochs
+ MLJBase.fit!(mach; verbosity=0) #testing update function
+ model.epochs = 30 #changing number of epochs to a lower number
+ MLJBase.fit!(mach; verbosity=0) #testing update function
+ model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
+ MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part)
+ yhat = MLJBase.predict(mach, X) # probabilistic predictions
+ MLJ.evaluate!(mach; resampling=cv, measure=MLJ.log_loss, verbosity=0)
+
+ # Define a different model
+ flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 1))
+ # test update! fallback to fit!
+ model.model = flux_model_two
+ MLJBase.fit!(mach; verbosity=0)
+ model_two = LaplaceRegressor(; model=flux_model_two, epochs=100)
+ @test !MLJBase.is_same_except(model, model_two)
+
+
+
+ #testing default mlp builder
+ model = LaplaceRegressor(; model=nothing, epochs=20)
+ mach = MLJ.machine(model, X, y)
+ MLJBase.fit!(mach; verbosity=1)
+ yhat = MLJBase.predict(mach, X) # probabilistic predictions
+ MLJBase.predict_mode(mach, X) # point predictions
+ MLJBase.fitted_params(mach) #fitted params function
+ MLJBase.training_losses(mach) #training loss history
+ model.epochs = 100 #changing number of epochs
+ MLJBase.fit!(mach; verbosity=1) #testing update function
+
+ #testing dataset_shape for one dimensional function
+ X, y = MLJ.make_regression(100, 1; noise=0.5, sparse=0.2, outliers=0.1)
+ model = LaplaceRegressor(; model=nothing, epochs=20)
+ mach = MLJ.machine(model, X, y)
+ MLJBase.fit!(mach; verbosity=0)
+
+
+
+end
+
+@testset "Classification" begin
+ @info " testing interface for LaplaceClassifier"
+ # Define the model
+ flux_model = Chain(Dense(4, 10, relu), Dense(10, 3))
+
+ model = LaplaceClassifier(; model=flux_model, epochs=20)
+
+ X, y = MLJ.@load_iris
+ mach = MLJ.machine(model, X, y)
+ MLJBase.fit!(mach; verbosity=0)
+ Xnew = (
+ sepal_length=[6.4, 7.2, 7.4],
+ sepal_width=[2.8, 3.0, 2.8],
+ petal_length=[5.6, 5.8, 6.1],
+ petal_width=[2.1, 1.6, 1.9],
+ )
+ yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
+ MLJBase.predict_mode(mach, Xnew) # point predictions
+ MLJBase.pdf.(yhat, "virginica") # probabilities for the "verginica" class
+ MLJBase.fitted_params(mach) # fitted params
+ MLJBase.training_losses(mach) #training loss history
+ model.epochs = 40 #changing number of epochs
+ MLJBase.fit!(mach; verbosity=0) #testing update function
+ model.epochs = 30 #changing number of epochs to a lower number
+ MLJBase.fit!(mach; verbosity=0) #testing update function
+ model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
+ MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part)
+ MLJ.evaluate!(mach; resampling=cv, measure=MLJ.brier_loss, verbosity=0)
+
+ # Define a different model
+ flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 3))
+
+ model.model = flux_model_two
+
+ MLJBase.fit!(mach; verbosity=0)
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index 92d97033..b29d8e14 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -34,5 +34,7 @@ using Test
@testset "KronDecomposed" begin
include("krondecomposed.jl")
end
-
+ @testset "Interface" begin
+ include("direct_mlj_interface.jl")
+ end
end