Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct mlj interface #126

Merged
merged 62 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
51ed526
new first commit
pasq-cat Sep 12, 2024
e8e96d1
various stuff
pasq-cat Sep 12, 2024
255cf19
fixes
pasq-cat Sep 12, 2024
35ac2d8
changes
pasq-cat Sep 13, 2024
44469a2
there is still a problem with the classifier
pasq-cat Sep 18, 2024
5315264
almost fixed
pasq-cat Sep 18, 2024
f6e7a00
works but i have to fix the hyperparameters
pasq-cat Sep 18, 2024
d7c4f7b
question on parameters....
pasq-cat Sep 18, 2024
bbab460
there is some problem with the one hot encoding
pasq-cat Sep 18, 2024
8af38ae
fixed error in univariatefinite
pasq-cat Sep 19, 2024
fe19d4d
performance improvement
pasq-cat Sep 19, 2024
d809afb
JuliaFormatter
pasq-cat Sep 21, 2024
33d84f5
juliaformatter+docstrings
pasq-cat Sep 21, 2024
9731297
removed predict_proba and ret_Distr from the struct
pasq-cat Sep 21, 2024
f70d239
mlj docstring in progress
pasq-cat Sep 21, 2024
80c6553
ah fixed constant , added prototype for regression
pasq-cat Sep 21, 2024
d1c895c
small stuff here and there in the docstring plus
pasq-cat Sep 21, 2024
19ffa16
still writing this long ass docstring
pasq-cat Sep 21, 2024
de0bd91
added fit_params functions
pasq-cat Sep 22, 2024
87df85f
switched to customized loop
pasq-cat Sep 22, 2024
24459a1
fixed error in custom loop
pasq-cat Sep 22, 2024
0e2ca03
various fixes
pasq-cat Sep 22, 2024
841d5eb
added reformat. must updated again the doc string....
pasq-cat Sep 22, 2024
de784f1
work on the docstring and then made it in a module
pasq-cat Sep 22, 2024
b7a99f6
fixed uuid, made test file.for direct_mlj. shut down the tests for ml…
pasq-cat Sep 23, 2024
c44b8d8
added tests. should be good....
pasq-cat Sep 23, 2024
b762185
added mlj to the dependency in test
pasq-cat Sep 23, 2024
ced3da0
prep for update + added mljmodelinterface to doc env
pasq-cat Sep 25, 2024
b700f85
changed the loop so that it nows uses optimisers from optimisers.jl
pasq-cat Oct 1, 2024
da6fc76
started joining the functions in a single common function for both mo…
pasq-cat Oct 3, 2024
70df568
various fixes
pasq-cat Oct 4, 2024
9889872
merged functions for both cases
pasq-cat Oct 4, 2024
0f46fd6
julia formatter
pasq-cat Oct 4, 2024
ab8b6bf
added unit tests
pasq-cat Oct 15, 2024
263cc67
more units
pasq-cat Oct 15, 2024
453b49f
fix
pasq-cat Oct 15, 2024
656b24e
changed unit test and a minor fix in the update function. there is st…
pasq-cat Oct 16, 2024
7c4d744
only things left to fix are the selectrows functions
pasq-cat Oct 16, 2024
f872d96
returning one-hot encoded directly
pat-alt Oct 16, 2024
71a3611
nearly there I think
pat-alt Oct 16, 2024
74d778e
one more issue with regression
pat-alt Oct 16, 2024
80784bb
fixed predict so that it return a vector of distributions-> fixed eva…
pasq-cat Oct 18, 2024
be80e32
amend: fixed predict so that it return a vector of distributions-> fi…
pasq-cat Oct 18, 2024
d426844
Merge branch 'direct_mlj_interface' of https://github.com/JuliaTrustw…
pasq-cat Oct 18, 2024
f4fcd95
madea mess with commits.... bah
pasq-cat Oct 18, 2024
851784f
trying to increase patch coverage
pasq-cat Oct 18, 2024
0752b83
fkn hell this codecov bot is worse than the inquisition
pasq-cat Oct 18, 2024
573ffd8
uhmmmmmm
pasq-cat Oct 21, 2024
db14b84
fixed _isdefined
pasq-cat Oct 21, 2024
82c5714
trying to fix docs issue and no longer importing MLJ nor MLJBase name…
pat-alt Oct 22, 2024
7202013
formatting
pat-alt Oct 22, 2024
a05e25f
removing mlj_flux
pat-alt Oct 22, 2024
05df2e1
fixed issues
pat-alt Oct 22, 2024
02abec2
removing reference to deep_propertier
pat-alt Oct 23, 2024
374aca5
hadn't saved file
pat-alt Oct 23, 2024
06343bd
Merge branch 'main' into local_direct_mlj
pasq-cat Oct 29, 2024
59917f8
added default mlp
pasq-cat Oct 29, 2024
9da5d7d
reducing number of epochs and trying to extende patch coverage
pasq-cat Oct 30, 2024
503763d
removed the else because it seems to have no role.
pasq-cat Oct 30, 2024
e78e9b8
ops forgot to remove the comment
pasq-cat Oct 30, 2024
6a5f26f
various change in the documentation
pasq-cat Oct 30, 2024
12f2584
ufffffffffffffffffffff
pasq-cat Oct 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
version:
- '1.9'
- '1.10'
- '1'
os:
- ubuntu-latest
- windows-latest
Expand Down
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ authors = ["Patrick Altmeyer"]
version = "1.1.1"

[deps]
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -19,20 +22,23 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Aqua = "0.8"
CategoricalDistributions = "0.1.15"
ChainRulesCore = "1.23.0"
Compat = "4.7.0"
Distributions = "0.25.109"
Flux = "0.12, 0.13, 0.14"
LinearAlgebra = "1.7, 1.10"
MLJBase = "1"
MLJModelInterface = "1.8.0"
MLUtils = "0.4"
Optimisers = "0.2, 0.3"
Random = "1.9, 1.10"
Statistics = "1"
Tables = "1.10.1"
Test = "1.9, 1.10"
Test = "1"
Tullio = "0.3.5"
Zygote = "0.6"
julia = "1.9, 1.10"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
6 changes: 3 additions & 3 deletions dev/issues/predict_slow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ using Optimisers

X = MLJBase.table(rand(Float32, 100, 3));
y = coerce(rand("abc", 100), Multiclass);
model = LaplaceClassification(optimiser=Optimisers.Adam(0.1), epochs=100);
model = LaplaceClassification(; optimiser=Optimisers.Adam(0.1), epochs=100);
fitresult, _, _ = MLJBase.fit(model, 2, X, y);
la = fitresult[1];
Xmat = matrix(X) |> permutedims;

# Single test sample:
Xtest = Xmat[:,1:10];
Xtest = Xmat[:, 1:10];
Xtest_tab = MLJBase.table(Xtest');
MLJBase.predict(model, fitresult, Xtest_tab); # warm up
LaplaceRedux.predict(la, Xmat); # warm up
@time MLJBase.predict(model, fitresult, Xtest_tab);
@time LaplaceRedux.predict(la, Xtest);
@time glm_predictive_distribution(la, Xtest);
@time glm_predictive_distribution(la, Xtest);
1,137 changes: 1,137 additions & 0 deletions dev/notebooks/mlj-interfacing/direct_mlj.ipynb

Large diffs are not rendered by default.

Loading
Loading