Skip to content

Commit

Permalink
Merge pull request #126 from JuliaTrustworthyAI/direct_mlj_interface
Browse files Browse the repository at this point in the history
Direct mlj interface
  • Loading branch information
pasq-cat authored Nov 7, 2024
2 parents 38b6e69 + 12f2584 commit 5f66429
Show file tree
Hide file tree
Showing 18 changed files with 2,500 additions and 209 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
version:
- '1.9'
- '1.10'
- '1'
os:
- ubuntu-latest
- windows-latest
Expand Down
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ authors = ["Patrick Altmeyer"]
version = "1.1.1"

[deps]
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -19,20 +22,23 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Aqua = "0.8"
CategoricalDistributions = "0.1.15"
ChainRulesCore = "1.23.0"
Compat = "4.7.0"
Distributions = "0.25.109"
Flux = "0.12, 0.13, 0.14"
LinearAlgebra = "1.7, 1.10"
MLJBase = "1"
MLJModelInterface = "1.8.0"
MLUtils = "0.4"
Optimisers = "0.2, 0.3"
Random = "1.9, 1.10"
Statistics = "1"
Tables = "1.10.1"
Test = "1.9, 1.10"
Test = "1"
Tullio = "0.3.5"
Zygote = "0.6"
julia = "1.9, 1.10"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
6 changes: 3 additions & 3 deletions dev/issues/predict_slow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ using Optimisers

X = MLJBase.table(rand(Float32, 100, 3));
y = coerce(rand("abc", 100), Multiclass);
model = LaplaceClassification(optimiser=Optimisers.Adam(0.1), epochs=100);
model = LaplaceClassification(; optimiser=Optimisers.Adam(0.1), epochs=100);
fitresult, _, _ = MLJBase.fit(model, 2, X, y);
la = fitresult[1];
Xmat = matrix(X) |> permutedims;

# Single test sample:
Xtest = Xmat[:,1:10];
Xtest = Xmat[:, 1:10];
Xtest_tab = MLJBase.table(Xtest');
MLJBase.predict(model, fitresult, Xtest_tab); # warm up
LaplaceRedux.predict(la, Xmat); # warm up
@time MLJBase.predict(model, fitresult, Xtest_tab);
@time LaplaceRedux.predict(la, Xtest);
@time glm_predictive_distribution(la, Xtest);
@time glm_predictive_distribution(la, Xtest);
1,137 changes: 1,137 additions & 0 deletions dev/notebooks/mlj-interfacing/direct_mlj.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 5f66429

Please sign in to comment.