diff --git a/.flake8 b/.flake8 index 9214cb9e..0efb58e6 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,2 @@ [flake8] -ignore = E501, W503 \ No newline at end of file +ignore = E501, W503, E203 diff --git a/config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml b/config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml new file mode 100644 index 00000000..2b5d9a3c --- /dev/null +++ b/config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml @@ -0,0 +1,24 @@ +defaults: + - _self_ + - /model/clip-vit@models: + - clip-vit-base-patch32 + - clip-vit-base-patch32_sun397 + - clip-vit-base-patch32_stanford-cars + - clip-vit-base-patch32_resisc45 + - clip-vit-base-patch32_eurosat + - clip-vit-base-patch32_svhn + - clip-vit-base-patch32_gtsrb + - clip-vit-base-patch32_mnist + - clip-vit-base-patch32_dtd + - /dataset/image_classification/train@train_datasets: + - tiny-imagenet + +_target_: fusion_bench.modelpool.CLIPVisionModelPool +_recursive_: false + +models: ??? +train_datasets: ??? + +processor: + _target_: transformers.CLIPProcessor.from_pretrained + pretrained_model_name_or_path: openai/clip-vit-base-patch32 diff --git a/fusion_bench/method/smile_upscaling/smile_upscaling.py b/fusion_bench/method/smile_upscaling/smile_upscaling.py index d8fea030..4056a7bc 100644 --- a/fusion_bench/method/smile_upscaling/smile_upscaling.py +++ b/fusion_bench/method/smile_upscaling/smile_upscaling.py @@ -424,8 +424,8 @@ def _average_experts(self, pretarined_model, finetuned_models, name: str): def _upscale_submodules( self, - pretrained_model, - finetuned_model, + pretrained_model: nn.Module, + finetuned_models: List[nn.Module], tqdm_desc: str = "Upscaling Linear Modules", ): """ @@ -446,9 +446,9 @@ def _upscale_submodules( if isinstance(module, self._linear_layer_cls): self._upscale_linear_layer( pretrained_model=pretrained_model, - finetuned_models=finetuned_model, + finetuned_models=finetuned_models, name=name, ) elif config.average_experts and len(tuple(module.named_modules())) == 1: # if the module is a leaf module, we perform a parameter average - self._average_experts(pretrained_model, finetuned_model, name) + self._average_experts(pretrained_model, finetuned_models, name) diff --git a/fusion_bench/method/weighted_average/weighted_average.py b/fusion_bench/method/weighted_average/weighted_average.py index a1185b29..9ddde161 100644 --- a/fusion_bench/method/weighted_average/weighted_average.py +++ b/fusion_bench/method/weighted_average/weighted_average.py @@ -37,9 +37,17 @@ class WeightedAverageAlgorithm(BaseModelFusionAlgorithm, SimpleProfilerMixin): "weights": "weights", } - def __init__(self, normalize: bool, weights: List[float], **kwargs): + def __init__( + self, + normalize: bool, + weights: List[float], + verbose: bool = True, + **kwargs, + ): self.normalize = normalize self.weights = weights + self.verbose = verbose + log.disabled = not self.verbose super().__init__(**kwargs) @override @@ -70,7 +78,8 @@ def run(self, modelpool: BaseModelPool): ) if self.normalize: weights = weights / np.sum(weights) - print(f"weights: {weights}, normalized: {self.normalize}") + if self.verbose: + print(f"weights: {weights}, normalized: {self.normalize}") sd: Optional[StateDictType] = None forward_model = None @@ -88,5 +97,6 @@ def run(self, modelpool: BaseModelPool): ) forward_model.load_state_dict(sd) - self.print_profile_summary() + if self.verbose: + self.print_profile_summary() return forward_model diff --git a/fusion_bench/modelpool/base_pool.py b/fusion_bench/modelpool/base_pool.py index cb987582..11db3e45 100644 --- a/fusion_bench/modelpool/base_pool.py +++ b/fusion_bench/modelpool/base_pool.py @@ -145,6 +145,13 @@ def load_model( ) return model + def load_pretrained_model(self, *args, **kwargs): + assert ( + self.has_pretrained + ), "No pretrained model available. Check `_pretrained_` is in the `models` key." + model = self.load_model("_pretrained_", *args, **kwargs) + return model + def load_pretrained_or_first_model(self, *args, **kwargs): """ Load the pretrained model if available, otherwise load the first available model. diff --git a/fusion_bench/models/smile_moe/__init__.py b/fusion_bench/models/smile_moe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fusion_bench/models/smile_moe/linear.py b/fusion_bench/models/smile_moe/linear.py new file mode 100644 index 00000000..096252f1 --- /dev/null +++ b/fusion_bench/models/smile_moe/linear.py @@ -0,0 +1,256 @@ +import logging +from typing import Dict, List, Tuple # noqa: F401 + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +log = logging.getLogger(__name__) + + +class ExpertNotTrainedError(Exception): + pass + + +def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool: + if isinstance(tensor, Tensor): + return torch.allclose(tensor, torch.zeros_like(tensor)) + else: + return all(_is_all_zeros(t) for t in tensor) + + +def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]: + u, s, vh = torch.linalg.svd( + w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None + ) + v = vh.T + return u, s, v + + +def svd( + w: Tensor, full_matrices=True, accelerator=None +) -> Tuple[Tensor, Tensor, Tensor]: + if accelerator is None: + return _svd(w, full_matrices=full_matrices) + original_device = w.device + w = w.to(accelerator) + u, s, v = _svd(w) + return u.to(original_device), s.to(original_device), v.to(original_device) + + +class SmileGate(nn.Module): + def __init__( + self, + input_features: int, + w_diff_list: List[Tensor], + k: int, + svd_list=None, # cached `svd_list`, pass it to avoid recomputing + upscaling_accelerator=None, + ): + super().__init__() + self.input_features = input_features + self.num_experts = len(w_diff_list) + weights = [] + for i, w_diff in enumerate(w_diff_list): + if svd_list is None: + u, s, v = svd(w_diff, accelerator=upscaling_accelerator) + else: + u, s, v = svd_list[i] + u = u[:, :k] + s = s[:k] + v = v[:, :k] + + # weights.append((s * v).T) + weights.append(v.T) + self.k = s.size(0) # k is the actual k after truncation + + weights = ( + torch.stack(weights, dim=0) + .reshape(self.num_experts * self.k, -1) + .contiguous() + ) + self.weights = nn.Parameter( + weights + ) # weights should be a tensor of shape (num_experts * k, n) + + def forward(self, x: Tensor): + batch_size = x.size(0) + if self.num_experts == 1: + return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype) + + routing_weights = F.linear(x, self.weights).view( + batch_size, self.num_experts, self.k + ) + routing_weights = routing_weights.norm(p=2, dim=2) + return routing_weights + + +class SmileCompressedLinear(nn.Module): + def __init__(self, model: nn.Linear, k: int, svd_cache=None): + super().__init__() + if svd_cache is None: + u, s, v = svd(model.weight) + else: + u, s, v = svd_cache + if k > 0: + u = u[:, :k] + s = s[:k] + v = v[:, :k] + + self.u = nn.Parameter(u) + self.svh = nn.Parameter((s * v).T) + + if model.bias is not None: + self.bias = nn.Parameter(model.bias.data, requires_grad=True) + else: + self.register_parameter("bias", None) + + def forward(self, x): + x = F.linear(x, self.svh) + x = F.linear(x, self.u, self.bias) + return x + + +class SmileMoELinear(nn.Module): + @torch.no_grad() + def __init__( + self, + pretrained_model: nn.Linear, + finetuned_models: List[nn.Linear], + gate_k: int, + k: int, + top_k: int = 1, + full_matrices=True, + upscaling_accelerator=None, + routing_use_diff=True, + ): + super().__init__() + self.num_experts = len(finetuned_models) + self.top_k = top_k + self.k = k + self.gate_k = gate_k + self.in_features = pretrained_model.in_features + self.out_features = pretrained_model.out_features + + w_diff_list = [m.weight - pretrained_model.weight for m in finetuned_models] + if _is_all_zeros(w_diff_list): + # All fine-tuned models are identical to the pretrained model + raise ExpertNotTrainedError() + + if routing_use_diff or k > 0: + svd_cache_list = [ + svd(w, full_matrices=full_matrices, accelerator=upscaling_accelerator) + for w in w_diff_list + ] # the svd cache list to avoid recomputing + + # construct the gate network + if routing_use_diff: + self.gate = SmileGate( + input_features=self.in_features, + w_diff_list=w_diff_list, + k=gate_k, + svd_list=svd_cache_list, + upscaling_accelerator=upscaling_accelerator, + ) + else: + self.gate = SmileGate( + input_features=self.in_features, + w_diff_list=[m.weight for m in finetuned_models], + k=gate_k, + svd_list=None, + upscaling_accelerator=upscaling_accelerator, + ) + + # construct experts + for m, w_diff in zip(finetuned_models, w_diff_list): + m.weight.data = w_diff + if k > 0: + experts = [ + SmileCompressedLinear(m, k, svd_cache=svd_cache) + for m, svd_cache in zip(finetuned_models, svd_cache_list) + ] + else: + # if k is not set (<0), we use the full fine-tuned model + experts = finetuned_models + self.experts = nn.ModuleList(experts) + + if pretrained_model.bias is not None: + for m in experts: + m.bias.data = m.bias.data - pretrained_model.bias + # assign the pretrained model (the shared part) + self.pretrained_model = pretrained_model + + def forward(self, hidden_states: Tensor): + pretrained_out = self.pretrained_model(hidden_states) + + input_shape = hidden_states.size() + hidden_states = hidden_states.view(-1, self.in_features) + + router_logits = self.gate(hidden_states) + routing_weights = F.softmax(router_logits, dim=1) + # sample the expert according to the routing weights + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = torch.zeros( + (hidden_states.size(0), self.out_features), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, self.in_features) + if current_state.numel() == 0: + continue + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + *input_shape[:-1], self.out_features + ) + final_hidden_states = pretrained_out + final_hidden_states + return final_hidden_states + + @property + def weight(self): + """ + Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device` + """ + return self.pretrained_model.weight + + @property + def bias(self): + return self.pretrained_model.bias + + def __repr__(self): + return ( + f"SingularMoELinear(" + f"in_features={self.pretrained_model.in_features}, " + f"out_features={self.pretrained_model.out_features}, " + f"num_experts={self.num_experts}, " + f"top_k={self.top_k}, " + f"gate_k={self.gate_k}, " + f"k={self.k}" + f")" + ) diff --git a/fusion_bench/models/utils.py b/fusion_bench/models/utils.py index cade99d6..8263aaf2 100644 --- a/fusion_bench/models/utils.py +++ b/fusion_bench/models/utils.py @@ -1,5 +1,7 @@ from typing import List +from torch import nn + def del_attr(obj, names: List[str]): """ @@ -45,3 +47,26 @@ def get_attr(obj, names: List[str]): return getattr(obj, names[0]) else: return get_attr(getattr(obj, names[0]), names[1:]) + + +def find_layers_with_type( + module: nn.Module, + layer_types=[nn.Linear], + prefix="", +): + """ + Recursively find the layers of a certain type in a module. + + Args: + module (nn.Module): PyTorch module. + layer_types (list): List of layer types to find. + prefix (str): A prefix to add to the layer names. + + Returns: + dict: Dictionary of layers of the given type(s) within the module. + """ + res = {} + for name, submodule in module.named_modules(prefix=prefix): + if isinstance(submodule, tuple(layer_types)): + res[name] = submodule + return res diff --git a/fusion_bench/programs/fabric_fusion_program.py b/fusion_bench/programs/fabric_fusion_program.py index 7be8df3b..5ded8a6d 100644 --- a/fusion_bench/programs/fabric_fusion_program.py +++ b/fusion_bench/programs/fabric_fusion_program.py @@ -1,11 +1,8 @@ -import importlib -import importlib.resources import json import logging import os -from typing import Callable, Dict, Iterable, Optional, Union +from typing import Callable, Dict, Iterable, Optional, Union # noqa: F401 -import hydra import lightning as L from omegaconf import DictConfig, OmegaConf from torch import nn diff --git a/fusion_bench/tasks/clip_classification/imagenet.py b/fusion_bench/tasks/clip_classification/imagenet.py new file mode 100644 index 00000000..0b96d4a0 --- /dev/null +++ b/fusion_bench/tasks/clip_classification/imagenet.py @@ -0,0 +1,1091 @@ +classnames = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "rooster", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite (bird of prey)", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder rattlesnake", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peafowl", + "quail", + "partridge", + "african grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern bird", + "crane bird", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany dog", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael dog", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres dog", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland dog", + "Great Pyrenees dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "brussels griffon", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog (xoloitzcuintli)", + "grey wolf", + "Alaskan tundra wolf", + "red wolf or maned wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket insect", + "stick insect", + "cockroach", + "praying mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral butterfly", + "ringlet butterfly", + "monarch butterfly", + "small white butterfly", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel horse", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram (adult male sheep)", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala (antelope)", + "gazelle", + "arabian camel", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi monkey", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek fish", + "eel", + "silver salmon", + "rock beauty fish", + "clownfish", + "sturgeon", + "gar fish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "trash can", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster / handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military hat (bearskin or shako)", + "beer bottle", + "beer glass", + "bell tower", + "baby bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "hunting bow", + "bow tie", + "brass memorial plaque", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "cardboard box / carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "storage chest", + "chiffonier", + "bell or wind chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "spiral or coil", + "combination lock", + "computer keyboard", + "candy store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "construction crane", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire truck", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask or respirator", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "radiator grille", + "grocery store", + "guillotine", + "hair clip", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "combine harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "gymnastic horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "carved pumpkin", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "music speaker", + "loupe magnifying glass", + "sawmill", + "magnetic compass", + "messenger bag", + "mailbox", + "tights", + "one-piece bathing suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine cabinet", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "ford model t", + "modem", + "monastery", + "monitor", + "moped", + "mortar and pestle", + "graduation cap", + "mosque", + "mosquito net", + "vespa", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "metal nail", + "neck brace", + "necklace", + "baby pacifier", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "pipe organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "product packet / packaging", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "railroad car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "drink pitcher", + "block plane", + "planetarium", + "plastic bag", + "plate rack", + "farm plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "soda bottle", + "plant pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "missile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "fishing casting reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler measuring stick", + "sneaker", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT monitor", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji screen / room divider", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "balaclava ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "keyboard space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglasses", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swim trunks / shorts", + "swing", + "electrical switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "hot tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vaulted or arched ceiling", + "velvet fabric", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "hair wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "airplane wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "sailboat", + "yurt", + "website", + "comic book", + "crossword", + "traffic or street sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "popsicle", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potatoes", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith apple", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "cherimoya (custard apple)", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "tea cup", + "eggnog", + "mountain", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "sandbar", + "beach", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star fungus", + "hen of the woods mushroom", + "bolete", + "corn cob", + "toilet paper", +] + +imagenet_templates = [ + "a bad photo of a {}.", + "a photo of many {}.", + "a sculpture of a {}.", + "a photo of the hard to see {}.", + "a low resolution photo of the {}.", + "a rendering of a {}.", + "graffiti of a {}.", + "a bad photo of the {}.", + "a cropped photo of the {}.", + "a tattoo of a {}.", + "the embroidered {}.", + "a photo of a hard to see {}.", + "a bright photo of a {}.", + "a photo of a clean {}.", + "a photo of a dirty {}.", + "a dark photo of the {}.", + "a drawing of a {}.", + "a photo of my {}.", + "the plastic {}.", + "a photo of the cool {}.", + "a close-up photo of a {}.", + "a black and white photo of the {}.", + "a painting of the {}.", + "a painting of a {}.", + "a pixelated photo of the {}.", + "a sculpture of the {}.", + "a bright photo of the {}.", + "a cropped photo of a {}.", + "a plastic {}.", + "a photo of the dirty {}.", + "a jpeg corrupted photo of a {}.", + "a blurry photo of the {}.", + "a photo of the {}.", + "a good photo of the {}.", + "a rendering of the {}.", + "a {} in a video game.", + "a photo of one {}.", + "a doodle of a {}.", + "a close-up photo of the {}.", + "a photo of a {}.", + "the origami {}.", + "the {} in a video game.", + "a sketch of a {}.", + "a doodle of the {}.", + "a origami {}.", + "a low resolution photo of a {}.", + "the toy {}.", + "a rendition of the {}.", + "a photo of the clean {}.", + "a photo of a large {}.", + "a rendition of a {}.", + "a photo of a nice {}.", + "a photo of a weird {}.", + "a blurry photo of a {}.", + "a cartoon {}.", + "art of a {}.", + "a sketch of the {}.", + "a embroidered {}.", + "a pixelated photo of a {}.", + "itap of the {}.", + "a jpeg corrupted photo of the {}.", + "a good photo of a {}.", + "a plushie {}.", + "a photo of the nice {}.", + "a photo of the small {}.", + "a photo of the weird {}.", + "the cartoon {}.", + "art of the {}.", + "a drawing of the {}.", + "a photo of the large {}.", + "a black and white photo of a {}.", + "the plushie {}.", + "a dark photo of a {}.", + "itap of a {}.", + "graffiti of the {}.", + "a toy {}.", + "itap of my {}.", + "a photo of a cool {}.", + "a photo of a small {}.", + "a tattoo of the {}.", +] + +templates = [lambda c: prompt.format(c) for prompt in imagenet_templates] + +if __name__ == "__main__": + print(f"number of classnames: {len(classnames)}") + print(f"number of templates: {len(templates)}")