Skip to content

Commit

Permalink
various change in the documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Oct 30, 2024
1 parent e78e9b8 commit 6a5f26f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/direct_mlj.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#module MLJLaplaceRedux
using Optimisers: Optimisers
using Flux
using Random
Expand Down Expand Up @@ -81,7 +80,6 @@ Compute the the number of features of the X input dataset and the number of var
- (input size, output size)
"""
function dataset_shape(model::LaplaceModels, X, y)
#X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X
n_input = size(X, 1)
dims = size(y)
if length(dims) == 1
Expand Down Expand Up @@ -142,6 +140,7 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y)
y, decode = y

if (m.model === nothing)
@warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 3 hidden layers with 20 neurons each and input and output layers compatible with the dataset."
shape = dataset_shape(m, X, y)

m.model = default_build(11, shape)
Expand Down Expand Up @@ -606,7 +605,7 @@ Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters (format: name-type-default value-restrictions)
- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
Expand Down Expand Up @@ -744,8 +743,7 @@ Train the machine using `fit!(mach, rows=...)`.
# Hyperparameters (format: name-type-default value-restrictions)
- `model::Flux.Chain = nothing`: a Flux model provided by the user and compatible with the dataset.
- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
- `optimiser = Adam()` a Flux optimiser
Expand Down Expand Up @@ -846,4 +844,3 @@ See also [LaplaceRedux.jl](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl
"""
LaplaceRegressor

#end # module
2 changes: 2 additions & 0 deletions test/direct_mlj_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ cv = MLJBase.CV(; nfolds=3)
# test update! fallback to fit!
model.model = flux_model_two
MLJBase.fit!(mach; verbosity=0)
model_two = LaplaceRegressor(; model=flux_model_two, epochs=100)
@test !MLJBase.is_same_except(model, model_two)



Expand Down

0 comments on commit 6a5f26f

Please sign in to comment.