"AI for Sustainability" Toolkit for Research and Analysis. ASTRA (अस्त्र) means a "tool" or "a weapon" in Sanskrit.
Stable version:
pip install astra-lib
Latest version:
pip install git+https://github.com/sustainability-lab/ASTRA
Please go through the contributing guidelines before making a contribution.
from astra.torch.data import load_mnist, load_cifar_10
ds, ds_name = load_cifar_10()
from astra.torch.models import MLP
mlp = MLP(input_dim=100, hidden_dims=[128, 64], output_dim=10, activation="relu", dropout=0.1)
from astra.torch.models import CNN
cnn = CNN(image_dim=32,
kernel_size=5,
n_channels=3,
conv_hidden_dims=[32, 64],
dense_hidden_dims=[128, 64],
output_dim=10)
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from astra.torch.models import EfficientNet
model = EfficientNet(efficientnet_b0, EfficientNet_B0_Weights.DEFAULT, output_dim=10)
from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.models import ViT
model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)
from astra.torch.utils import train_fn
result = train_fn(model, inputs, outputs, loss_fn, lr, n_epochs, batch_size, enable_tqdm=True)
print(result.keys()) # dict_keys(['epoch_losses', 'iter_losses'])
from astra.torch.utils import count_params
n_params = count_params(mlp)
import torch
from astra.torch.models import ViT
from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.utils import ravel_pytree
import optree
model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)
params = dict(model.named_parameters())
flat_params, unravel_fn = ravel_pytree(params)
unraveled_params = unravel_fn(flat_params) # returns the original params
# check if the tree structure is preserved
assert optree.tree_structure(params) == optree.tree_structure(unraveled_params)
# check if the values are preserved
for before_leaf, after_leaf in zip(optree.tree_leaves(params), optree.tree_leaves(unraveled_params)):
assert torch.all(before_leaf == after_leaf)