Skip to content

Commit

Permalink
upload inits
Browse files Browse the repository at this point in the history
  • Loading branch information
Harry24k committed Nov 3, 2023
1 parent ecc41ad commit 399a908
Show file tree
Hide file tree
Showing 13 changed files with 207 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ cython_debug/
#.idea/


__*
__*/*
___*
___*/*
/data/*
/logo/*
/demos/_*
Expand Down
6 changes: 6 additions & 0 deletions mair/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .nn.robmodel import RobModel
from .utils import load_model

# from .utils.datasets import Datasets

__version__ = "1.0.0"
99 changes: 99 additions & 0 deletions mair/attacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# None attacks
from .attacks.vanila import VANILA
from .attacks.gn import GN

# Linf attacks
from .attacks.fgsm import FGSM
from .attacks.bim import BIM
from .attacks.rfgsm import RFGSM
from .attacks.pgd import PGD
from .attacks.eotpgd import EOTPGD
from .attacks.ffgsm import FFGSM
from .attacks.tpgd import TPGD
from .attacks.mifgsm import MIFGSM
from .attacks.upgd import UPGD
from .attacks.apgd import APGD
from .attacks.apgdt import APGDT
from .attacks.difgsm import DIFGSM
from .attacks.tifgsm import TIFGSM
from .attacks.jitter import Jitter
from .attacks.nifgsm import NIFGSM
from .attacks.pgdrs import PGDRS
from .attacks.sinifgsm import SINIFGSM
from .attacks.vmifgsm import VMIFGSM
from .attacks.vnifgsm import VNIFGSM
from .attacks.spsa import SPSA
from .attacks.pifgsm import PIFGSM
from .attacks.pifgsmpp import PIFGSMPP

# L2 attacks
from .attacks.cw import CW
from .attacks.pgdl2 import PGDL2
from .attacks.pgdrsl2 import PGDRSL2
from .attacks.deepfool import DeepFool
from .attacks.eaden import EADEN

# L1 attacks
from .attacks.eadl1 import EADL1

# L0 attacks
from .attacks.sparsefool import SparseFool
from .attacks.onepixel import OnePixel
from .attacks.pixle import Pixle
from .attacks.jsma import JSMA

# Linf, L2 attacks
from .attacks.fab import FAB
from .attacks.autoattack import AutoAttack
from .attacks.square import Square

# Wrapper Class
from .wrappers.multiattack import MultiAttack
from .wrappers.lgv import LGV

__version__ = "3.5.1"
__all__ = [
"VANILA",
"GN",
"FGSM",
"BIM",
"RFGSM",
"PGD",
"EOTPGD",
"FFGSM",
"TPGD",
"MIFGSM",
"UPGD",
"APGD",
"APGDT",
"DIFGSM",
"TIFGSM",
"Jitter",
"NIFGSM",
"PGDRS",
"SINIFGSM",
"VMIFGSM",
"VNIFGSM",
"SPSA",
"JSMA",
"EADL1",
"EADEN",
"PIFGSM",
"PIFGSMPP",
"CW",
"PGDL2",
"DeepFool",
"PGDRSL2",
"SparseFool",
"OnePixel",
"Pixle",
"FAB",
"AutoAttack",
"Square",
"MultiAttack",
"LGV",
]
__wrapper__ = [
"LGV",
"MultiAttack",
]
Empty file.
Empty file.
4 changes: 4 additions & 0 deletions mair/defenses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .advtraining.standard import Standard
from .advtraining.at import AT
from .advtraining.trades import TRADES
from .advtraining.mart import MART
Empty file.
Empty file added mair/nn/__init__.py
Empty file.
Empty file added mair/nn/modules/__init__.py
Empty file.
1 change: 1 addition & 0 deletions mair/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .minimizer import *
Empty file added mair/transforms/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions mair/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# from .datasets import Datasets
from .models import load_model
from .cuda import manual_seed
from .data import get_subloader
from .eval import get_accuracy
90 changes: 90 additions & 0 deletions mair/utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from .lenet import LeNet, LeNetPP
from .mnist_ates import MNIST_ATES
from .mnist_dat import MNIST_DAT
from .mnist_fast import MNIST_Fast
from .preactresnet import PreActBlock, PreActResNet
from .resnet import ResBasicBlock, ResNet
from .densenet import DenseNet, Bottleneck
from .vgg import VGG
from .wideresnet import WideResNet


def load_model(model_name, n_classes):
if model_name == "LeNet":
return LeNet(n_classes)

if model_name == "LeNetPP":
return LeNetPP(n_classes)

elif model_name == "MNIST_ATES":
return MNIST_ATES(n_classes)

elif model_name == "MNIST_DAT":
return MNIST_DAT(n_classes)

elif model_name == "MNIST_Fast":
return MNIST_Fast(n_classes)

elif model_name == "WRN28-10":
model = WideResNet(
depth=28, num_classes=n_classes, widen_factor=10, dropRate=0.0
)

elif model_name == "WRN34-10":
model = WideResNet(
depth=34, num_classes=n_classes, widen_factor=10, dropRate=0.0
)

elif model_name == "PRN18":
model = PreActResNet(
PreActBlock, num_blocks=[2, 2, 2, 2], num_classes=n_classes
)

elif model_name == "ResNet10":
model = ResNet(ResBasicBlock, [1, 1, 1, 1], n_classes, in_channels=1)

elif model_name == "ResNet18":
model = ResNet(ResBasicBlock, [2, 2, 2, 2], n_classes)

elif model_name == "ResNet34":
model = ResNet(ResBasicBlock, [3, 4, 6, 3], n_classes)

elif model_name == "ResNet50":
model = ResNet(ResBasicBlock, [3, 4, 6, 3], n_classes)

elif model_name == "ResNet101":
model = ResNet(ResBasicBlock, [3, 4, 23, 3], n_classes)

elif model_name == "ResNet152":
model = ResNet(ResBasicBlock, [3, 8, 36, 3], n_classes)

elif model_name == "DenseNet121":
model = DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32)

elif model_name == "DenseNet169":
model = DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32)

elif model_name == "DenseNet201":
model = DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32)

elif model_name == "DenseNet161":
model = DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48)

elif model_name == "VGG11":
model = VGG("VGG11", n_classes)

elif model_name == "VGG13":
model = VGG("VGG13", n_classes)

elif model_name == "VGG16":
model = VGG("VGG16", n_classes)

elif model_name == "VGG19":
model = VGG("VGG19", n_classes)

else:
raise ValueError("Invalid model name.")

print(model_name, "is loaded.")

return model

0 comments on commit 399a908

Please sign in to comment.