Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Sep 13, 2024
1 parent 255cf19 commit 35ac2d8
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/direct_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ It has the following Hyperparameters:
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
flux_model::Flux.Chain= nothing
flux_model::Flux.Chain = nothing
subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
subnetwork_indices = nothing
hessian_structure::Union{HessianStructure,Symbol,String} =
Expand All @@ -39,23 +39,25 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
end


function MMI.fit(m::LaplaceRegressor, verbosity, X, y, w=nothing)
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)

This comment has been minimized.

Copy link
@pasq-cat

pasq-cat Sep 13, 2024

Author Member

@pat-alt @MojiFarmanbar I am having some trouble with the input data format. According to the mlj documentation i must suppose that X is a table so i extract the names of the features and then convert to a matrix but then i have problem when i try to fit the laplace object.

This comment has been minimized.

Copy link
@pat-alt

pat-alt Sep 13, 2024

Member

I think you might have to transpose the matrix?

 X = MLJBase.matrix(X) |> permutedims
features = Tables.schema(X).names

X = MLJBase.matrix(X)

la = LaplaceRedux.Laplace(
m.flux_model;
likelihood=:regression,
subset_of_weights=model.subset_of_weights,
subnetwork_indices=model.subnetwork_indices,
hessian_structure=model.hessian_structure,
backend=model.backend,
subset_of_weights=m.subset_of_weights,
subnetwork_indices=m.subnetwork_indices,
hessian_structure=m.hessian_structure,
backend=m.backend,
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
)

println(la)

# fit the Laplace model:
LaplaceRedux.fit!(la, zip(X, y))
optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
Expand Down Expand Up @@ -122,7 +124,7 @@ end



function MMI.fit(m::LaplaceClassifier, verbosity, X, y, w=nothing)
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
features = Tables.schema(X).names
Xmatrix = MLJBase.matrix(X)
decode = y[1]
Expand Down

0 comments on commit 35ac2d8

Please sign in to comment.