diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 64bf7af1..22e67adc 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -22,6 +22,7 @@ jobs: version: - '1.9' - '1.10' + - '1' os: - ubuntu-latest - windows-latest diff --git a/Project.toml b/Project.toml index a06a996b..0124164e 100644 --- a/Project.toml +++ b/Project.toml @@ -4,11 +4,14 @@ authors = ["Patrick Altmeyer"] version = "1.1.1" [deps] +CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -19,20 +22,23 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.8" +CategoricalDistributions = "0.1.15" ChainRulesCore = "1.23.0" Compat = "4.7.0" Distributions = "0.25.109" Flux = "0.12, 0.13, 0.14" LinearAlgebra = "1.7, 1.10" +MLJBase = "1" +MLJModelInterface = "1.8.0" MLUtils = "0.4" Optimisers = "0.2, 0.3" Random = "1.9, 1.10" Statistics = "1" Tables = "1.10.1" -Test = "1.9, 1.10" +Test = "1" Tullio = "0.3.5" Zygote = "0.6" -julia = "1.9, 1.10" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/dev/issues/predict_slow.jl b/dev/issues/predict_slow.jl index ae713cbd..95c7c1f8 100644 --- a/dev/issues/predict_slow.jl +++ b/dev/issues/predict_slow.jl @@ -4,16 +4,16 @@ using Optimisers X = MLJBase.table(rand(Float32, 100, 3)); y = coerce(rand("abc", 100), Multiclass); -model = LaplaceClassification(optimiser=Optimisers.Adam(0.1), epochs=100); +model = LaplaceClassification(; optimiser=Optimisers.Adam(0.1), epochs=100); fitresult, _, _ = MLJBase.fit(model, 2, X, y); la = fitresult[1]; Xmat = matrix(X) |> permutedims; # Single test sample: -Xtest = Xmat[:,1:10]; +Xtest = Xmat[:, 1:10]; Xtest_tab = MLJBase.table(Xtest'); MLJBase.predict(model, fitresult, Xtest_tab); # warm up LaplaceRedux.predict(la, Xmat); # warm up @time MLJBase.predict(model, fitresult, Xtest_tab); @time LaplaceRedux.predict(la, Xtest); -@time glm_predictive_distribution(la, Xtest); \ No newline at end of file +@time glm_predictive_distribution(la, Xtest); diff --git a/dev/notebooks/mlj-interfacing/direct_mlj.ipynb b/dev/notebooks/mlj-interfacing/direct_mlj.ipynb new file mode 100644 index 00000000..4169dc0e --- /dev/null +++ b/dev/notebooks/mlj-interfacing/direct_mlj.ipynb @@ -0,0 +1,1137 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "using Revise\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\docs`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1mStatus\u001b[22m\u001b[39m `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\docs\\Project.toml`\n", + " \u001b[90m[324d7699] \u001b[39mCategoricalArrays v0.10.8\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[a93c6f00] \u001b[39mDataFrames v1.6.1\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[31c24e10] \u001b[39mDistributions v0.25.111\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[e30172f5] \u001b[39mDocumenter v1.6.0\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[587475ba] \u001b[39mFlux v0.14.19\n", + " \u001b[90m[c52c1a26] \u001b[39mLaplaceRedux v1.1.1 `C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl`\n", + "\u001b[33m⌅\u001b[39m \u001b[90m[094fc8d1] \u001b[39mMLJFlux v0.5.1\n", + " \u001b[90m[e80e1ace] \u001b[39mMLJModelInterface v1.11.0\n", + " \u001b[90m[3bd65402] \u001b[39mOptimisers v0.3.3\n", + " \u001b[90m[91a5bcdd] \u001b[39mPlots v1.40.8\n", + " \u001b[90m[ce6b1742] \u001b[39mRDatasets v0.7.7\n", + " \u001b[90m[860ef19b] \u001b[39mStableRNGs v1.0.2\n", + " \u001b[90m[2913bbd2] \u001b[39mStatsBase v0.34.3\n", + " \u001b[90m[bd369af6] \u001b[39mTables v1.12.0\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[bd7198b4] \u001b[39mTaijaPlotting v1.2.0\n", + " \u001b[90m[592b5752] \u001b[39mTrapz v2.0.3\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[e88e6eb3] \u001b[39mZygote v0.6.70\n", + "\u001b[32m⌃\u001b[39m \u001b[90m[02a925ec] \u001b[39mcuDNN v1.3.2\n", + " \u001b[90m[9a3f8284] \u001b[39mRandom\n", + " \u001b[90m[10745b16] \u001b[39mStatistics v1.10.0\n", + "\u001b[36m\u001b[1mInfo\u001b[22m\u001b[39m Packages marked with \u001b[32m⌃\u001b[39m and \u001b[33m⌅\u001b[39m have new versions available. Those with \u001b[32m⌃\u001b[39m may be upgradable, but those with \u001b[33m⌅\u001b[39m are restricted by compatibility constraints from upgrading. To see why use `status --outdated`\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: using MLJBase.predict in module LaplaceRedux conflicts with an existing identifier.\n", + "WARNING: using MLJBase.fit! in module LaplaceRedux conflicts with an existing identifier.\n" + ] + } + ], + "source": [ + "using Pkg\n", + "\n", + "\n", + "Pkg.activate(\"C:/Users/Pasqu/Documents/julia_projects/LaplaceRedux.jl/docs\")\n", + "Pkg.status()\n", + "using LaplaceRedux\n", + "using MLJBase\n", + "using Random\n", + "using DataFrames\n", + "using Flux" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Tables.MatrixTable{Matrix{Float64}} with 100 rows, 2 columns, and schema:\n", + " :x1 Float64\n", + " :x2 Float64, [-0.15053240354230857, -0.16143107735113107, -0.28104782384528254, 0.8905842690519058, -0.2716955057136559, 0.9606721208381163, 0.14403243794060133, 0.13743002853667605, 0.820641892942472, 0.2270783932443115 … 0.4639933046961763, 0.30449384622096687, 0.2744588755171263, -31.785240173822324, 2.58951832655098, 0.10969223924903307, 0.1600255529817666, 0.5913011997917647, -0.39253898214541977, -0.7247243478011852])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "X, y = make_regression(100, 2; noise=0.5, sparse=0.2, outliers=0.1)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of rows: 100\n", + "Number of columns: 2\n" + ] + } + ], + "source": [ + "using Tables\n", + "# Get the number of rows\n", + "num_rows = Tables.rows(X) |> length\n", + "\n", + "# Get the number of columns\n", + "num_columns = Tables.columnnames(X) |> length\n", + "\n", + "# Display the dimensions\n", + "println(\"Number of rows: \", num_rows)\n", + "println(\"Number of columns: \", num_columns)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Chain(\n", + " Dense(2 => 10, relu), \u001b[90m# 30 parameters\u001b[39m\n", + " Dense(10 => 10, relu), \u001b[90m# 110 parameters\u001b[39m\n", + " Dense(10 => 1), \u001b[90m# 11 parameters\u001b[39m\n", + ") \u001b[90m # Total: 6 arrays, \u001b[39m151 parameters, 988 bytes." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using Flux\n", + "# Define the model\n", + "flux_model = Chain(\n", + " Dense(2, 10, relu),\n", + " Dense(10, 10, relu),\n", + " Dense(10, 1)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LaplaceRegressor(\n", + " model = Chain(Dense(2 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), \n", + " flux_loss = Flux.Losses.mse, \n", + " optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), \n", + " epochs = 1000, \n", + " batch_size = 32, \n", + " subset_of_weights = :all, \n", + " subnetwork_indices = nothing, \n", + " hessian_structure = :full, \n", + " backend = :GGN, \n", + " σ = 1.0, \n", + " μ₀ = 0.0, \n", + " P₀ = nothing, \n", + " fit_prior_nsteps = 100)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "# Create an instance of LaplaceRegressor with the Flux model\n", + "laplace_regressor = LaplaceRegressor(\n", + " model = flux_model,\n", + " subset_of_weights = :all,\n", + " subnetwork_indices = nothing,\n", + " hessian_structure = :full,\n", + " backend = :GGN,\n", + " σ = 1.0,\n", + " μ₀ = 0.0,\n", + " P₀ = nothing,\n", + " fit_prior_nsteps = 100\n", + ")\n", + "\n", + "# Display the LaplaceRegressor object\n", + "laplace_regressor" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tables.MatrixTable{Matrix{Float64}}" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "typeof(X)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
17×5 DataFrame
Rowx1x2x3x4y
Float64Float64Float64Float64Cat…
15.8941912.99796.185128.992862
2-8.148346.23246-1.684978.969051
3-4.882295.35276-0.458768.207561
44.025856.9476913.4032-0.04192232
5-5.536356.55656-1.670638.770411
6-6.618584.65032-1.151988.348971
710.234411.427813.05448.530252
8-6.050528.12027-3.687088.787321
9-5.067694.8631-3.583468.413711
1010.83736.324729.791636.659622
11-6.632265.45149-0.388619.00071
121.628124.6107311.660211.72412
13-6.486796.68166-3.325699.196181
147.955962.2392812.68971.778572
15-6.364665.82985-0.7025028.449761
168.032943.859015.507412.20142
17-7.450677.01011-1.961877.843361
" + ], + "text/latex": [ + "\\begin{tabular}{r|ccccc}\n", + "\t& x1 & x2 & x3 & x4 & y\\\\\n", + "\t\\hline\n", + "\t& Float64 & Float64 & Float64 & Float64 & Cat…\\\\\n", + "\t\\hline\n", + "\t1 & 5.89419 & 12.9979 & 6.18512 & 8.99286 & 2 \\\\\n", + "\t2 & -8.14834 & 6.23246 & -1.68497 & 8.96905 & 1 \\\\\n", + "\t3 & -4.88229 & 5.35276 & -0.45876 & 8.20756 & 1 \\\\\n", + "\t4 & 4.02585 & 6.94769 & 13.4032 & -0.0419223 & 2 \\\\\n", + "\t5 & -5.53635 & 6.55656 & -1.67063 & 8.77041 & 1 \\\\\n", + "\t6 & -6.61858 & 4.65032 & -1.15198 & 8.34897 & 1 \\\\\n", + "\t7 & 10.2344 & 11.4278 & 13.0544 & 8.53025 & 2 \\\\\n", + "\t8 & -6.05052 & 8.12027 & -3.68708 & 8.78732 & 1 \\\\\n", + "\t9 & -5.06769 & 4.8631 & -3.58346 & 8.41371 & 1 \\\\\n", + "\t10 & 10.8373 & 6.32472 & 9.79163 & 6.65962 & 2 \\\\\n", + "\t11 & -6.63226 & 5.45149 & -0.38861 & 9.0007 & 1 \\\\\n", + "\t12 & 1.62812 & 4.61073 & 11.6602 & 11.7241 & 2 \\\\\n", + "\t13 & -6.48679 & 6.68166 & -3.32569 & 9.19618 & 1 \\\\\n", + "\t14 & 7.95596 & 2.23928 & 12.6897 & 1.77857 & 2 \\\\\n", + "\t15 & -6.36466 & 5.82985 & -0.702502 & 8.44976 & 1 \\\\\n", + "\t16 & 8.03294 & 3.85901 & 5.50741 & 2.2014 & 2 \\\\\n", + "\t17 & -7.45067 & 7.01011 & -1.96187 & 7.84336 & 1 \\\\\n", + "\\end{tabular}\n" + ], + "text/plain": [ + "\u001b[1m17×5 DataFrame\u001b[0m\n", + "\u001b[1m Row \u001b[0m│\u001b[1m x1 \u001b[0m\u001b[1m x2 \u001b[0m\u001b[1m x3 \u001b[0m\u001b[1m x4 \u001b[0m\u001b[1m y \u001b[0m\n", + " │\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Float64 \u001b[0m\u001b[90m Cat… \u001b[0m\n", + "─────┼─────────────────────────────────────────────────\n", + " 1 │ 5.89419 12.9979 6.18512 8.99286 2\n", + " 2 │ -8.14834 6.23246 -1.68497 8.96905 1\n", + " 3 │ -4.88229 5.35276 -0.45876 8.20756 1\n", + " 4 │ 4.02585 6.94769 13.4032 -0.0419223 2\n", + " 5 │ -5.53635 6.55656 -1.67063 8.77041 1\n", + " 6 │ -6.61858 4.65032 -1.15198 8.34897 1\n", + " 7 │ 10.2344 11.4278 13.0544 8.53025 2\n", + " 8 │ -6.05052 8.12027 -3.68708 8.78732 1\n", + " 9 │ -5.06769 4.8631 -3.58346 8.41371 1\n", + " 10 │ 10.8373 6.32472 9.79163 6.65962 2\n", + " 11 │ -6.63226 5.45149 -0.38861 9.0007 1\n", + " 12 │ 1.62812 4.61073 11.6602 11.7241 2\n", + " 13 │ -6.48679 6.68166 -3.32569 9.19618 1\n", + " 14 │ 7.95596 2.23928 12.6897 1.77857 2\n", + " 15 │ -6.36466 5.82985 -0.702502 8.44976 1\n", + " 16 │ 8.03294 3.85901 5.50741 2.2014 2\n", + " 17 │ -7.45067 7.01011 -1.96187 7.84336 1" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using MLJ, DataFrames\n", + "X, y = make_blobs(100, 4; centers=2, cluster_std=[1.0, 3.0 ])\n", + "dfBlobs = DataFrame(X)\n", + "dfBlobs.y = y\n", + "first(dfBlobs, 17)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Chain(\n", + " Dense(4 => 10, relu), \u001b[90m# 50 parameters\u001b[39m\n", + " Dense(10 => 10, relu), \u001b[90m# 110 parameters\u001b[39m\n", + " Dense(10 => 2), \u001b[90m# 22 parameters\u001b[39m\n", + ") \u001b[90m # Total: 6 arrays, \u001b[39m182 parameters, 1.086 KiB." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using Flux\n", + "# Define the model\n", + "flux_model = Chain(\n", + " Dense(4, 10, relu),\n", + " Dense(10, 10, relu),\n", + " Dense(10, 2)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LaplaceClassifier(\n", + " model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 2)), \n", + " flux_loss = Flux.Losses.logitcrossentropy, \n", + " optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), \n", + " epochs = 1000, \n", + " batch_size = 32, \n", + " subset_of_weights = :all, \n", + " subnetwork_indices = nothing, \n", + " hessian_structure = :full, \n", + " backend = :GGN, \n", + " σ = 1.0, \n", + " μ₀ = 0.0, \n", + " P₀ = nothing, \n", + " fit_prior_nsteps = 100, \n", + " link_approx = :probit)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Create an instance of LaplaceRegressor with the Flux model\n", + "laplace_classifier = LaplaceClassifier(\n", + " model = flux_model,\n", + " subset_of_weights = :all,\n", + " subnetwork_indices = nothing,\n", + " hessian_structure = :full,\n", + " backend = :GGN,\n", + " σ = 1.0,\n", + " μ₀ = 0.0,\n", + " P₀ = nothing,\n", + " fit_prior_nsteps = 100\n", + ")\n", + "\n", + "# Display the LaplaceRegressor object\n", + "laplace_classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([5.894190948229072 -8.148338397036424 … 4.6901216310216896 8.754850004556681; 12.99791729188348 6.232458473633034 … 10.904369405648966 12.187158715203248; 6.185122057812649 -1.6849673572899138 … 10.09295616393206 15.616135680604215; 8.992861668498659 8.969052840985539 … 8.673387526088336 3.444822170655238], (Bool[1 0 … 1 1; 0 1 … 0 0], CategoricalArrays.CategoricalValue{Int64, UInt32} 2))" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "X_test_reformat,y_test_reformat= MLJBase.reformat(laplace_classifier,X,y)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2×100 OneHotMatrix(::Vector{UInt32}) with eltype Bool:\n", + " 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ … 1 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1\n", + " ⋅ 1 1 ⋅ 1 1 ⋅ 1 1 ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 1 1 1 1 1 ⋅ ⋅" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "y_test_reformat[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2×3 view(OneHotMatrix(::Vector{UInt32}), :, [2, 3, 4]) with eltype Bool:\n", + " 0 0 1\n", + " 1 1 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "view(y_test_reformat[1],:, [2,3,4])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4×3 view(::Matrix{Float64}, :, [2, 3, 4]) with eltype Float64:\n", + " -8.14834 -4.88229 4.02585\n", + " 6.23246 5.35276 6.94769\n", + " -1.68497 -0.45876 13.4032\n", + " 8.96905 8.20756 -0.0419223" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "view(X_test_reformat, :, [2,3,4])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((sepal_length = [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9 … 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9], sepal_width = [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1 … 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0], petal_length = [1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5 … 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1], petal_width = [0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1 … 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]), CategoricalArrays.CategoricalValue{String, UInt32}[\"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\" … \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\", \"virginica\"])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using MLJ\n", + "#DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree\n", + "\n", + "using Flux\n", + "# Define the model\n", + "flux_model = Chain(\n", + " Dense(4, 10, relu),\n", + " Dense(10, 10, relu),\n", + " Dense(10, 3)\n", + ")\n", + "\n", + "model = LaplaceClassifier(model=flux_model)\n", + "\n", + "X, y = @load_iris\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3-element Vector{String}:\n", + " \"setosa\"\n", + " \"versicolor\"\n", + " \"virginica\"" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "levels(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Training machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:499\n", + "┌ Warning: Layer with Float32 parameters got Float64 input.\n", + "│ The input will be converted, but any earlier layers may be very slow.\n", + "│ layer = Dense(4 => 10, relu)\n", + "│ summary(x) = 4×32 Matrix{Float64}\n", + "└ @ Flux C:\\Users\\Pasqu\\.julia\\packages\\Flux\\HBF2N\\src\\layers\\stateless.jl:60\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 100: Loss: 4.042439937591553 \n", + "Epoch 200: Loss: 3.533539354801178 \n", + "Epoch 300: Loss: 3.2006053030490875 \n", + "Epoch 400: Loss: 2.9201304614543915 \n", + "Epoch 500: Loss: 2.647124856710434 \n", + "Epoch 600: Loss: 2.3577273190021515 \n", + "Epoch 700: Loss: 2.0284294933080673 \n", + "Epoch 800: Loss: 1.6497879326343536 \n", + "Epoch 900: Loss: 1.3182911276817322 \n", + "Epoch 1000: Loss: 1.077766239643097 \n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n", + " args: \n", + " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mach = machine(model, X, y) |> MLJBase.fit!" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3-element Vector{Float64}:\n", + " 0.9488185211171959\n", + " 0.7685919895442062\n", + " 0.9454937120287679" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "Xnew = (sepal_length = [6.4, 7.2, 7.4],\n", + " sepal_width = [2.8, 3.0, 2.8],\n", + " petal_length = [5.6, 5.8, 6.1],\n", + " petal_width = [2.1, 1.6, 1.9],)\n", + "yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions\n", + "predict_mode(mach, Xnew) # point predictions\n", + "pdf.(yhat, \"virginica\") # probabilities for the \"verginica\" class" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CV(\n", + " nfolds = 3, \n", + " shuffle = false, \n", + " rng = TaskLocalRNG())" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using Random\n", + "cv = CV(nfolds=3)\n", + "CV(\n", + " nfolds = 3,\n", + " shuffle = false,\n", + " rng = Random.default_rng())" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(μ = Float32[0.36796558, -0.590969, -0.0005966219, -0.1918786, -0.21046183, 0.33172563, -0.6454487, -0.43057427, 0.09267368, -0.31252357 … 0.31747338, -0.42893136, -0.036560666, 0.12848923, -0.022290554, -0.17562295, 0.10335993, 2.0584264, -0.27860576, -0.9749556],\n", + " H = [19221.80854034424 0.0 … 175.01133728027344 416.5687526464462; 0.0 0.0 … 0.0 0.0; … ; 175.01133728027344 0.0 … 150.0 0.0; 416.5687526464462 0.0 … 0.0 150.0],\n", + " P = [19222.80854034424 0.0 … 175.01133728027344 416.5687526464462; 0.0 1.0 … 0.0 0.0; … ; 175.01133728027344 0.0 … 151.0 0.0; 416.5687526464462 0.0 … 0.0 151.0],\n", + " Σ = [0.4942412474220449 0.0 … 0.00012692197672594825 -0.00029493662610333626; 0.0 1.0 … -0.0 -0.0; … ; 0.00012692197672495198 0.0 … 0.019500175331854216 2.549733600297664e-5; -0.000294936626108007 0.0 … 2.5497336002974752e-5 0.019426448013349775],\n", + " n_data = 160,\n", + " n_params = 193,\n", + " n_out = 3,\n", + " loss = 76.21374633908272,)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "MLJBase.fitted_params(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1100: Loss: 0.9005963057279587 \n", + "Epoch 1200: Loss: 0.7674828618764877 \n", + "Epoch 1300: Loss: 0.6668129041790962 " + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 1400: Loss: 0.5900264084339142 \n", + "Epoch 1500: Loss: 0.5311783701181412 \n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n", + " args: \n", + " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.epochs= 1500\n", + "uff=MLJBase.fit!(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " updating only the laplace optimization part\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceClassifier(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 3)), …)\n", + " args: \n", + " 1:\tSource @378 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @187 ⏎ AbstractVector{Multiclass{3}}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.fit_prior_nsteps = 200\n", + "uff=MLJBase.fit!(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Training machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:499\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 100: Loss: 123.99619626934783 \n", + "Epoch 200: Loss: 110.70134709925446 \n", + "Epoch 300: Loss: 101.54836914035226 \n", + "Epoch 400: Loss: 93.59849316803738 \n", + "Epoch 500: Loss: 85.18775669356168 \n", + "Epoch 600: Loss: 75.76182416873147 \n", + "Epoch 700: Loss: 67.35201676247976 \n", + "Epoch 800: Loss: 61.22682561405186 \n", + "Epoch 900: Loss: 56.388587609986764 \n", + "Epoch 1000: Loss: 51.8525101259686 \n" + ] + }, + { + "data": { + "text/plain": [ + "1×3 Matrix{Float64}:\n", + " -2.0295 -1.03661 1.29323" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using MLJ\n", + "#LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux\n", + "flux_model = Chain(\n", + " Dense(4, 10, relu),\n", + " Dense(10, 10, relu),\n", + " Dense(10, 1)\n", + ")\n", + "model = LaplaceRegressor(model=flux_model)\n", + "\n", + "X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)\n", + "mach = machine(model, X, y) \n", + "uff=MLJBase.fit!(mach)\n", + "Xnew, _ = make_regression(3, 4; rng=123)\n", + "yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions\n", + "MLJBase.predict_mode(mach, Xnew) # point predictions\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(μ = Float32[1.0187618, 1.1801981, 1.5047036, 0.09397529, -1.0688609, -0.59974676, 1.2977643, -1.8645017, -0.6493797, 0.8598072 … -3.1050463, -2.267896, 1.018275, -2.0316908, 0.72548866, -2.7309833, -2.1938443, 0.8200029, 1.2517836, 0.3313799],\n", + " H = [8124.873962402344 1238.1352081298828 … 1980.7237396240234 324.216007232666; 1238.1352081298828 496.2319145202637 … 451.8493309020996 62.39444828033447; … ; 1980.7237396240234 451.8493309020996 … 1277.6469421386719 212.83043670654297; 324.216007232666 62.39444828033447 … 212.83043670654297 100.0],\n", + " P = [8125.873962402344 1238.1352081298828 … 1980.7237396240234 324.216007232666; 1238.1352081298828 497.2319145202637 … 451.8493309020996 62.39444828033447; … ; 1980.7237396240234 451.8493309020996 … 1278.6469421386719 212.83043670654297; 324.216007232666 62.39444828033447 … 212.83043670654297 101.0],\n", + " Σ = [0.08818398027482557 0.01068994964485469 … -0.028911073548363073 0.0029380249323247053; 0.01068994964484666 0.3740392430965252 … -0.03321343036791264 0.03808189340124685; … ; -0.02891107354841918 -0.03321343036788035 … 0.37993018751981145 -0.007007167070441715; 0.002938024932320219 0.03808189340124439 … -0.007007167070367749 0.8779610764649773],\n", + " n_data = 128,\n", + " n_params = 171,\n", + " n_out = 1,\n", + " loss = 25.89132467732224,)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "MLJBase.fitted_params(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1000-element Vector{Any}:\n", + " 132.4757012476366\n", + " 132.34841820836348\n", + " 132.2594675536756\n", + " 132.1764167781873\n", + " 132.09720400291872\n", + " 132.02269947440837\n", + " 131.95014241005182\n", + " 131.87929837827625\n", + " 131.8104252672236\n", + " 131.74319014461662\n", + " ⋮\n", + " 52.175198365270234\n", + " 52.148401306572445\n", + " 52.10149557862212\n", + " 52.032551867920525\n", + " 51.99690913470439\n", + " 51.994054497037524\n", + " 51.97321907106968\n", + " 51.932579915942064\n", + " 51.8525101259686" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "a= MLJBase.training_losses(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1100: Loss: 47.37399118448011 \n", + "Epoch 1200: Loss: 42.67007994813126 \n", + "Epoch 1300: Loss: 38.057605481177234 \n", + "Epoch 1400: Loss: 33.846777755430416 \n", + "Epoch 1500: Loss: 29.514834265471144 \n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n", + " args: \n", + " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.epochs= 1500\n", + "uff=MLJBase.fit!(mach)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1500-element Vector{Any}:\n", + " 132.4757012476366\n", + " 132.34841820836348\n", + " 132.2594675536756\n", + " 132.1764167781873\n", + " 132.09720400291872\n", + " 132.02269947440837\n", + " 131.95014241005182\n", + " 131.87929837827625\n", + " 131.8104252672236\n", + " 131.74319014461662\n", + " ⋮\n", + " 29.979098347946184\n", + " 29.896050775635096\n", + " 29.83609799111873\n", + " 29.763058304997937\n", + " 29.69908817371154\n", + " 29.646864259231556\n", + " 29.623977677653713\n", + " 29.549562760589527\n", + " 29.514834265471144" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "MLJBase.training_losses(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n", + " args: \n", + " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.epochs= 1200\n", + "uff=MLJBase.fit!(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " updating only the laplace optimization part\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Updating machine(LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …), …).\n", + "└ @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:500\n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; caches model-specific representations of data\n", + " model: LaplaceRegressor(model = Chain(Dense(4 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1)), …)\n", + " args: \n", + " 1:\tSource @447 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @645 ⏎ AbstractVector{Continuous}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.fit_prior_nsteps = 200\n", + "uff=MLJBase.fit!(mach)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "MLJBase.selectrows(model, I, Xmatrix, y) = (view(Xmatrix, :, I), view(y[1], I))" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "#MLJBase.evaluate(model, X, y, resampling=cv, measure=l2, verbosity=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "ename": "BoundsError", + "evalue": "BoundsError: attempt to access 4×100 Matrix{Float64} at index [[35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], 1:100]", + "output_type": "error", + "traceback": [ + "BoundsError: attempt to access 4×100 Matrix{Float64} at index [[35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100], 1:100]\n", + "\n", + "Stacktrace:\n", + " [1] throw_boundserror(A::Matrix{Float64}, I::Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}})\n", + " @ Base .\\abstractarray.jl:737\n", + " [2] checkbounds\n", + " @ .\\abstractarray.jl:702 [inlined]\n", + " [3] _getindex\n", + " @ .\\multidimensional.jl:888 [inlined]\n", + " [4] getindex\n", + " @ .\\abstractarray.jl:1291 [inlined]\n", + " [5] _selectrows(::MLJModelInterface.FullInterface, ::Val{:other}, X::Matrix{Float64}, r::Vector{Int64})\n", + " @ MLJModelInterface C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:350\n", + " [6] selectrows\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:340 [inlined]\n", + " [7] selectrows\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\data_utils.jl:336 [inlined]\n", + " [8] #29\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\model_api.jl:88 [inlined]\n", + " [9] map\n", + " @ .\\tuple.jl:292 [inlined]\n", + " [10] selectrows(::LaplaceRegressor, ::Vector{Int64}, ::Matrix{Float64}, ::Tuple{Matrix{Float64}, Nothing})\n", + " @ MLJModelInterface C:\\Users\\Pasqu\\.julia\\packages\\MLJModelInterface\\y9x5A\\src\\model_api.jl:88\n", + " [11] fit_only!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}; rows::Vector{Int64}, verbosity::Int64, force::Bool, composite::Nothing)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:676\n", + " [12] fit_only!\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:617 [inlined]\n", + " [13] #fit!#63\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:789 [inlined]\n", + " [14] fit!\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\machines.jl:786 [inlined]\n", + " [15] fit_and_extract_on_fold\n", + " @ C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1463 [inlined]\n", + " [16] (::MLJBase.var\"#277#278\"{MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, Machine{LaplaceRegressor, LaplaceRegressor, true}, Int64})(k::Int64)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1289\n", + " [17] _mapreduce(f::MLJBase.var\"#277#278\"{MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, Machine{LaplaceRegressor, LaplaceRegressor, true}, Int64}, op::typeof(vcat), ::IndexLinear, A::UnitRange{Int64})\n", + " @ Base .\\reduce.jl:440\n", + " [18] _mapreduce_dim\n", + " @ .\\reducedim.jl:365 [inlined]\n", + " [19] mapreduce\n", + " @ .\\reducedim.jl:357 [inlined]\n", + " [20] _evaluate!(func::MLJBase.var\"#fit_and_extract_on_fold#304\"{Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, Nothing, Nothing, Int64, Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, Vector{typeof(predict_mean)}, Bool, Bool, Vector{Float64}}, mach::Machine{LaplaceRegressor, LaplaceRegressor, true}, ::CPU1{Nothing}, nfolds::Int64, verbosity::Int64)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1288\n", + " [21] evaluate!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}, resampling::Vector{Tuple{Vector{Int64}, UnitRange{Int64}}}, weights::Nothing, class_weights::Nothing, rows::Nothing, verbosity::Int64, repeats::Int64, measures::Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, operations::Vector{typeof(predict_mean)}, acceleration::CPU1{Nothing}, force::Bool, per_observation_flag::Bool, logger::Nothing, user_resampling::CV, compact::Bool)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1510\n", + " [22] evaluate!(::Machine{LaplaceRegressor, LaplaceRegressor, true}, ::CV, ::Nothing, ::Nothing, ::Nothing, ::Int64, ::Int64, ::Vector{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}}}, ::Vector{typeof(predict_mean)}, ::CPU1{Nothing}, ::Bool, ::Bool, ::Nothing, ::CV, ::Bool)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1603\n", + " [23] evaluate!(mach::Machine{LaplaceRegressor, LaplaceRegressor, true}; resampling::CV, measures::Nothing, measure::StatisticalMeasuresBase.FussyMeasure{StatisticalMeasuresBase.RobustMeasure{StatisticalMeasuresBase.Multimeasure{StatisticalMeasuresBase.SupportsMissingsMeasure{StatisticalMeasures.LPLossOnScalars{Int64}}, Nothing, StatisticalMeasuresBase.Mean, typeof(identity)}}, Nothing}, weights::Nothing, class_weights::Nothing, operations::Nothing, operation::Nothing, acceleration::CPU1{Nothing}, rows::Nothing, repeats::Int64, force::Bool, check_measure::Bool, per_observation::Bool, verbosity::Int64, logger::Nothing, compact::Bool)\n", + " @ MLJBase C:\\Users\\Pasqu\\.julia\\packages\\MLJBase\\7nGJF\\src\\resampling.jl:1232\n", + " [24] top-level scope\n", + " @ c:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl\\dev\\notebooks\\mlj-interfacing\\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X54sZmlsZQ==.jl:1" + ] + } + ], + "source": [ + "evaluate!(mach, resampling=cv, measure=l2, verbosity=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "#flux_model = Chain(\n", + " #Dense(4, 10, relu),\n", + " #Dense(10, 10, relu),\n", + " #Dense(10, 1))\n", + "\n", + "\n", + "#nested_flux_model = Chain(\n", + " #Chain(Dense(10, 5, relu), Dense(5, 5, relu)),\n", + " #Chain(Dense(5, 3, relu), Dense(3, 2)))\n", + "#model = LaplaceRegressor()\n", + "#model.model= nested_flux_model\n", + "\n", + "#copy_model= deepcopy(model)\n", + "\n", + "#copy_model.epochs= 2000\n", + "#copy_model.optimiser = Descent()\n", + "#MLJBase.is_same_except(model , copy_model,:epochs )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.10.5", + "language": "julia", + "name": "julia-1.10" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.10.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 6d42a1e2..d01b6356 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.5" +julia_version = "1.11.1" manifest_format = "2.0" -project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca" +project_hash = "aafefca0e42df191a3fa625af8a8b686a9db4944" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -26,31 +26,35 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" version = "0.4.5" [[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"] +git-tree-sha1 = "b392ede862e506d451fc1616e79aa6f4c673dab8" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" +version = "0.1.38" [deps.Accessors.extensions] AccessorsAxisKeysExt = "AxisKeys" + AccessorsDatesExt = "Dates" AccessorsIntervalSetsExt = "IntervalSets" AccessorsStaticArraysExt = "StaticArrays" AccessorsStructArraysExt = "StructArrays" + AccessorsTestExt = "Test" AccessorsUnitfulExt = "Unitful" [deps.Accessors.weakdeps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +git-tree-sha1 = "d80af0733c99ea80575f612813fa6aa71022d33a" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" +version = "4.1.0" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -69,7 +73,7 @@ version = "2.3.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" +version = "1.1.2" [[deps.Arpack]] deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"] @@ -115,6 +119,7 @@ version = "7.16.0" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" [[deps.Atomix]] deps = ["UnsafeAtomics"] @@ -157,6 +162,7 @@ version = "0.4.3" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" [[deps.Baselet]] git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" @@ -170,9 +176,9 @@ version = "0.1.9" [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" +git-tree-sha1 = "8873e196c2eb87962a2048b3b8e08946535864a1" uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+1" +version = "1.0.8+2" [[deps.CEnum]] git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" @@ -181,15 +187,15 @@ version = "0.5.0" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +git-tree-sha1 = "deddd8725e5e1cc49ee205a1964256043720a6c3" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" +version = "0.10.15" [[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] -git-tree-sha1 = "fdd9dfb67dfefd548f51000cc400bb51003de247" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics", "demumble_jll"] +git-tree-sha1 = "e0725a467822697171af4dae15cec10b4fc19053" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.4.3" +version = "5.5.2" [deps.CUDA.extensions] ChainRulesCoreExt = "ChainRulesCore" @@ -203,9 +209,9 @@ version = "5.4.3" [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "325058b426c2b421e3d2df3d5fa646d72d2e3e7e" +git-tree-sha1 = "ccd1e54610c222fadfd4737dac66bff786f63656" uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.9.2+0" +version = "0.10.3+0" [[deps.CUDA_Runtime_Discovery]] deps = ["Libdl"] @@ -215,21 +221,21 @@ version = "0.3.5" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "afea94249b821dc754a8ca6695d3daed851e1f5a" +git-tree-sha1 = "e43727b237b2879a34391eeb81887699a26f8f2f" uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.14.1+0" +version = "0.15.3+0" [[deps.CUDNN_jll]] deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4" +git-tree-sha1 = "9851af16a2f357a793daa0f13634c82bc7e40419" uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" -version = "9.0.0+1" +version = "9.4.0+0" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" +git-tree-sha1 = "009060c9a6168704143100f36ab08f06c2af4642" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+2" +version = "1.18.2+1" [[deps.CategoricalArrays]] deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] @@ -263,15 +269,15 @@ version = "0.1.15" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +git-tree-sha1 = "be227d253d132a6d57f9ccf5f67c0fb6488afd87" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" +version = "1.71.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" +version = "1.25.0" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] @@ -285,9 +291,9 @@ version = "0.7.6" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +git-tree-sha1 = "13951eb68769ad1cd460cdb2e64e5e95f1bf123d" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" +version = "3.27.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -411,10 +417,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.16.0" [[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "fb61b4812c49343d7ef0b533ba982c46021938a6" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" +version = "1.7.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -430,6 +436,7 @@ version = "1.0.0" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" [[deps.Dbus_jll]] deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"] @@ -486,12 +493,13 @@ weakdeps = ["ChainRulesCore", "SparseArrays"] [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +version = "1.11.0" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "e6c693a0e4394f8fda0e51a5bdf5aef26f8235e9" +git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.111" +version = "0.25.112" weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] [deps.Distributions.extensions] @@ -499,6 +507,12 @@ weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] DistributionsDensityInterfaceExt = "DensityInterface" DistributionsTestExt = "Test" +[[deps.DocInventories]] +deps = ["CodecZlib", "Downloads", "TOML"] +git-tree-sha1 = "e97cfa8680a39396924dcdca4b7ff1014ed5c499" +uuid = "43dc2714-ed3b-44b5-b226-857eda1aa7de" +version = "1.0.0" + [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" @@ -507,9 +521,21 @@ version = "0.9.3" [[deps.Documenter]] deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] -git-tree-sha1 = "9d29b99b6b2b6bc2382a4c8dbec6eb694f389853" +git-tree-sha1 = "5a1ee886566f2fa9318df1273d8b778b9d42712d" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "1.6.0" +version = "1.7.0" + +[[deps.DocumenterInterLinks]] +deps = ["CodecZlib", "DocInventories", "Documenter", "DocumenterInventoryWritingBackport", "Markdown", "MarkdownAST", "TOML"] +git-tree-sha1 = "00dceb038f6cb24f4d8d6a9f2feb85bbe58305fd" +uuid = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" +version = "1.0.0" + +[[deps.DocumenterInventoryWritingBackport]] +deps = ["CodecZlib", "Documenter", "TOML"] +git-tree-sha1 = "1b89024e375353961bb98b9818b44a4e38961cc4" +uuid = "195adf08-069f-4855-af3e-8933a2cdae94" +version = "0.1.0" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -541,9 +567,9 @@ version = "0.1.10" [[deps.FFMPEG]] deps = ["FFMPEG_jll"] -git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" +git-tree-sha1 = "53ebe7511fa11d33bec688a9178fac4e49eeee00" uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" -version = "0.4.1" +version = "0.4.2" [[deps.FFMPEG_jll]] deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] @@ -565,9 +591,9 @@ version = "0.1.1" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +git-tree-sha1 = "62ca0547a14c57e98154423419d8a342dca75ca9" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" +version = "1.16.4" [[deps.FilePathsBase]] deps = ["Compat", "Dates"] @@ -582,6 +608,7 @@ weakdeps = ["Mmap", "Test"] [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" [[deps.FillArrays]] deps = ["LinearAlgebra"] @@ -596,19 +623,21 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"] FillArraysStatisticsExt = "Statistics" [[deps.FiniteDiff]] -deps = ["ArrayInterface", "LinearAlgebra", "Setfield", "SparseArrays"] -git-tree-sha1 = "f9219347ebf700e77ca1d48ef84e4a82a6701882" +deps = ["ArrayInterface", "LinearAlgebra", "Setfield"] +git-tree-sha1 = "b10bdafd1647f57ace3885143936749d61638c3b" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.24.0" +version = "2.26.0" [deps.FiniteDiff.extensions] FiniteDiffBandedMatricesExt = "BandedMatrices" FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" + FiniteDiffSparseArraysExt = "SparseArrays" FiniteDiffStaticArraysExt = "StaticArrays" [deps.FiniteDiff.weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.FixedPointNumbers]] @@ -618,23 +647,27 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.5" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "fbf100b4bed74c9b6fac0ebd1031e04977d35b3b" +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.19" +version = "0.14.22" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" FluxCUDAExt = "CUDA" FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] FluxEnzymeExt = "Enzyme" + FluxMPIExt = "MPI" + FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] FluxMetalExt = "Metal" [deps.Flux.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.Fontconfig_jll]] @@ -679,6 +712,7 @@ version = "0.4.12" [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" +version = "1.11.0" [[deps.GLFW_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"] @@ -699,22 +733,22 @@ uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.6" [[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "ab29216184312f99ff957b32cd63c2fe9c928b91" +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "PrecompileTools", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "1d6f290a5eb1201cd63574fbc4440c788d5cb38f" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.26.7" +version = "0.27.8" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] -git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5" +git-tree-sha1 = "ee28ddcd5517d54e417182fec3886e7412d3926f" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.73.7" +version = "0.73.8" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d" +git-tree-sha1 = "f31929b9e67066bee48eec8b03c0df47d31a74b3" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.73.7+0" +version = "0.73.8+0" [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] @@ -730,15 +764,15 @@ version = "1.3.1" [[deps.Git_jll]] deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809" +git-tree-sha1 = "ea372033d09e4552a04fd38361cd019f9003f4f4" uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" -version = "2.44.0+2" +version = "2.46.2+0" [[deps.Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" +git-tree-sha1 = "674ff0db93fffcd11a3573986e550d66cd4fd71f" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.2+0" +version = "2.80.5+0" [[deps.Graphite2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -814,11 +848,12 @@ version = "1.4.2" [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" [[deps.InverseFunctions]] -git-tree-sha1 = "2787db24f4e03daf859c6509ff87764e4182f7d1" +git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.16" +version = "0.1.17" weakdeps = ["Dates", "Test"] [deps.InverseFunctions.extensions] @@ -848,9 +883,9 @@ version = "1.0.0" [[deps.JLD2]] deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] -git-tree-sha1 = "a0746c21bdc986d0dc293efa6b1faee112c37c28" +git-tree-sha1 = "b464b9b461ee989b435a689a4f7d870b68d467ed" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.53" +version = "0.5.6" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] @@ -860,9 +895,9 @@ version = "0.1.8" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40" +git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.6.0" +version = "1.6.1" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -872,9 +907,9 @@ version = "0.21.4" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" +git-tree-sha1 = "25ee0be4d43d0269027024d75a24c24d6c6e590c" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.3+0" +version = "3.0.4+0" [[deps.JuliaNVTXCallbacks_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -890,9 +925,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "cb1cff88ef2f3a157cbad75bbe6b229e1975e498" +git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.25" +version = "0.9.28" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -911,16 +946,16 @@ uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" version = "3.100.2+0" [[deps.LERC_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "36bdbc52f13a7d1dcb0f3cd694e01677a515655b" uuid = "88015f11-f218-50d7-93a8-a6af411a945d" -version = "3.0.0+1" +version = "4.0.0+0" [[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "2470e69781ddd70b8878491233cd09bc1bd7fc96" +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Unicode"] +git-tree-sha1 = "d422dfd9707bec6617335dc2ea3c5172a87d5908" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.1.0" +version = "9.1.3" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] @@ -928,9 +963,9 @@ weakdeps = ["BFloat16s"] [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "597d1c758c9ae5d985ba4202386a607c675ee700" +git-tree-sha1 = "05a8bd5a42309a9ec82f700876903abce1017dd3" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.31+0" +version = "0.0.34+0" [[deps.LLVMLoopInfo]] git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" @@ -939,26 +974,26 @@ version = "1.0.0" [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e16271d212accd09d52ee0ae98956b8a05c4b626" +git-tree-sha1 = "78211fb6cbc872f77cad3fc0b6cf647d923f4929" uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" -version = "17.0.6+0" +version = "18.1.7+0" [[deps.LZO_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" +git-tree-sha1 = "854a9c268c43b77b0a27f22d7fab8d33cdb3a731" uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.2+0" +version = "2.10.2+1" [[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" +version = "1.4.0" [[deps.LaplaceRedux]] -deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] -git-tree-sha1 = "a84b72a27c93c72a6af5d22216eb81a419b1b97a" +deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] +path = ".." uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" -version = "1.0.2" +version = "1.1.1" [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] @@ -984,6 +1019,7 @@ version = "1.2.2" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +version = "1.11.0" [[deps.LearnAPI]] deps = ["InteractiveUtils", "Statistics"] @@ -999,16 +1035,17 @@ version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" +version = "8.6.0+0" [[deps.LibGit2]] deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" [[deps.LibGit2_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" +version = "1.7.2+0" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] @@ -1017,6 +1054,7 @@ version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" [[deps.Libffi_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1056,9 +1094,9 @@ version = "2.40.1+0" [[deps.Libtiff_jll]] deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] -git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a" +git-tree-sha1 = "b404131d06f7886402758c9ce2214b636eb4d54a" uuid = "89763e89-9b03-5906-acba-b20f662cd828" -version = "4.5.1+1" +version = "4.7.0+0" [[deps.Libuuid_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1075,6 +1113,7 @@ version = "7.3.0" [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.11.0" [[deps.LinearMaps]] deps = ["LinearAlgebra"] @@ -1106,6 +1145,7 @@ version = "0.3.28" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" [[deps.LoggingExtras]] deps = ["Dates", "Logging"] @@ -1113,6 +1153,46 @@ git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.0.3" +[[deps.MLDataDevices]] +deps = ["Adapt", "Compat", "Functors", "LinearAlgebra", "Preferences", "Random"] +git-tree-sha1 = "3207c2e66164e6366440ad3f0243a8d67abb4a47" +uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +version = "1.4.1" + + [deps.MLDataDevices.extensions] + MLDataDevicesAMDGPUExt = "AMDGPU" + MLDataDevicesCUDAExt = "CUDA" + MLDataDevicesChainRulesCoreExt = "ChainRulesCore" + MLDataDevicesFillArraysExt = "FillArrays" + MLDataDevicesGPUArraysExt = "GPUArrays" + MLDataDevicesMLUtilsExt = "MLUtils" + MLDataDevicesMetalExt = ["GPUArrays", "Metal"] + MLDataDevicesReactantExt = "Reactant" + MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" + MLDataDevicesReverseDiffExt = "ReverseDiff" + MLDataDevicesSparseArraysExt = "SparseArrays" + MLDataDevicesTrackerExt = "Tracker" + MLDataDevicesZygoteExt = "Zygote" + MLDataDevicescuDNNExt = ["CUDA", "cuDNN"] + MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] + + [deps.MLDataDevices.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" + [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9" @@ -1175,6 +1255,7 @@ version = "0.5.13" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" [[deps.MarkdownAST]] deps = ["AbstractTrees", "Markdown"] @@ -1191,7 +1272,7 @@ version = "1.1.9" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" +version = "2.28.6+0" [[deps.Measures]] git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" @@ -1200,9 +1281,9 @@ version = "0.3.2" [[deps.Metalhead]] deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"] -git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152" +git-tree-sha1 = "aef476e4958303f5ea9e1deb81a1ba2f510d4e11" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.9.3" +version = "0.9.4" weakdeps = ["CUDA"] [deps.Metalhead.extensions] @@ -1222,6 +1303,7 @@ version = "1.2.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" [[deps.Mocking]] deps = ["Compat", "ExprTools"] @@ -1231,7 +1313,7 @@ version = "0.8.1" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" +version = "2023.12.12" [[deps.MultivariateStats]] deps = ["Arpack", "Distributions", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] @@ -1246,10 +1328,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" version = "7.8.3" [[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "ae52c156a63bb647f80c26319b104e99e5977e51" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "da09a1e112fd75f9af2a5229323f01b56ec96a4c" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.22" +version = "0.9.24" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1257,12 +1339,14 @@ version = "0.9.22" NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" + NNlibForwardDiffExt = "ForwardDiff" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NVTX]] @@ -1302,9 +1386,9 @@ version = "0.2.3" [[deps.NearestNeighbors]] deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" +git-tree-sha1 = "3cebfc94a0754cc329ebc3bab1e6c89621e791ad" uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.18" +version = "0.4.20" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -1325,7 +1409,7 @@ version = "0.2.5" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" +version = "0.3.27+1" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -1340,9 +1424,9 @@ version = "1.4.3" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" +version = "3.0.15+1" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1432,9 +1516,13 @@ uuid = "30392449-352a-5448-841d-b1acce4e97dc" version = "0.43.4+0" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" +version = "1.11.0" +weakdeps = ["REPL"] + + [deps.Pkg.extensions] + REPLExt = "REPL" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] @@ -1443,10 +1531,10 @@ uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" version = "3.2.0" [[deps.PlotUtils]] -deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" +deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"] +git-tree-sha1 = "650a022b2ce86c7dcfbdecf00f78afeeb20e5655" uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" -version = "1.4.1" +version = "1.4.2" [[deps.Plots]] deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] @@ -1499,13 +1587,14 @@ version = "0.2.0" [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +git-tree-sha1 = "1101cd475833706e4d0e7b122218257178f48f34" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" +version = "2.4.0" [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" [[deps.ProgressLogging]] deps = ["Logging", "SHA", "UUIDs"] @@ -1550,9 +1639,9 @@ version = "6.7.1+1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "1d587203cf851a51bf1ea31ad7ff89eff8d625ea" +git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.11.0" +version = "2.11.1" [deps.QuadGK.extensions] QuadGKEnzymeExt = "Enzyme" @@ -1573,12 +1662,14 @@ uuid = "ce6b1742-4840-55fa-b093-852dadbb1d8b" version = "0.7.7" [[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "StyledStrings", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +version = "1.11.0" [[deps.Random]] deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" [[deps.Random123]] deps = ["Random", "RandomNumbers"] @@ -1647,15 +1738,15 @@ version = "1.3.0" [[deps.Rmath]] deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" +version = "0.8.0" [[deps.Rmath_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e60724fd3beea548353984dc61c943ecddb0e29a" +git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.3+0" +version = "0.5.1+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -1692,6 +1783,7 @@ version = "1.4.5" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" [[deps.Setfield]] deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] @@ -1711,9 +1803,9 @@ uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" version = "1.0.3" [[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +git-tree-sha1 = "f305871d2f381d21527c770d4788c06c097c9bc1" uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" +version = "1.2.0" [[deps.SimpleTraits]] deps = ["InteractiveUtils", "MacroTools"] @@ -1723,6 +1815,7 @@ version = "0.9.4" [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +version = "1.11.0" [[deps.SortingAlgorithms]] deps = ["DataStructures"] @@ -1733,7 +1826,7 @@ version = "1.2.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" +version = "1.11.0" [[deps.SparseInverseSubset]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] @@ -1765,9 +1858,9 @@ version = "1.0.2" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +git-tree-sha1 = "777657803913ffc7e8cc20f0fd04b634f871af8f" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" +version = "1.9.8" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -1792,9 +1885,14 @@ uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" version = "3.4.0" [[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -1810,9 +1908,9 @@ version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" +git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.1" +version = "1.3.2" weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] @@ -1821,9 +1919,9 @@ weakdeps = ["ChainRulesCore", "InverseFunctions"] [[deps.StringManipulation]] deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +git-tree-sha1 = "a6b1675a536c5ad1a60e5a5153e1fee12eb146e3" uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" +version = "0.4.0" [[deps.StructArrays]] deps = ["ConstructionBase", "DataAPI", "Tables"] @@ -1838,6 +1936,10 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] StructArraysSparseArraysExt = "SparseArrays" StructArraysStaticArraysExt = "StaticArrays" +[[deps.StyledStrings]] +uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b" +version = "1.11.0" + [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -1845,7 +1947,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" +version = "7.7.0+0" [[deps.TOML]] deps = ["Dates"] @@ -1854,9 +1956,9 @@ version = "1.0.3" [[deps.TZJData]] deps = ["Artifacts"] -git-tree-sha1 = "1607ad46cf8d642aa779a1d45af1c8620dbf6915" +git-tree-sha1 = "36b40607bf2bf856828690e097e1c799623b0602" uuid = "dc5dba14-91b3-4cab-a142-028a31da12f7" -version = "1.2.0+2024a" +version = "1.3.0+2024b" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -1871,16 +1973,15 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.12.0" [[deps.TaijaBase]] -deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"] -git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0" +git-tree-sha1 = "4076f60078b12095ca71a2c26e2e4515e3a6a5e5" uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6" -version = "1.2.2" +version = "1.2.3" [[deps.TaijaPlotting]] -deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots", "Trapz"] -git-tree-sha1 = "2fc71041e1c215cf6ef3dc2d3b8419499c4b40ff" +deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "LaplaceRedux", "LinearAlgebra", "MLJBase", "MLUtils", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "OneHotArrays", "Plots", "RecipesBase", "Trapz"] +git-tree-sha1 = "01c76535e8b87b05d0fbc275d44714c78e49fe6e" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" -version = "1.2.0" +version = "1.3.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -1896,6 +1997,7 @@ version = "0.1.1" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +version = "1.11.0" [[deps.ThreadsX]] deps = ["Accessors", "ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "SplittablesBase", "Transducers"] @@ -1905,9 +2007,9 @@ version = "0.1.12" [[deps.TimeZones]] deps = ["Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"] -git-tree-sha1 = "b92aebdd3555f3a7e3267cf17702033c2814ef48" +git-tree-sha1 = "8323074bc977aa85cf5ad71099a83ac75b0ac107" uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" -version = "1.18.0" +version = "1.18.1" weakdeps = ["RecipesBase"] [deps.TimeZones.extensions] @@ -1915,22 +2017,23 @@ weakdeps = ["RecipesBase"] [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" +git-tree-sha1 = "3a6f063d690135f5c1ba351412c82bae4d1402bf" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.24" +version = "0.5.25" [[deps.TranscodingStreams]] -git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2" +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.2" +version = "0.11.3" [[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +deps = ["Accessors", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "7deeab4ff96b85c5f72c824cae53a1398da3d1cb" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" +version = "0.4.84" [deps.Transducers.extensions] + TransducersAdaptExt = "Adapt" TransducersBlockArraysExt = "BlockArrays" TransducersDataFramesExt = "DataFrames" TransducersLazyArraysExt = "LazyArrays" @@ -1938,6 +2041,7 @@ version = "0.4.82" TransducersReferenceablesExt = "Referenceables" [deps.Transducers.weakdeps] + Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -1975,6 +2079,7 @@ version = "1.5.1" [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" [[deps.UnPack]] git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" @@ -1983,6 +2088,7 @@ version = "1.0.2" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" [[deps.UnicodeFun]] deps = ["REPL"] @@ -2221,15 +2327,15 @@ version = "1.2.13+1" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" +git-tree-sha1 = "555d1076590a6cc2fdee2ef1469451f872d8b41b" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.6+0" +version = "1.5.6+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +git-tree-sha1 = "f816633be6dc5c0ed9ffedda157ecfda0b3b6a69" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" +version = "0.6.72" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -2249,9 +2355,15 @@ version = "0.2.5" [[deps.cuDNN]] deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] -git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e" +git-tree-sha1 = "4b3ac62501ca73263eaa0d034c772f13c647fba6" uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" -version = "1.3.2" +version = "1.4.0" + +[[deps.demumble_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6498e3581023f8e530f34760d18f75a69e3a4ea8" +uuid = "1e29f10c-031c-5a83-9565-69cddfc27673" +version = "1.3.0+0" [[deps.eudev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "gperf_jll"] @@ -2314,9 +2426,9 @@ version = "1.18.0+0" [[deps.libpng_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" +git-tree-sha1 = "b70c870239dc3d7bc094eb2d6be9b73d27bef280" uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.43+1" +version = "1.6.44+0" [[deps.libvorbis_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] @@ -2333,7 +2445,7 @@ version = "1.1.6+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" +version = "1.59.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] diff --git a/docs/Project.toml b/docs/Project.toml index 7118d5fe..e8f6d849 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,9 +3,12 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/docs/make.jl b/docs/make.jl index 8aedd777..f24ea50b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,7 +24,7 @@ makedocs(; "Calibrated forecasts" => "tutorials/calibration.md", ], "Reference" => "reference.md", - "MLJ interface"=> "mlj_interface.md", + "MLJ interface" => "mlj_interface.md", "Additional Resources" => "resources/_resources.md", ], ) diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index ed931ca7..b6324f95 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -26,5 +26,8 @@ export empirical_frequency_binary_classification, empirical_frequency_regression, sharpness_regression, extract_mean_and_variance, - sigma_scaling, rescale_stddev + sigma_scaling, + rescale_stddev + +include("direct_mlj.jl") end diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 4a0a1d6c..feb457b3 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -93,7 +93,7 @@ Computes the Bayesian predictivie distribution from a neural network with a Lapl - `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`. - `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model. - `return_distr::Bool=false`: if `false` (default), the function outputs either the direct output of the chain or pseudo-probabilities (if `predict_proba=true`). - if `true` predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks. + if `true` predict returns a probability distribution. # Returns @@ -144,7 +144,6 @@ function predict( else return fμ, pred_var end - end # Classification: diff --git a/src/calibration_functions.jl b/src/calibration_functions.jl index 1d313248..d6aa5062 100644 --- a/src/calibration_functions.jl +++ b/src/calibration_functions.jl @@ -78,7 +78,7 @@ Outputs: \ """ function sharpness_classification(y_binary, distributions::Vector{Bernoulli{Float64}}) mean_class_one = mean(mean.(distributions[findall(y_binary .== 1)])) - mean_class_zero = mean( 1 .- mean.(distributions[findall(y_binary .== 0)])) + mean_class_zero = mean(1 .- mean.(distributions[findall(y_binary .== 0)])) return mean_class_one, mean_class_zero end @@ -176,8 +176,9 @@ Inputs: \ Outputs: \ - `sigma`: the scalar that maximize the likelihood. """ -function sigma_scaling(distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat} - ) where T <: AbstractFloat +function sigma_scaling( + distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat} +) where {T<:AbstractFloat} means, variances = extract_mean_and_variance(distr) sigma = sqrt(1 / length(y_cal) * sum(norm.(y_cal .- means) ./ variances)) @@ -198,4 +199,4 @@ Outputs: \ function rescale_stddev(distr::Vector{Normal{T}}, s::T) where {T<:AbstractFloat} rescaled_distr = [Normal(mean(d), std(d) * s) for d in distr] return rescaled_distr -end \ No newline at end of file +end diff --git a/src/data/functions.jl b/src/data/functions.jl index 99e4440c..5ae83805 100644 --- a/src/data/functions.jl +++ b/src/data/functions.jl @@ -1,4 +1,4 @@ -import Random +using Random: Random """ toy_data_linear(N=100) @@ -42,7 +42,7 @@ toy_data_non_linear() ``` """ -function toy_data_non_linear( N=100; seed=nothing) +function toy_data_non_linear(N=100; seed=nothing) #set seed if available if seed !== nothing diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl new file mode 100644 index 00000000..61aab7dd --- /dev/null +++ b/src/direct_mlj.jl @@ -0,0 +1,846 @@ +using Optimisers: Optimisers +using Flux +using Random +using Tables +using LinearAlgebra +using LaplaceRedux +using MLJBase: MLJBase +import MLJModelInterface as MMI +using Distributions: Normal + +MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic + model::Union{Flux.Chain,Nothing} = nothing + flux_loss = Flux.Losses.logitcrossentropy + optimiser = Optimisers.Adam() + epochs::Integer = 1000::(_ > 0) + batch_size::Integer = 32::(_ > 0) + subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) + subnetwork_indices = nothing + hessian_structure::Union{HessianStructure,Symbol,String} = + :full::(_ in (:full, :diagonal)) + backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) + σ::Float64 = 1.0 + μ₀::Float64 = 0.0 + P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing + fit_prior_nsteps::Int = 100::(_ > 0) + link_approx::Symbol = :probit::(_ in (:probit, :plugin)) +end + +MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic + model::Union{Flux.Chain,Nothing} = nothing + flux_loss = Flux.Losses.mse + optimiser = Optimisers.Adam() + epochs::Integer = 1000::(_ > 0) + batch_size::Integer = 32::(_ > 0) + subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) + subnetwork_indices = nothing + hessian_structure::Union{HessianStructure,Symbol,String} = + :full::(_ in (:full, :diagonal)) + backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) + σ::Float64 = 1.0 + μ₀::Float64 = 0.0 + P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing + fit_prior_nsteps::Int = 100::(_ > 0) +end + +LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier} + +# for fit: +function MMI.reformat(::LaplaceRegressor, X, y) + return (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing)) +end + +function MMI.reformat(::LaplaceClassifier, X, y) + X = MLJBase.matrix(X) |> permutedims + y = MLJBase.categorical(y) + labels = y.pool.levels + y = Flux.onehotbatch(y, labels) # One-hot encoding + + return X, (y, labels) +end + +MMI.reformat(::LaplaceModels, X) = (MLJBase.matrix(X) |> permutedims,) + +MMI.selectrows(::LaplaceModels, I, Xmatrix, y) = (Xmatrix[:, I], (y[1][:, I], y[2])) +MMI.selectrows(::LaplaceModels, I, Xmatrix) = (Xmatrix[:, I],) + + + +""" + function dataset_shape(model::LaplaceRegression, X, y) + +Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset. + +# Arguments +- `model::LaplaceModels`: The Laplace model to fit. +- `X`: The input data for training. +- `y`: The target labels for training one-hot encoded. + +# Returns +- (input size, output size) +""" +function dataset_shape(model::LaplaceModels, X, y) + n_input = size(X, 1) + dims = size(y) + if length(dims) == 1 + n_output = 1 + else + n_output = dims[1] + end + return (n_input, n_output) +end + + +""" + default_build( seed::Int, shape) + +Builds a default MLP Flux model compatible with the dimensions of the dataset, with reproducible initial weights. + +# Arguments +- `seed::Int`: The seed for random number generation. +- `shape`: a tuple containing the dimensions of the input layer and the output layer. + +# Returns +- The constructed Flux model, which consist in a simple MLP with 2 hidden layers with 20 neurons each and an input and output layers compatible with the dataset. +""" +function default_build(seed::Int, shape) + Random.seed!(seed) + (n_input, n_output) = shape + + chain = Chain( + Dense(n_input, 20, relu), + Dense(20, 20, relu), + Dense(20, 20, relu), + Dense(20, n_output) + ) + + return chain +end + + + + +@doc """ + MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y) + +Fit a Laplace model using the provided features and target values. + +# Arguments +- `m::Laplace`: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted. +- `verbosity`: Verbosity level for logging. +- `X`: Input features, expected to be in a format compatible with MLJBase.matrix. +- `y`: Target values. + +# Returns +- `fitresult`: a tuple (la,decode) cointaing the fitted Laplace model and y[1],the first element of the categorical y vector. +- `cache`: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history. +- `report`: A Namedtuple containing the loss history of the fitting process. +""" +function MMI.fit(m::LaplaceModels, verbosity, X, y) + y, decode = y + + if (m.model === nothing) + @warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each and input and output layers compatible with the dataset." + shape = dataset_shape(m, X, y) + + m.model = default_build(11, shape) + + end + + # Make a copy of the model because Flux does not allow to mutate hyperparameters + copied_model = deepcopy(m.model) + + data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size) + state_tree = Optimisers.setup(m.optimiser, copied_model) + loss_history = [] + + for epoch in 1:(m.epochs) + loss_per_epoch = 0.0 + + for (X_batch, y_batch) in data_loader + # Forward pass: compute predictions + y_pred = copied_model(X_batch) + + # Compute loss + loss = m.flux_loss(y_pred, y_batch) + + # Compute gradients + grads, _ = gradient(copied_model, X_batch) do grad_model, X + # Recompute predictions inside gradient context + y_pred = grad_model(X) + m.flux_loss(y_pred, y_batch) + end + + # Update parameters using the optimizer and computed gradients + state_tree, copied_model = Optimisers.update!(state_tree, copied_model, grads) + + # Accumulate the loss for this batch + loss_per_epoch += sum(loss) # Summing the batch loss + end + + push!(loss_history, loss_per_epoch) + + # Print loss every 100 epochs if verbosity is 1 or more + if verbosity >= 1 && epoch % 100 == 0 + println("Epoch $epoch: Loss: $loss_per_epoch ") + end + end + + la = LaplaceRedux.Laplace( + copied_model; + likelihood=:regression, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, + σ=m.σ, + μ₀=m.μ₀, + P₀=m.P₀, + ) + + if typeof(m) == LaplaceClassifier + la.likelihood = :classification + end + + # fit the Laplace model: + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbosity= verbosity, n_steps=m.fit_prior_nsteps) + + fitresult = (la, decode) + report = (loss_history=loss_history,) + cache = (deepcopy(m), state_tree, loss_history) + return fitresult, cache, report +end + +@doc """ + MMI.update(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y) + +Update the Laplace model using the provided new data points. + +# Arguments +- `m`: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted. +- `verbosity`: Verbosity level for logging. +- `X`: New input features, expected to be in a format compatible with MLJBase.matrix. +- `y`: New target values. + +# Returns +- `fitresult`: a tuple (la,decode) cointaing the updated fitted Laplace model and y[1],the first element of the categorical y vector. +- `cache`: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history. +- `report`: A Namedtuple containing the complete loss history of the fitting process. +""" +function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) + y_up, decode = y + + data_loader = Flux.DataLoader((X, y_up); batchsize=m.batch_size) + old_model = old_cache[1] + old_state_tree = old_cache[2] + old_loss_history = old_cache[3] + + epochs = m.epochs + + if MMI.is_same_except(m, old_model, :epochs) + old_la = old_fitresult[1] + if epochs > old_model.epochs + for epoch in (old_model.epochs + 1):(epochs) + loss_per_epoch = 0.0 + + for (X_batch, y_batch) in data_loader + # Forward pass: compute predictions + y_pred = old_la.model(X_batch) + + # Compute loss + loss = m.flux_loss(y_pred, y_batch) + + # Compute gradients + grads, _ = gradient(old_la.model, X_batch) do grad_model, X + # Recompute predictions inside gradient context + y_pred = grad_model(X) + m.flux_loss(y_pred, y_batch) + end + + # Update parameters using the optimizer and computed gradients + old_state_tree, old_la.model = Optimisers.update!( + old_state_tree, old_la.model, grads + ) + + # Accumulate the loss for this batch + loss_per_epoch += sum(loss) # Summing the batch loss + end + + push!(old_loss_history, loss_per_epoch) + + # Print loss every 100 epochs if verbosity is 1 or more + if verbosity >= 1 && epoch % 100 == 0 + println("Epoch $epoch: Loss: $loss_per_epoch ") + end + end + + la = LaplaceRedux.Laplace( + old_la.model; + likelihood=:regression, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, + σ=m.σ, + μ₀=m.μ₀, + P₀=m.P₀, + ) + if typeof(m) == LaplaceClassifier + la.likelihood = :classification + end + + # fit the Laplace model: + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbosity = verbosity, n_steps=m.fit_prior_nsteps) + + fitresult = (la, decode) + report = (loss_history=old_loss_history,) + cache = (deepcopy(m), old_state_tree, old_loss_history) + + else + println( + "The number of epochs inserted is lower than the number of epochs already been trained. No update is necessary", + ) + fitresult = (old_la, decode) + report = (loss_history=old_loss_history,) + cache = (deepcopy(m), old_state_tree, old_loss_history) + end + + elseif MMI.is_same_except( + m, + old_model, + :fit_prior_nsteps, + :subset_of_weights, + :subnetwork_indices, + :hessian_structure, + :backend, + :σ, + :μ₀, + :P₀, + ) + println(" updating only the laplace optimization part") + old_la = old_fitresult[1] + + la = LaplaceRedux.Laplace( + old_la.model; + likelihood=:regression, + subset_of_weights=m.subset_of_weights, + subnetwork_indices=m.subnetwork_indices, + hessian_structure=m.hessian_structure, + backend=m.backend, + σ=m.σ, + μ₀=m.μ₀, + P₀=m.P₀, + ) + if typeof(m) == LaplaceClassifier + la.likelihood = :classification + end + + # fit the Laplace model: + LaplaceRedux.fit!(la, data_loader) + optimize_prior!(la; verbosity = verbosity, n_steps=m.fit_prior_nsteps) + + fitresult = (la, decode) + report = (loss_history=old_loss_history,) + cache = (deepcopy(m), old_state_tree, old_loss_history) + + end + + return fitresult, cache, report +end + +@doc """ + function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...) + +If both `m1` and `m2` are of `MLJType`, return `true` if the +following conditions all hold, and `false` otherwise: + +- `typeof(m1) === typeof(m2)` + +- `propertynames(m1) === propertynames(m2)` + +- with the exception of properties listed as `exceptions` or bound to + an `AbstractRNG`, each pair of corresponding property values is + either "equal" or both undefined. (If a property appears as a + `propertyname` but not a `fieldname`, it is deemed as always defined.) + +The meaining of "equal" depends on the type of the property value: + +- values that are themselves of `MLJType` are "equal" if they are + equal in the sense of `is_same_except` with no exceptions. + +- values that are not of `MLJType` are "equal" if they are `==`. + +In the special case of a "deep" property, "equal" has a different +meaning; see `MLJBase.deep_properties` for details. + +If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`. + +""" +function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...) + typeof(m1) === typeof(m2) || return false + names = propertynames(m1) + propertynames(m2) === names || return false + + for name in names + if !(name in exceptions) + if !_isdefined(m1, name) + !_isdefined(m2, name) || return false + elseif _isdefined(m2, name) + if name in MLJBase.deep_properties(LaplaceRegressor) + _equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) || + return false + else + ( + MMI.is_same_except(getproperty(m1, name), getproperty(m2, name)) || + getproperty(m1, name) isa AbstractRNG || + getproperty(m2, name) isa AbstractRNG || + ( + getproperty(m1, name) isa Flux.Chain && + getproperty(m2, name) isa Flux.Chain && + _equal_flux_chain(getproperty(m1, name), getproperty(m2, name)) + ) + ) || return false + end + else + return false + end + end + end + return true +end +function _isdefined(object, name) + pnames = propertynames(object) + fnames = fieldnames(typeof(object)) + name in pnames && !(name in fnames) && return true + return isdefined(object, name) +end + +function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain) + if length(chain1.layers) != length(chain2.layers) + return false + end + params1 = Flux.params(chain1) + params2 = Flux.params(chain2) + if length(params1) != length(params2) + return false + end + for (p1, p2) in zip(params1, params2) + if !isequal(p1, p2) + return false + end + end + for (layer1, layer2) in zip(chain1.layers, chain2.layers) + if typeof(layer1) != typeof(layer2) + return false + end + end + return true +end + +@doc """ + + function MMI.fitted_params(model::LaplaceRegressor, fitresult) + + + This function extracts the fitted parameters from a `LaplaceRegressor` model. + + # Arguments + - `model::LaplaceRegressor`: The Laplace regression model. + - `fitresult`: the Laplace approximation (`la`). + + # Returns + A named tuple containing: + - `μ`: The mean of the posterior distribution. + - `H`: The Hessian of the posterior distribution. + - `P`: The precision matrix of the posterior distribution. + - `Σ`: The covariance matrix of the posterior distribution. + - `n_data`: The number of data points. + - `n_params`: The number of parameters. + - `n_out`: The number of outputs. + - `loss`: The loss value of the posterior distribution. + +""" +function MMI.fitted_params(model::LaplaceModels, fitresult) + la, decode = fitresult + posterior = la.posterior + return ( + μ=posterior.μ, + H=posterior.H, + P=posterior.P, + Σ=posterior.Σ, + n_data=posterior.n_data, + n_params=posterior.n_params, + n_out=posterior.n_out, + loss=posterior.loss, + ) +end + +@doc """ + MMI.training_losses(model::Union{LaplaceRegressor,LaplaceClassifier}, report) + +Retrieve the training loss history from the given `report`. + +# Arguments +- `model`: The model for which the training losses are being retrieved. +- `report`: An object containing the training report, which includes the loss history. + +# Returns +- A collection representing the loss history from the training report. +""" +function MMI.training_losses(model::LaplaceModels, report) + return report.loss_history +end + +@doc """ +function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) + + Predicts the response for new data using a fitted Laplace model. + + # Arguments + - `m::LaplaceRegressor`: The Laplace model. + - `fitresult`: The result of the fitting procedure. + - `Xnew`: The new data for which predictions are to be made. + + # Returns + for LaplaceRegressor: + - An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`. + for LaplaceClassifier: + - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data. +""" +function MMI.predict(m::LaplaceModels, fitresult, Xnew) + la, decode = fitresult + if typeof(m) == LaplaceRegressor + yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false) + # Extract mean and variance matrices + means, variances = yhat + + # Create Normal distributions from the means and variances + return vec([Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]) + + else + predictions = + LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |> + permutedims + + return MLJBase.UnivariateFinite(decode, predictions; pool=missing) + end +end + +MMI.metadata_pkg( + LaplaceRegressor; + name="LaplaceRedux", + package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478", + package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", + is_pure_julia=true, + is_wrapper=true, + package_license="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE", +) + +MMI.metadata_pkg( + LaplaceClassifier; + name="LaplaceRedux", + package_uuid="c52c1a26-f7c5-402b-80be-ba1e638ad478", + package_url="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl", + is_pure_julia=true, + is_wrapper=true, + package_license="https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/main/LICENSE", +) + +MLJBase.metadata_model( + LaplaceClassifier; + input_scitype=Union{ + AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types + MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types + }, + target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass + supports_training_losses=true, + load_path="LaplaceRedux.LaplaceClassifier", +) +# metadata for each model, +MLJBase.metadata_model( + LaplaceRegressor; + input_scitype=Union{ + AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types + MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types + }, + target_scitype=AbstractArray{MLJBase.Continuous}, + supports_training_losses=true, + load_path="LaplaceRedux.LaplaceRegressor", +) + +const DOC_LAPLACE_REDUX = + "[Laplace Redux – Effortless Bayesian Deep Learning]" * + "(https://proceedings.neurips.cc/paper/2021/hash/a3923dbe2f702eff254d67b48ae2f06e-Abstract.html), originally published in " * + "Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): \"Laplace Redux – Effortless Bayesian Deep Learning.\", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103" + +""" +$(MMI.doc_header(LaplaceClassifier)) + +`LaplaceClassifier` implements the $DOC_LAPLACE_REDUX for classification models. + +# Training data + +In MLJ or MLJBase, given a dataset X,y and a Flux Chain adapt to the dataset, pass the chain to the model + +laplace_model = LaplaceClassifier(model = Flux_Chain,kwargs...) + +then bind an instance `laplace_model` to data with + + mach = machine(laplace_model, X, y) + +where + +- `X`: any table of input features (eg, a `DataFrame`) whose columns + each have one of the following element scitypes: `Continuous`, + `Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)` + +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype + with `scitype(y)` + +Train the machine using `fit!(mach, rows=...)`. + + +# Hyperparameters (format: name-type-default value-restrictions) + +- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each. + +- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function + +- `optimiser = Adam()` a Flux optimiser + +- `epochs::Integer = 1000::(_ > 0)`: the number of training epochs. + +- `batch_size::Integer = 32::(_ > 0)`: the batch size. + +- `subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. + +- `subnetwork_indices = nothing`: the indices of the subnetworks. + +- `hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal))`: the structure of the Hessian matrix, either `:full` or `:diagonal`. + +- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`. + +- `σ::Float64 = 1.0`: the standard deviation of the prior distribution. + +- `μ₀::Float64 = 0.0`: the mean of the prior distribution. + +- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution. + +- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors. + +- `link_approx::Symbol = :probit::(_ in (:probit, :plugin))`: the approximation to adopt to compute the probabilities. + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given + features `Xnew` having the same scitype as `X` above. Predictions + are probabilistic, but uncalibrated. + +- `predict_mode(mach, Xnew)`: instead return the mode of each + prediction above. + +- `training_losses(mach)`: return the loss history from report + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + + - `μ`: The mean of the posterior distribution. + + - `H`: The Hessian of the posterior distribution. + + - `P`: The precision matrix of the posterior distribution. + + - `Σ`: The covariance matrix of the posterior distribution. + + - `n_data`: The number of data points. + + - `n_params`: The number of parameters. + + - `n_out`: The number of outputs. + + - `loss`: The loss value of the posterior distribution. + + + + # Report + +The fields of `report(mach)` are: + +- `loss_history`: an array containing the total loss per epoch. + +# Accessor functions + + +# Examples + +``` +using MLJ +LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux + +X, y = @load_iris + +# Define the Flux Chain model +using Flux +model = Chain( + Dense(4, 10, relu), + Dense(10, 10, relu), + Dense(10, 3) +) + +#Define the LaplaceClassifier +model = LaplaceClassifier(model=model) + +mach = machine(model, X, y) |> fit! + +Xnew = (sepal_length = [6.4, 7.2, 7.4], + sepal_width = [2.8, 3.0, 2.8], + petal_length = [5.6, 5.8, 6.1], + petal_width = [2.1, 1.6, 1.9],) +yhat = predict(mach, Xnew) # probabilistic predictions +predict_mode(mach, Xnew) # point predictions +training_losses(mach) # loss history per epoch +pdf.(yhat, "virginica") # probabilities for the "verginica" class +fitted_params(mach) # NamedTuple with the fitted params of Laplace + +``` + +See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl). + +""" +LaplaceClassifier + +""" +$(MMI.doc_header(LaplaceRegressor)) + +`LaplaceRegressor` implements the $DOC_LAPLACE_REDUX for regression models. + +# Training data + +In MLJ or MLJBase, given a dataset X,y and a Flux Chain adapt to the dataset, pass the chain to the model + +laplace_model = LaplaceRegressor(model = Flux_Chain,kwargs...) + +then bind an instance `laplace_model` to data with + + mach = machine(laplace_model, X, y) +where + +- `X`: any table of input features (eg, a `DataFrame`) whose columns + each have one of the following element scitypes: `Continuous`, + `Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)` + +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `<:Continuous`; check the scitype + with `scitype(y)` + +Train the machine using `fit!(mach, rows=...)`. + + +# Hyperparameters (format: name-type-default value-restrictions) + +- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each. +- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function + +- `optimiser = Adam()` a Flux optimiser + +- `epochs::Integer = 1000::(_ > 0)`: the number of training epochs. + +- `batch_size::Integer = 32::(_ > 0)`: the batch size. + +- `subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. + +- `subnetwork_indices = nothing`: the indices of the subnetworks. + +- `hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal))`: the structure of the Hessian matrix, either `:full` or `:diagonal`. + +- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`. + +- `σ::Float64 = 1.0`: the standard deviation of the prior distribution. + +- `μ₀::Float64 = 0.0`: the mean of the prior distribution. + +- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution. + +- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors. + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given + features `Xnew` having the same scitype as `X` above. Predictions + are probabilistic, but uncalibrated. + +- `predict_mode(mach, Xnew)`: instead return the mode of each + prediction above. + +- `training_losses(mach)`: return the loss history from report + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + + - `μ`: The mean of the posterior distribution. + + - `H`: The Hessian of the posterior distribution. + + - `P`: The precision matrix of the posterior distribution. + + - `Σ`: The covariance matrix of the posterior distribution. + + - `n_data`: The number of data points. + + - `n_params`: The number of parameters. + + - `n_out`: The number of outputs. + + - `loss`: The loss value of the posterior distribution. + + +# Report + +The fields of `report(mach)` are: + +- `loss_history`: an array containing the total loss per epoch. + + + + +# Accessor functions + + + +# Examples + +``` +using MLJ +using Flux +LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux +model = Chain( + Dense(4, 10, relu), + Dense(10, 10, relu), + Dense(10, 1) +) +model = LaplaceRegressor(model=model) + +X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) +mach = machine(model, X, y) |> fit! + +Xnew, _ = make_regression(3, 4; rng=123) +yhat = predict(mach, Xnew) # probabilistic predictions +predict_mode(mach, Xnew) # point predictions +training_losses(mach) # loss history per epoch +fitted_params(mach) # NamedTuple with the fitted params of Laplace + +``` + +See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl). + +""" +LaplaceRegressor + diff --git a/test/Manifest.toml b/test/Manifest.toml index 31dcb255..9717977b 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,7 +2,13 @@ julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "2fde859c875aff2c1b66bd10b3f4f3d64f67067a" +project_hash = "44de0b245e083feedd88d905933ce8d9b455e504" + +[[deps.ARFFFiles]] +deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] +git-tree-sha1 = "678eb18590a8bc6674363da4d5faa4ac09c40a18" +uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" +version = "1.5.0" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] @@ -233,9 +239,9 @@ version = "0.7.6" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +git-tree-sha1 = "13951eb68769ad1cd460cdb2e64e5e95f1bf123d" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" +version = "3.27.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -421,9 +427,9 @@ version = "1.15.1" [[deps.Distances]] deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +git-tree-sha1 = "c7e3a542b999843086e2f29dac96a618c105be1d" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.11" +version = "0.10.12" weakdeps = ["ChainRulesCore", "SparseArrays"] [deps.Distances.extensions] @@ -461,6 +467,12 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" +[[deps.EarlyStopping]] +deps = ["Dates", "Statistics"] +git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" +uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" +version = "0.3.0" + [[deps.EpollShim_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" @@ -543,9 +555,9 @@ version = "0.8.5" [[deps.Flux]] deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777" +git-tree-sha1 = "b78bd94ef1588881983bd6b8b860b2c27293140a" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.22" +version = "0.14.23" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" @@ -554,14 +566,12 @@ version = "0.14.22" FluxEnzymeExt = "Enzyme" FluxMPIExt = "MPI" FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] - FluxMetalExt = "Metal" [deps.Flux.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -692,9 +702,9 @@ version = "1.14.2+1" [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" +git-tree-sha1 = "bc3f416a965ae61968c20d0ad867556367f2817d" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.8" +version = "1.10.9" [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll"] @@ -786,6 +796,12 @@ git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.2.2" +[[deps.IterationControl]] +deps = ["EarlyStopping", "InteractiveUtils"] +git-tree-sha1 = "e663925ebc3d93c1150a7570d114f9ea2f664726" +uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" +version = "0.5.4" + [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" @@ -793,9 +809,9 @@ version = "1.0.0" [[deps.JLD2]] deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] -git-tree-sha1 = "b464b9b461ee989b435a689a4f7d870b68d467ed" +git-tree-sha1 = "783c1be5213a09609b23237a0c9e5dfd258ae6f2" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.5.6" +version = "0.5.7" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] @@ -841,9 +857,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8" +git-tree-sha1 = "e73a077abc7fe798fe940deabe30ef6c66bdde52" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.28" +version = "0.9.29" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -916,6 +932,12 @@ version = "0.16.5" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" +[[deps.LatinHypercubeSampling]] +deps = ["Random", "StableRNGs", "StatsBase", "Test"] +git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" +uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" +version = "1.9.0" + [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" @@ -966,9 +988,9 @@ version = "3.2.2+1" [[deps.Libgcrypt_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"] -git-tree-sha1 = "9fd170c4bbfd8b935fdc5f8b7aa33532c991a673" +git-tree-sha1 = "8be878062e0ffa2c3f67bb58a595375eda5de80b" uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" -version = "1.8.11+0" +version = "1.11.0+0" [[deps.Libglvnd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll", "Xorg_libXext_jll"] @@ -978,15 +1000,15 @@ version = "1.6.0+0" [[deps.Libgpg_error_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "fbb1f2bef882392312feb1ede3615ddc1e9b99ed" +git-tree-sha1 = "c6ce1e19f3aec9b59186bdf06cdf3c4fc5f5f3e6" uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" -version = "1.49.0+0" +version = "1.50.0+0" [[deps.Libiconv_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +git-tree-sha1 = "61dfdba58e585066d8bce214c5a51eaa0539f269" uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" +version = "1.17.0+1" [[deps.Libmount_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1031,9 +1053,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.LoggingExtras]] deps = ["Dates", "Logging"] -git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +git-tree-sha1 = "f02b56007b064fbfddb4c9cd60161b6dd0f40df3" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.3" +version = "1.1.0" [[deps.MAT]] deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] @@ -1042,10 +1064,10 @@ uuid = "23992714-dd62-5051-b70f-ba57cb901cac" version = "0.10.7" [[deps.MLDataDevices]] -deps = ["Adapt", "Compat", "Functors", "LinearAlgebra", "Preferences", "Random"] -git-tree-sha1 = "3207c2e66164e6366440ad3f0243a8d67abb4a47" +deps = ["Adapt", "Compat", "Functors", "Preferences", "Random"] +git-tree-sha1 = "5cffc52b59227864b665459e1f7bcc4d3c4fb47b" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" -version = "1.4.1" +version = "1.4.2" [deps.MLDataDevices.extensions] MLDataDevicesAMDGPUExt = "AMDGPU" @@ -1087,24 +1109,58 @@ git-tree-sha1 = "361c2692ee730944764945859f1a6b31072e275d" uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" version = "0.7.18" +[[deps.MLFlowClient]] +deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] +git-tree-sha1 = "9abb12b62debc27261c008daa13627255bf79967" +uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83" +version = "0.5.1" + +[[deps.MLJ]] +deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBalancing", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "StatisticalMeasures", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "bd2072e9cd65be0a3cb841f3d8cda1d2cacfe5db" +uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +version = "0.20.5" + +[[deps.MLJBalancing]] +deps = ["MLJBase", "MLJModelInterface", "MLUtils", "OrderedCollections", "Random", "StatsBase"] +git-tree-sha1 = "f707a01a92d664479522313907c07afa5d81df19" +uuid = "45f359ea-796d-4f51-95a5-deb1a414c586" +version = "0.1.5" + [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" version = "1.7.0" +weakdeps = ["StatisticalMeasures"] [deps.MLJBase.extensions] DefaultMeasuresExt = "StatisticalMeasures" - [deps.MLJBase.weakdeps] - StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" - [[deps.MLJDecisionTreeInterface]] deps = ["CategoricalArrays", "DecisionTree", "MLJModelInterface", "Random", "Tables"] git-tree-sha1 = "90ef4d3b6cacec631c57cc034e1e61b4aa0ce511" uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" version = "0.4.2" +[[deps.MLJEnsembles]] +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatisticalMeasuresBase", "StatsBase"] +git-tree-sha1 = "84a5be55a364bb6b6dc7780bbd64317ebdd3ad1e" +uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" +version = "0.4.3" + +[[deps.MLJFlow]] +deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"] +git-tree-sha1 = "508bff8071d7d1902d6f1b9d1e868d58821f1cfe" +uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f" +version = "0.5.0" + +[[deps.MLJIteration]] +deps = ["IterationControl", "MLJBase", "Random", "Serialization"] +git-tree-sha1 = "ad16cfd261e28204847f509d1221a581286448ae" +uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" +version = "0.6.3" + [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] git-tree-sha1 = "ceaff6618408d0e412619321ae43b33b40c1a733" @@ -1117,6 +1173,12 @@ git-tree-sha1 = "410da88e0e6ece5467293d2c76b51b7c6df7d072" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" version = "0.16.17" +[[deps.MLJTuning]] +deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase", "StatisticalMeasuresBase"] +git-tree-sha1 = "38aab60b1274ce7d6da784808e3be69e585dbbf6" +uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" +version = "0.8.8" + [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" @@ -1289,6 +1351,12 @@ deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" version = "0.8.1+2" +[[deps.OpenML]] +deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"] +git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33" +uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" +version = "0.3.1" + [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"] git-tree-sha1 = "bfce6d523861a6c562721b262c0d1aaeead2647f" @@ -1401,9 +1469,9 @@ version = "1.10.0" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] -git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0" +git-tree-sha1 = "41031ef3a1be6f5bbbf3e8073f210556daeae5ca" uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" -version = "3.2.0" +version = "3.3.0" [[deps.PlotUtils]] deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"] @@ -1606,9 +1674,9 @@ version = "1.2.1" [[deps.SentinelArrays]] deps = ["Dates", "Random"] -git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" +git-tree-sha1 = "305becf8af67eae1dbc912ee9097f00aeeabb8d5" uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.5" +version = "1.4.6" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -1705,6 +1773,20 @@ git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" version = "1.4.3" +[[deps.StatisticalMeasures]] +deps = ["CategoricalArrays", "CategoricalDistributions", "Distributions", "LearnAPI", "LinearAlgebra", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "StatisticalMeasuresBase", "Statistics", "StatsBase"] +git-tree-sha1 = "c1d4318fa41056b839dfbb3ee841f011fa6e8518" +uuid = "a19d573c-0a75-4610-95b3-7071388c7541" +version = "0.1.7" + + [deps.StatisticalMeasures.extensions] + LossFunctionsExt = "LossFunctions" + ScientificTypesExt = "ScientificTypes" + + [deps.StatisticalMeasures.weakdeps] + LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" + ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" + [[deps.StatisticalMeasuresBase]] deps = ["CategoricalArrays", "InteractiveUtils", "MLUtils", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "Statistics"] git-tree-sha1 = "17dfb22e2e4ccc9cd59b487dce52883e0151b4d3" @@ -1983,9 +2065,9 @@ version = "1.6.1" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "1165b0443d0eca63ac1e32b8c0eb69ed2f4f8127" +git-tree-sha1 = "6a451c6f33a176150f315726eba8b92fbfdb9ae7" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.13.3+0" +version = "2.13.4+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] @@ -1995,9 +2077,9 @@ version = "1.1.41+0" [[deps.XZ_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632" +git-tree-sha1 = "15e637a697345f6743674f1322beefbc5dcd5cfc" uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" -version = "5.4.6+0" +version = "5.6.3+0" [[deps.Xorg_libICE_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] diff --git a/test/Project.toml b/test/Project.toml index 1e853790..aa0bf7d9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/test/calibration.jl b/test/calibration.jl index 847543ce..988ff1ef 100644 --- a/test/calibration.jl +++ b/test/calibration.jl @@ -3,8 +3,6 @@ using LaplaceRedux using Distributions using Trapz - - # Test for `sharpness_regression` function @testset "sharpness_regression distributions tests" begin @info " testing sharpness_regression with distributions" @@ -201,4 +199,4 @@ end distributions = [Normal(0, 1), Normal(2, 1), Normal(4, 1)] rescaled_distr = rescale_stddev(distributions, 2.0) @test rescaled_distr == [Normal(0, 2), Normal(2, 2), Normal(4, 2)] -end \ No newline at end of file +end diff --git a/test/data.jl b/test/data.jl index e6d61881..05a5b30d 100644 --- a/test/data.jl +++ b/test/data.jl @@ -15,7 +15,7 @@ for fun in fun_list # Generate data with the same seed Random.seed!(seed) xs1, ys1 = fun(N; seed=seed) - + Random.seed!(seed) xs2, ys2 = fun(N; seed=seed) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl new file mode 100644 index 00000000..8f2b0438 --- /dev/null +++ b/test/direct_mlj_interface.jl @@ -0,0 +1,100 @@ +using Random: Random +import Random.seed! +using MLJBase: MLJBase, categorical +using MLJ: MLJ +using Flux +using StableRNGs +import LaplaceRedux: LaplaceClassifier, LaplaceRegressor + +cv = MLJBase.CV(; nfolds=3) + +@testset "Regression" begin + @info " testing interface for LaplaceRegressor" + flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1)) + model = LaplaceRegressor(; model=flux_model, epochs=20) + + #testing more complex dataset + X, y = MLJ.make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) + #train, test = partition(eachindex(y), 0.7); # 70:30 split + mach = MLJ.machine(model, X, y) + MLJBase.fit!(mach; verbosity=0) + yhat = MLJBase.predict(mach, X) # probabilistic predictions + MLJBase.predict_mode(mach, X) # point predictions + MLJBase.fitted_params(mach) #fitted params function + MLJBase.training_losses(mach) #training loss history + model.epochs = 40 #changing number of epochs + MLJBase.fit!(mach; verbosity=0) #testing update function + model.epochs = 30 #changing number of epochs to a lower number + MLJBase.fit!(mach; verbosity=0) #testing update function + model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps + MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part) + yhat = MLJBase.predict(mach, X) # probabilistic predictions + MLJ.evaluate!(mach; resampling=cv, measure=MLJ.log_loss, verbosity=0) + + # Define a different model + flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 1)) + # test update! fallback to fit! + model.model = flux_model_two + MLJBase.fit!(mach; verbosity=0) + model_two = LaplaceRegressor(; model=flux_model_two, epochs=100) + @test !MLJBase.is_same_except(model, model_two) + + + + #testing default mlp builder + model = LaplaceRegressor(; model=nothing, epochs=20) + mach = MLJ.machine(model, X, y) + MLJBase.fit!(mach; verbosity=1) + yhat = MLJBase.predict(mach, X) # probabilistic predictions + MLJBase.predict_mode(mach, X) # point predictions + MLJBase.fitted_params(mach) #fitted params function + MLJBase.training_losses(mach) #training loss history + model.epochs = 100 #changing number of epochs + MLJBase.fit!(mach; verbosity=1) #testing update function + + #testing dataset_shape for one dimensional function + X, y = MLJ.make_regression(100, 1; noise=0.5, sparse=0.2, outliers=0.1) + model = LaplaceRegressor(; model=nothing, epochs=20) + mach = MLJ.machine(model, X, y) + MLJBase.fit!(mach; verbosity=0) + + + +end + +@testset "Classification" begin + @info " testing interface for LaplaceClassifier" + # Define the model + flux_model = Chain(Dense(4, 10, relu), Dense(10, 3)) + + model = LaplaceClassifier(; model=flux_model, epochs=20) + + X, y = MLJ.@load_iris + mach = MLJ.machine(model, X, y) + MLJBase.fit!(mach; verbosity=0) + Xnew = ( + sepal_length=[6.4, 7.2, 7.4], + sepal_width=[2.8, 3.0, 2.8], + petal_length=[5.6, 5.8, 6.1], + petal_width=[2.1, 1.6, 1.9], + ) + yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions + MLJBase.predict_mode(mach, Xnew) # point predictions + MLJBase.pdf.(yhat, "virginica") # probabilities for the "verginica" class + MLJBase.fitted_params(mach) # fitted params + MLJBase.training_losses(mach) #training loss history + model.epochs = 40 #changing number of epochs + MLJBase.fit!(mach; verbosity=0) #testing update function + model.epochs = 30 #changing number of epochs to a lower number + MLJBase.fit!(mach; verbosity=0) #testing update function + model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps + MLJBase.fit!(mach; verbosity=0) #testing update function (the laplace part) + MLJ.evaluate!(mach; resampling=cv, measure=MLJ.brier_loss, verbosity=0) + + # Define a different model + flux_model_two = Chain(Dense(4, 6, relu), Dense(6, 3)) + + model.model = flux_model_two + + MLJBase.fit!(mach; verbosity=0) +end diff --git a/test/runtests.jl b/test/runtests.jl index 92d97033..b29d8e14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,5 +34,7 @@ using Test @testset "KronDecomposed" begin include("krondecomposed.jl") end - + @testset "Interface" begin + include("direct_mlj_interface.jl") + end end