-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
207 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,8 +160,8 @@ cython_debug/ | |
#.idea/ | ||
|
||
|
||
__* | ||
__*/* | ||
___* | ||
___*/* | ||
/data/* | ||
/logo/* | ||
/demos/_* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .nn.robmodel import RobModel | ||
from .utils import load_model | ||
|
||
# from .utils.datasets import Datasets | ||
|
||
__version__ = "1.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# None attacks | ||
from .attacks.vanila import VANILA | ||
from .attacks.gn import GN | ||
|
||
# Linf attacks | ||
from .attacks.fgsm import FGSM | ||
from .attacks.bim import BIM | ||
from .attacks.rfgsm import RFGSM | ||
from .attacks.pgd import PGD | ||
from .attacks.eotpgd import EOTPGD | ||
from .attacks.ffgsm import FFGSM | ||
from .attacks.tpgd import TPGD | ||
from .attacks.mifgsm import MIFGSM | ||
from .attacks.upgd import UPGD | ||
from .attacks.apgd import APGD | ||
from .attacks.apgdt import APGDT | ||
from .attacks.difgsm import DIFGSM | ||
from .attacks.tifgsm import TIFGSM | ||
from .attacks.jitter import Jitter | ||
from .attacks.nifgsm import NIFGSM | ||
from .attacks.pgdrs import PGDRS | ||
from .attacks.sinifgsm import SINIFGSM | ||
from .attacks.vmifgsm import VMIFGSM | ||
from .attacks.vnifgsm import VNIFGSM | ||
from .attacks.spsa import SPSA | ||
from .attacks.pifgsm import PIFGSM | ||
from .attacks.pifgsmpp import PIFGSMPP | ||
|
||
# L2 attacks | ||
from .attacks.cw import CW | ||
from .attacks.pgdl2 import PGDL2 | ||
from .attacks.pgdrsl2 import PGDRSL2 | ||
from .attacks.deepfool import DeepFool | ||
from .attacks.eaden import EADEN | ||
|
||
# L1 attacks | ||
from .attacks.eadl1 import EADL1 | ||
|
||
# L0 attacks | ||
from .attacks.sparsefool import SparseFool | ||
from .attacks.onepixel import OnePixel | ||
from .attacks.pixle import Pixle | ||
from .attacks.jsma import JSMA | ||
|
||
# Linf, L2 attacks | ||
from .attacks.fab import FAB | ||
from .attacks.autoattack import AutoAttack | ||
from .attacks.square import Square | ||
|
||
# Wrapper Class | ||
from .wrappers.multiattack import MultiAttack | ||
from .wrappers.lgv import LGV | ||
|
||
__version__ = "3.5.1" | ||
__all__ = [ | ||
"VANILA", | ||
"GN", | ||
"FGSM", | ||
"BIM", | ||
"RFGSM", | ||
"PGD", | ||
"EOTPGD", | ||
"FFGSM", | ||
"TPGD", | ||
"MIFGSM", | ||
"UPGD", | ||
"APGD", | ||
"APGDT", | ||
"DIFGSM", | ||
"TIFGSM", | ||
"Jitter", | ||
"NIFGSM", | ||
"PGDRS", | ||
"SINIFGSM", | ||
"VMIFGSM", | ||
"VNIFGSM", | ||
"SPSA", | ||
"JSMA", | ||
"EADL1", | ||
"EADEN", | ||
"PIFGSM", | ||
"PIFGSMPP", | ||
"CW", | ||
"PGDL2", | ||
"DeepFool", | ||
"PGDRSL2", | ||
"SparseFool", | ||
"OnePixel", | ||
"Pixle", | ||
"FAB", | ||
"AutoAttack", | ||
"Square", | ||
"MultiAttack", | ||
"LGV", | ||
] | ||
__wrapper__ = [ | ||
"LGV", | ||
"MultiAttack", | ||
] |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .advtraining.standard import Standard | ||
from .advtraining.at import AT | ||
from .advtraining.trades import TRADES | ||
from .advtraining.mart import MART |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .minimizer import * |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# from .datasets import Datasets | ||
from .models import load_model | ||
from .cuda import manual_seed | ||
from .data import get_subloader | ||
from .eval import get_accuracy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from .lenet import LeNet, LeNetPP | ||
from .mnist_ates import MNIST_ATES | ||
from .mnist_dat import MNIST_DAT | ||
from .mnist_fast import MNIST_Fast | ||
from .preactresnet import PreActBlock, PreActResNet | ||
from .resnet import ResBasicBlock, ResNet | ||
from .densenet import DenseNet, Bottleneck | ||
from .vgg import VGG | ||
from .wideresnet import WideResNet | ||
|
||
|
||
def load_model(model_name, n_classes): | ||
if model_name == "LeNet": | ||
return LeNet(n_classes) | ||
|
||
if model_name == "LeNetPP": | ||
return LeNetPP(n_classes) | ||
|
||
elif model_name == "MNIST_ATES": | ||
return MNIST_ATES(n_classes) | ||
|
||
elif model_name == "MNIST_DAT": | ||
return MNIST_DAT(n_classes) | ||
|
||
elif model_name == "MNIST_Fast": | ||
return MNIST_Fast(n_classes) | ||
|
||
elif model_name == "WRN28-10": | ||
model = WideResNet( | ||
depth=28, num_classes=n_classes, widen_factor=10, dropRate=0.0 | ||
) | ||
|
||
elif model_name == "WRN34-10": | ||
model = WideResNet( | ||
depth=34, num_classes=n_classes, widen_factor=10, dropRate=0.0 | ||
) | ||
|
||
elif model_name == "PRN18": | ||
model = PreActResNet( | ||
PreActBlock, num_blocks=[2, 2, 2, 2], num_classes=n_classes | ||
) | ||
|
||
elif model_name == "ResNet10": | ||
model = ResNet(ResBasicBlock, [1, 1, 1, 1], n_classes, in_channels=1) | ||
|
||
elif model_name == "ResNet18": | ||
model = ResNet(ResBasicBlock, [2, 2, 2, 2], n_classes) | ||
|
||
elif model_name == "ResNet34": | ||
model = ResNet(ResBasicBlock, [3, 4, 6, 3], n_classes) | ||
|
||
elif model_name == "ResNet50": | ||
model = ResNet(ResBasicBlock, [3, 4, 6, 3], n_classes) | ||
|
||
elif model_name == "ResNet101": | ||
model = ResNet(ResBasicBlock, [3, 4, 23, 3], n_classes) | ||
|
||
elif model_name == "ResNet152": | ||
model = ResNet(ResBasicBlock, [3, 8, 36, 3], n_classes) | ||
|
||
elif model_name == "DenseNet121": | ||
model = DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32) | ||
|
||
elif model_name == "DenseNet169": | ||
model = DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32) | ||
|
||
elif model_name == "DenseNet201": | ||
model = DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32) | ||
|
||
elif model_name == "DenseNet161": | ||
model = DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48) | ||
|
||
elif model_name == "VGG11": | ||
model = VGG("VGG11", n_classes) | ||
|
||
elif model_name == "VGG13": | ||
model = VGG("VGG13", n_classes) | ||
|
||
elif model_name == "VGG16": | ||
model = VGG("VGG16", n_classes) | ||
|
||
elif model_name == "VGG19": | ||
model = VGG("VGG19", n_classes) | ||
|
||
else: | ||
raise ValueError("Invalid model name.") | ||
|
||
print(model_name, "is loaded.") | ||
|
||
return model |