From 1d0d9fb1ca6a677016b8d9172bd6c393b248fac0 Mon Sep 17 00:00:00 2001 From: conradry Date: Wed, 3 Mar 2021 10:37:50 -0500 Subject: [PATCH 01/19] correct conda channels --- environment.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index c3a1bfa..7e4bc35 100644 --- a/environment.yml +++ b/environment.yml @@ -1,14 +1,17 @@ name: cellem channels: - defaults + - conda-forge + - simpleitk dependencies: + - pip - pytorch - torchvision - - albumentations - h5py - mlflow - simpleitk - scikit-learn - pip: + - albumentations - imagehash - segmentation-models-pytorch From c911a5a9f937e3e8645c79f1d01f7fae72162784 Mon Sep 17 00:00:00 2001 From: conradry Date: Thu, 19 Aug 2021 08:48:49 -0400 Subject: [PATCH 02/19] refactor dataset prep --- dataset/{raw => }/cleanup2d.py | 0 dataset/deduplicated/deduplicate.py | 192 -------------------------- dataset/{preprocess => }/mrc2byte.py | 0 dataset/raw/crop_patches.py | 119 ---------------- dataset/raw/cross_section3d.py | 158 --------------------- dataset/{preprocess => }/vid2stack.py | 0 6 files changed, 469 deletions(-) rename dataset/{raw => }/cleanup2d.py (100%) delete mode 100644 dataset/deduplicated/deduplicate.py rename dataset/{preprocess => }/mrc2byte.py (100%) delete mode 100644 dataset/raw/crop_patches.py delete mode 100644 dataset/raw/cross_section3d.py rename dataset/{preprocess => }/vid2stack.py (100%) diff --git a/dataset/raw/cleanup2d.py b/dataset/cleanup2d.py similarity index 100% rename from dataset/raw/cleanup2d.py rename to dataset/cleanup2d.py diff --git a/dataset/deduplicated/deduplicate.py b/dataset/deduplicated/deduplicate.py deleted file mode 100644 index 41e2cdc..0000000 --- a/dataset/deduplicated/deduplicate.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -Description: ------------- - -It is assumed that this script will be run after the cross_section3d.py and -crop_patches.py scripts. Errors are certain to occur if that is not the case. - -This script takes a directory containing image patches and their corresponding -hashes and performs deduplication of patches from within the same dataset. The resulting -array of deduplicated patches is stored is a list of filepaths in the given savedir with -the name deduplicated_fpaths.npz. - -Because we save the results of deduplication for each dataset separately, this script -easily handles the addition of new images to the patchdir. Datasets that have already -been deduplicated will be skipped. - -Example usage: --------------- - -python deduplicate.py {patchdir} {savedir} --min_distance 12 --processes 32 - -For help with arguments: ------------------------- - -python deduplicate.py --help -""" - -import os, argparse -import numpy as np -import dask.array as da -from glob import glob -from multiprocessing import Pool -from tqdm import tqdm - -#main function of the script -if __name__ == "__main__": - - #setup the argument parser - parser = argparse.ArgumentParser(description='Create dataset for nn experimentation') - parser.add_argument('patchdir', type=str, metavar='patchdir', help='Directory containing image patches and hashes') - parser.add_argument('savedir', type=str, metavar='savedir', - help='Path to save array containing the paths of exemplar images') - parser.add_argument('-d', '--min_distance', dest='min_distance', type=int, metavar='min_distance', default=12, - help='Minimum Hamming distance between hashes to be considered unique') - parser.add_argument('-p', '--processes', dest='processes', type=int, metavar='processes', default=32, - help='Number of processes to run, more processes run faster but consume memory') - - - #parse the arguments - args = parser.parse_args() - patchdir = args.patchdir - savedir = args.savedir - min_distance = args.min_distance - processes = args.processes - - #to avoid running this long script only to get a nasty error - #let's make sure that the savedir exists - if not os.path.isdir(savedir): - os.mkdir(savedir) - - #load all the image filenames and save them as a dask array. - #With over 5 million strings in each array, the memory - #requirements become fairly hefty. The dask array - #has slightly slower I/O but saves a considerable amount - #of memory. In case this deduplication script has been run before, we'll - #check to see if the dask array already exists - #in case new images were added in the intervening period, we're - #going to check that the impaths file is the most recently - #made update to the directory - #da_impaths_path = os.path.join(patchdir, 'raw_fpaths.npz') - #gather_impaths = False - #if os.path.isfile(da_impaths_path): - # impaths_update_time = os.stat(da_impaths_path).st_mtime - # patchdir_update_time = os.stat(patchdir).st_mtime - # gather_impaths = impaths_update_time != patchdir_update_time - #else: - # gather_impaths = True - - #make the list of impaths, if it doesn't exist or isn't - #recent enough - #if gather_impaths: - impaths = np.sort(glob(os.path.join(patchdir, '*.tiff'))) - #impaths = da.from_array(impaths) - #da.to_npy_stack(da_impaths_path, impaths) - #del impaths - - #load the dask array of impaths - #impaths = da.from_npy_stack(da_impaths_path) - print(f'Found {len(impaths)} images to deduplicate') - - def get_dataset_name(imf): - #function to extract the name of a dataset from the patch image file path - #in the cross_section.py script we added the handy -LOC- indicator to - #easily identify the source dataset from location information - return imf.split('/')[-1].split('-LOC-')[0] - - #extract the set of unique dataset names from all the impaths - with Pool(processes) as pool: - datasets = np.sort(pool.map(get_dataset_name, impaths)) - - #because we sorted the impaths, we know that all images from the - #same dataset will be grouped together. therefore, we only need - #to know the index of the first instance of a unique dataset name - #in order to get the indices of all the patches from that dataset - unq_datasets, indices = np.unique(datasets, return_index=True) - print(f'Deduplicating {len(unq_datasets)} unique datasets') - - #we can delete the datasets array now - del datasets - - #add the last index for impaths such that we have complete intervals - indices = np.append(indices, len(impaths)) - - #make groups of image patches by source dataset - groups_impaths = [] - for si, ei in zip(indices[:-1], indices[1:]): - #have to call .compute() for a dask array - groups_impaths.append(impaths[si:ei]) - - #now we can delete the impaths dask array - del impaths - - #sanity check that we have the same number of - #unique datasets and impath groups - assert(len(unq_datasets) == len(groups_impaths)) - - #define the function for deduplication of a group of image paths - def group_dedupe(args): - #two arguments are the unique dataset name and the filepaths of - #the patches that belong to that dataset - dataset_name, impaths = args - - #check if we already processed this dataset name, if so - #then we can skip it, this makes it very easy to add new datasets - exemplar_fpath = os.path.join(savedir, f'{dataset_name}_exemplars.npy') - if os.path.isfile(exemplar_fpath): - return None - - #randomly permute the impaths such that we'll have random ordering - impaths = np.random.permutation(impaths) - - #requires that hash array and tiff image are in the - #same directory (which is how crop_patches.py is setup) - hashes = np.array([np.load(ip.replace('.tiff', '.npy')).ravel() for ip in impaths]) - - #make a list of exemplar images to keep - exemplars = [] - impaths = np.array(impaths) - - #loop through the hashes and assign images to sets of near duplicates - #until all of the hashes are exhausted - while len(hashes) > 0: - #the reference hash is the first one in the list - #of remaining hashes - ref_hash = hashes[0] - - #a match has Hamming distance less than min_distance - matches = np.where(np.logical_xor(ref_hash, hashes).sum(1) <= min_distance)[0] - - #choose the first match as the exemplar and add - #it's filepath to the list. this is random because we - #permuted the paths earlier. a different image could be - #chosen on another run of this script - exemplars.append(impaths[matches[0]]) - - #remove all the matched images from both hashes and impaths - hashes = np.delete(hashes, matches, axis=0) - impaths = np.delete(impaths, matches, axis=0) - - #because this script can take a long time to complete, let's save checkpoint - #results for each dataset when it's finished with deduplication, then we have - #the option to resume later on - np.save(exemplar_fpath, np.array(exemplars)) - - #run the dataset level deduplication on multiple groups at once - #results for each group are saved in separate .npy files, if the - #.npy file already exists, then it will be skipped. This makes it - #easier to add new datasets to the existing directory structure - with Pool(processes) as pool: - pool.map(group_dedupe, list(zip(unq_datasets, groups_impaths))) - - #now that all the patches from individual datasets are deduplicated, - #we'll combine all the separate .npy arrays into a single dask array and save it - exemplar_fpaths = glob(os.path.join(savedir, '*_exemplars.npy')) - deduplicated_fpaths = np.concatenate([np.load(fp) for fp in exemplar_fpaths]) - - #convert to dask and save - deduplicated_fpaths = da.from_array(deduplicated_fpaths) - da.to_npy_stack(os.path.join(savedir, 'deduplicated_fpaths.npz'), deduplicated_fpaths) - - #print the total number of deduplicated patches - print(f'{len(deduplicated_fpaths)} patches remaining after deduplication.') diff --git a/dataset/preprocess/mrc2byte.py b/dataset/mrc2byte.py similarity index 100% rename from dataset/preprocess/mrc2byte.py rename to dataset/mrc2byte.py diff --git a/dataset/raw/crop_patches.py b/dataset/raw/crop_patches.py deleted file mode 100644 index 9c603e9..0000000 --- a/dataset/raw/crop_patches.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -Description: ------------- - -It is assumed that this script will be run after the cross_section3d.py script. Errors -will arise if this is not the case. - -This script takes a directory containing 2d tiff images and crops those large images -into squares of a given dimension (default 224). In addition to creating the patch, -it also calculates and saves a difference hash for that patch. Doing both in a single -step significantly cuts down on I/O time. Both the patch and the hash are saved -in the same directory. All patches are .tiff and all hashes are .npy. - -Example usage: --------------- - -python crop_patches.py {imdir} {patchdir} --crop_size 224 --hash_size 8 --processes 4 - -For help with arguments: ------------------------- - -python crop_patches.py --help -""" - -import os -import argparse -import imagehash -import numpy as np -from skimage.io import imread, imsave -from glob import glob -from PIL import Image -from multiprocessing import Pool - -def calculate_hash(image, crop_size, hash_size=8): - #Creates the dhash for the resized image - #this guarantees that smaller images are not more likely - #to be recognized as unique - return imagehash.dhash(Image.fromarray(image).resize((crop_size, crop_size), resample=2), hash_size=hash_size).hash - -def patch_and_hash(impath, patchdir, crop_size=224, hash_size=8): - #load the image - image = imread(impath) - - #handle rgb by keeping only the first channel - if image.ndim == 3: - image = image[..., 0] - - #assumes that we are working with the output of cross_sections.py - #which saves all images as .tiff - prefix = impath.split('/')[-1].split('.tiff')[0] - - #get the image size - ysize, xsize = image.shape - - #this means that the smallest allowed image patch must have - #at least half of the desired crop size in both dimensions - ny = max(1, int(round(ysize / crop_size))) - nx = max(1, int(round(xsize / crop_size))) - - for y in range(ny): - #start and end indices for y - ys = y * crop_size - ye = min(ys + crop_size, ysize) - for x in range(nx): - #start and end indices for x - xs = x * crop_size - xe = min(xs + crop_size, xsize) - - #crop the patch and calculate its hash - patch = image[ys:ye, xs:xe] - patch_hash = calculate_hash(patch, crop_size, hash_size) - - #make the output file paths - patch_path = os.path.join(patchdir, f'{prefix}_{ys}_{xs}.tiff') - hash_path = patch_path.replace('.tiff', '.npy') - - #save the patch and the hash - imsave(patch_path, patch, check_contrast=False) - np.save(hash_path, patch_hash) - -#main function of the script -if __name__ == "__main__": - #setup the argument parser - parser = argparse.ArgumentParser(description='Create dataset for nn experimentation') - parser.add_argument('imdir', type=str, metavar='imdir', help='Directory containing tiff images') - parser.add_argument('patchdir', type=str, metavar='patchdir', help='Directory in which to save cropped patches') - parser.add_argument('-cs', '--crop_size', dest='crop_size', type=int, metavar='crop_size', default=224, - help='Size of square image patches. Default 224.') - parser.add_argument('-hs', '--hash_size', dest='hash_size', type=int, metavar='hash_size', default=8, - help='Size of the image hash. Default 8.') - parser.add_argument('-p', '--processes', dest='processes', type=int, metavar='processes', default=32, - help='Number of processes to run, more processes will run faster but consume more memory') - - - args = parser.parse_args() - - #read in the parser arguments - imdir = args.imdir - patchdir = args.patchdir - crop_size = args.crop_size - hash_size = args.hash_size - processes = args.processes - - #make sure the patchdir exists - if not os.path.isdir(patchdir): - os.mkdir(patchdir) - - #get the list of all tiff files in the imdir - impaths = np.sort(glob(os.path.join(imdir, '*.tiff'))) - print(f'Found {len(impaths)} tiff images to crop.') - - def map_func(impath): - patch_and_hash(impath, patchdir, crop_size, hash_size) - return None - - #loop over the images and save patches and hashes - #using the given number of processes - with Pool(processes) as pool: - pool.map(map_func, impaths) \ No newline at end of file diff --git a/dataset/raw/cross_section3d.py b/dataset/raw/cross_section3d.py deleted file mode 100644 index b3470b8..0000000 --- a/dataset/raw/cross_section3d.py +++ /dev/null @@ -1,158 +0,0 @@ -""" -Description: ------------- - -This script accepts a directory with image volume files and slices cross sections -from the given axes (xy, xz, yz). The resultant cross sections are saved in -the given save directory. - -Importantly, the saved image files are given a slightly different filename: -We add '-LOC-{axis}_{slice_index}' to the end of the filename, where axis denotes the -cross-sectioning plane (0->xy, 1->xz, 2->yz) and the slice index is the position of -the cross-section on that axis. Once images from 2d and 3d datasets -start getting mixed together, it can be difficult to keep track of the -provenance of each patch. Everything that appears before '-LOC-' is the -name of the original dataset, the axis and slice index allow us to lookup the -exact location of the cross-section in the volume. - -Example usage: --------------- - -python cross_section3d.py {imdir} {savedir} --axes 0 1 2 --spacing 1 --processes 4 - -For help with arguments: ------------------------- - -python cross_section3d.py --help -""" - -import os, math -import argparse -import numpy as np -import SimpleITK as sitk -from glob import glob -from skimage.io import imsave -from multiprocessing import Pool - -MAX_VALUES_BY_DTYPE = { - np.dtype("uint8"): 255, - np.dtype("uint16"): 65535, - np.dtype("int16"): 32767, - np.dtype("uint32"): 4294967295, - np.dtype("float32"): 1.0, -} - -#main function of the script -if __name__ == "__main__": - - #setup the argument parser - parser = argparse.ArgumentParser(description='Create dataset for nn experimentation') - parser.add_argument('imdir', type=str, metavar='imdir', help='Directory containing volume files') - parser.add_argument('savedir', type=str, metavar='savedir', help='Path to save the cross sections') - parser.add_argument('-a', '--axes', dest='axes', type=int, metavar='axes', nargs='+', default=[0, 1, 2], - help='Volume axes along which to slice (0-xy, 1-xz, 2-yz)') - parser.add_argument('-s', '--spacing', dest='spacing', type=int, metavar='spacing', default=1, - help='Spacing between image slices') - parser.add_argument('-p', '--processes', dest='processes', type=int, metavar='processes', default=4, - help='Number of processes to run, more processes will run faster but consume more memory') - - - args = parser.parse_args() - - #read in the parser arguments - imdir = args.imdir - savedir = args.savedir - axes = args.axes - spacing = args.spacing - processes = args.processes - - #check if the savedir exists, if not create it - if not os.path.isdir(savedir): - os.mkdir(savedir) - - #get the list of all volumes (mrc, tif, nrrd, nii.gz) - fpaths = np.array(glob(imdir + '*')) - - print(f'Found {len(fpaths)} image volumes') - - #loop over each fpath and save the slices - def create_slices(fp): - #try to load the volume, if it's not possible - #then pass - try: - im = sitk.ReadImage(fp) - print(im.GetSize(), fp) - except: - print('Failed to open: ', fp) - pass - - #extract the pixel size from the volume - #if the z-pixel size is more than 20% different - #from the x-pixel size, don't slice over orthogonal - #directions - pixel_sizes = im.GetSpacing() - anisotropy = np.abs(pixel_sizes[0] - pixel_sizes[2]) / pixel_sizes[0] - - #convert the volume to numpy - im = sitk.GetArrayFromImage(im) - - assert (im.min() >= 0), 'Negative images not allowed!' - - #check if the volume is uint8, convert if not - if im.dtype != np.uint8: - dtype = im.dtype - max_value = MAX_VALUES_BY_DTYPE[dtype] - im = im.astype(np.float32) / max_value - im = (im * 255).astype(np.uint8) - - #establish a filename prefix from the imvolume - #extract the experiment name from the filepath - #add a special case for .nii.gz files - if fp[-5:] == 'nii.gz': - fext = 'nii.gz' - else: - fext = fp.split('.')[-1] - exp_name = fp.split('/')[-1].split(f'.{fext}')[0] - - #loop over the axes and save slices - for axis in axes: - #only process xy slices if the volume is anisotropic - if anisotropy > 0.2 and axis != 0: - continue - - #get the axis dimension and get evenly spaced slice indices - nmax = im.shape[axis] - 1 - slice_indices = np.arange(0, nmax, spacing, dtype=np.long) - - #for the index naming convention we want to pad all slice indices - #with zeros up to some length: eg. 1 --> 0001 to match 999 --> 0999 - #we get the number of zeros from nmax - zpad = math.ceil(math.log(nmax, 10)) - - for idx in slice_indices: - index_str = str(idx).zfill(zpad) - #add the -LOC- to indicate the point of separation between - #the dataset name and the slice location information - slice_name = f'{exp_name}-LOC-{axis}_{index_str}.tiff' - - #don't save anything if the slice already exists - if os.path.isfile(os.path.join(savedir, slice_name)): - continue - - #slice the volume on the proper axis - if axis == 0: - im_slice = im[idx] - elif axis == 1: - im_slice = im[:, idx] - else: - im_slice = im[:, :, idx] - - #save the slice - imsave(os.path.join(savedir, slice_name), im_slice, check_contrast=False) - - #running the function with multiple processes - #results in a much faster runtime - with Pool(processes) as pool: - pool.map(create_slices, fpaths) - - print('Finished') \ No newline at end of file diff --git a/dataset/preprocess/vid2stack.py b/dataset/vid2stack.py similarity index 100% rename from dataset/preprocess/vid2stack.py rename to dataset/vid2stack.py From 60cef802870164a253271195a843e69360d49237 Mon Sep 17 00:00:00 2001 From: conradry Date: Thu, 19 Aug 2021 08:57:00 -0400 Subject: [PATCH 03/19] swav pretraining --- pretraining/swav/LARC.py | 132 +++++++++ pretraining/swav/__init__.py | 0 pretraining/swav/dataset.py | 89 ++++++ pretraining/swav/resnet50.py | 353 ++++++++++++++++++++++ pretraining/swav/swav_config.yaml | 58 ++++ pretraining/swav/train_swav.py | 478 ++++++++++++++++++++++++++++++ pretraining/swav/utils.py | 95 ++++++ 7 files changed, 1205 insertions(+) create mode 100644 pretraining/swav/LARC.py create mode 100644 pretraining/swav/__init__.py create mode 100644 pretraining/swav/dataset.py create mode 100644 pretraining/swav/resnet50.py create mode 100644 pretraining/swav/swav_config.yaml create mode 100644 pretraining/swav/train_swav.py create mode 100644 pretraining/swav/utils.py diff --git a/pretraining/swav/LARC.py b/pretraining/swav/LARC.py new file mode 100644 index 0000000..2b22725 --- /dev/null +++ b/pretraining/swav/LARC.py @@ -0,0 +1,132 @@ +""" +Copied with modificiation from https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + +Modifications: +-------------- + +1. Added a condition in step() to not adapt the lr for parameter groups +with 'adapt_lr' == False. + +""" + +import torch +from torch import nn +from torch.nn.parameter import Parameter + +class LARC: + """ + :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, + in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive + local learning rate for each individual parameter. The algorithm is designed to improve + convergence of large batch training. + + See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. + + In practice it modifies the gradients of parameters as a proxy for modifying the learning rate + of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. + + ``` + model = ... + optim = torch.optim.Adam(model.parameters(), lr=...) + optim = LARC(optim) + ``` + + It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. + + ``` + model = ... + optim = torch.optim.Adam(model.parameters(), lr=...) + optim = LARC(optim) + optim = apex.fp16_utils.FP16_Optimizer(optim) + ``` + + Args: + optimizer: Pytorch optimizer to wrap and modify learning rate for. + trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 + clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. + eps: epsilon kludge to help with numerical stability while calculating adaptive_lr + """ + + def __init__( + self, + optimizer, + trust_coefficient=1e-3, + clip=False, + eps=1e-8 + ): + self.optim = optimizer + self.trust_coefficient = trust_coefficient + self.eps = eps + self.clip = clip + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + @property + def state(self): + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + @property + def param_groups(self): + return self.optim.param_groups + + @param_groups.setter + def param_groups(self, value): + self.optim.param_groups = value + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + + def zero_grad(self): + self.optim.zero_grad() + + def add_param_group(self, param_group): + self.optim.add_param_group( param_group) + + def step(self): + with torch.no_grad(): + weight_decays = [] + for group in self.optim.param_groups: + # absorb weight decay control from optimizer + weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 + weight_decays.append(weight_decay) + group['weight_decay'] = 0 + + #check if adapt_lr flag exists and if it's False + #don't adapt the lr for those parameters + if 'adapt_lr' in group: + if group['adapt_lr'] is False: + continue #go to next group + + for p in group['params']: + if p.grad is None: + continue + + param_norm = torch.norm(p.data) + grad_norm = torch.norm(p.grad.data) + + if param_norm != 0 and grad_norm != 0: + # calculate adaptive lr + weight decay + adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) + + # clip learning rate for LARC + if self.clip: + # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` + adaptive_lr = min(adaptive_lr/group['lr'], 1) + + p.grad.data += weight_decay * p.data + p.grad.data *= adaptive_lr + + self.optim.step() + # return weight decay control to optimizer + for i, group in enumerate(self.optim.param_groups): + group['weight_decay'] = weight_decays[i] \ No newline at end of file diff --git a/pretraining/swav/__init__.py b/pretraining/swav/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pretraining/swav/dataset.py b/pretraining/swav/dataset.py new file mode 100644 index 0000000..5a694dc --- /dev/null +++ b/pretraining/swav/dataset.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +import os +import pickle +import random +import torch +import numpy as np +from glob import glob +from PIL import Image +from PIL import ImageFilter +from torch.utils.data import Dataset + +class MultiCropDataset(Dataset): + def __init__( + self, + data_path, + transforms + ): + super(MultiCropDataset, self).__init__() + self.data_path = data_path + + manifest_file = os.path.join(data_path, 'manifest.pkl') + if os.path.isfile(manifest_file): + with open(manifest_file, mode='rb') as f: + self.fpaths = pickle.load(f) + else: + self.fpaths = glob(data_path + '**/*') + with open(manifest_file, mode='wb') as f: + pickle.dump(self.fpaths, f) + + print(f'Found {len(self.fpaths)} images in dataset.') + self.tfs = transforms + + def __len__(self): + return len(self.fpaths) + + def __getitem__(self, index): + #get the filepath to load + f = self.fpaths[index] + + # process multiple transformed crops of the image + image = Image.open(f) + multi_crops = list(map( + lambda tfs: tfs(image), self.tfs + )) + + return multi_crops + +class RandomGaussianBlur: + """ + Apply Gaussian Blur to the PIL image. Take the radius and probability of + application as the parameter. + This transform was used in SimCLR - https://arxiv.org/abs/2002.05709 + """ + + def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): + self.prob = p + self.radius_min = radius_min + self.radius_max = radius_max + + def __call__(self, img): + do_it = np.random.rand() <= self.prob + if not do_it: + return img + + return img.filter( + ImageFilter.GaussianBlur( + radius=random.uniform(self.radius_min, self.radius_max) + ) + ) + +class GaussNoise: + """Gaussian Noise to be applied to images that have been scaled to fit in the range 0-1""" + def __init__(self, var_limit=(1e-5, 1e-4), p=0.5): + self.var_limit = np.log(var_limit) + self.p = p + + def __call__(self, image): + if np.random.random() < self.p: + sigma = np.exp(np.random.uniform(*self.var_limit)) ** 0.5 + noise = np.random.normal(0, sigma, size=image.shape).astype(np.float32) + image = image + torch.from_numpy(noise) + image = torch.clamp(image, 0, 1) + + return image diff --git a/pretraining/swav/resnet50.py b/pretraining/swav/resnet50.py new file mode 100644 index 0000000..41e90aa --- /dev/null +++ b/pretraining/swav/resnet50.py @@ -0,0 +1,353 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ["downsample"] + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ["downsample"] + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block, + layers, + zero_init_residual=False, + groups=1, + widen=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + normalize=False, + output_dim=0, + hidden_mlp=0, + nmb_prototypes=0, + eval_mode=False, + ): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.eval_mode = eval_mode + self.padding = nn.ConstantPad2d(1, 0.0) + + self.inplanes = width_per_group * widen + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + + # change padding 3 -> 2 compared to original torchvision code because added a padding layer + num_out_filters = width_per_group * widen + self.conv1 = nn.Conv2d( + 1, num_out_filters, kernel_size=7, stride=2, padding=2, bias=False + ) + self.bn1 = norm_layer(num_out_filters) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, num_out_filters, layers[0]) + num_out_filters *= 2 + self.layer2 = self._make_layer( + block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + num_out_filters *= 2 + self.layer3 = self._make_layer( + block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + num_out_filters *= 2 + self.layer4 = self._make_layer( + block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + # normalize output features + self.l2norm = normalize + + # projection head + if output_dim == 0: + self.projection_head = None + elif hidden_mlp == 0: + self.projection_head = nn.Linear(num_out_filters * block.expansion, output_dim) + else: + self.projection_head = nn.Sequential( + nn.Linear(num_out_filters * block.expansion, hidden_mlp), + nn.BatchNorm1d(hidden_mlp), + nn.ReLU(inplace=True), + nn.Linear(hidden_mlp, output_dim), + ) + + # prototype layer + self.prototypes = None + if isinstance(nmb_prototypes, list): + self.prototypes = MultiPrototypes(output_dim, nmb_prototypes) + elif nmb_prototypes > 0: + self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def forward_backbone(self, x): + x = self.padding(x) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.eval_mode: + return x + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + return x + + def forward_head(self, x): + if self.projection_head is not None: + x = self.projection_head(x) + + if self.l2norm: + x = nn.functional.normalize(x, dim=1, p=2) + + if self.prototypes is not None: + return x, self.prototypes(x) + return x + + def forward(self, inputs): + if not isinstance(inputs, list): + inputs = [inputs] + idx_crops = torch.cumsum(torch.unique_consecutive( + torch.tensor([inp.shape[-1] for inp in inputs]), + return_counts=True, + )[1], 0) + start_idx = 0 + for end_idx in idx_crops: + _out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True)) + if start_idx == 0: + output = _out + else: + output = torch.cat((output, _out)) + start_idx = end_idx + return self.forward_head(output) + + +class MultiPrototypes(nn.Module): + def __init__(self, output_dim, nmb_prototypes): + super(MultiPrototypes, self).__init__() + self.nmb_heads = len(nmb_prototypes) + for i, k in enumerate(nmb_prototypes): + self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False)) + + def forward(self, x): + out = [] + for i in range(self.nmb_heads): + out.append(getattr(self, "prototypes" + str(i))(x)) + return out + + +def resnet50(**kwargs): + return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + +def resnet50w2(**kwargs): + return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs) + + +def resnet50w4(**kwargs): + return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs) + + +def resnet50w5(**kwargs): + return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs) diff --git a/pretraining/swav/swav_config.yaml b/pretraining/swav/swav_config.yaml new file mode 100644 index 0000000..33f5cc3 --- /dev/null +++ b/pretraining/swav/swav_config.yaml @@ -0,0 +1,58 @@ +# training parameters +experiment_name: "SWaV_CEM1.5M" +data_path: "/data/IASEM/conradrw/data/cem_datasets/v1_cem/cem1.4M/" +model_path: "/data/IASEM/conradrw/models/SWaV_cem1.4M/" + +print_freq: 1 + +arch: "resnet50" +hidden_mlp: 2048 +workers: 16 +checkpoint_freq: 25 +use_fp16: True +seed: 1447 +resume: null + +epochs: 400 +warmup_epochs: 0 +start_warmup: 0 +batch_size: 64 +base_lr: 0.015 +final_lr: 0.0006 +wd: 0.000001 +freeze_prototypes_niters: 5005 + +# distributed training parameters +world_size: 1 +rank: 0 +dist_url: "tcp://localhost:10001" +dist_backend: "nccl" +multiprocessing_distributed: True + +# SWaV parameters +nmb_crops: + - 2 + - 6 +size_crops: + - 224 + - 96 +min_scale_crops: + - 0.14 + - 0.05 +max_scale_crops: + - 1. + - 0.14 +crops_for_assign: + - 0 + - 1 +temperature: 0.1 +epsilon: 0.05 +sinkhorn_iterations: 3 +feat_dim: 128 +nmb_prototypes: 3000 +queue_length: 3840 +epoch_queue_starts: 15 + +norms: + - 0.58331613 + - 0.09966064 diff --git a/pretraining/swav/train_swav.py b/pretraining/swav/train_swav.py new file mode 100644 index 0000000..02d904a --- /dev/null +++ b/pretraining/swav/train_swav.py @@ -0,0 +1,478 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +import argparse +import builtins +import math +import os +import yaml +import shutil +import time +import mlflow + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.multiprocessing as mp +import torch.distributed as dist +import torchvision.transforms as tf + +from torch.cuda.amp import autocast +from torch.cuda.amp import GradScaler + +import resnet50 as resnet_models +from LARC import LARC +from utils import ( + fix_random_seeds, + AverageMeter, + init_distributed_mode +) +from dataset import MultiCropDataset, RandomGaussianBlur, GaussNoise + +def parse_args(): + parser = argparse.ArgumentParser(description='PyTorch SWaV Training') + parser.add_argument('config', help='Path to .yaml training config file') + return parser.parse_args() + +def main(): + args = parse_args() + with open(args.config, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + config = {**config, **vars(args)} + + # load config dictionary into args + args = argparse.Namespace(**config) + + if not os.path.isdir(args.model_path): + os.mkdir(args.model_path) + + #world size is the number of processes that will run + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + + args.ngpus_per_node = ngpus_per_node + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + +def main_worker(gpu, ngpus_per_node, args): + args.gpu = gpu + + # suppress printing if not master process + if args.multiprocessing_distributed and args.gpu != 0: + def print_pass(*args): + pass + builtins.print = print_pass + + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + + # initialize distributed environment + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + # opt for random seeds for now + #fix_random_seeds(args.seed) + + # set the image transforms + mean, std = args.norms + normalize = tf.Normalize(mean=[mean], std=[std]) + + # list of transforms, one for each crop + assert len(args.size_crops) == len(args.nmb_crops) + assert len(args.min_scale_crops) == len(args.nmb_crops) + assert len(args.max_scale_crops) == len(args.nmb_crops) + transforms = [] + for i in range(len(args.size_crops)): + crop_size = args.size_crops[i] + min_scale = args.min_scale_crops[i] + max_scale = args.max_scale_crops[i] + num = args.nmb_crops[i] + + transforms.extend([tf.Compose([ + tf.Grayscale(3), + tf.RandomApply([tf.RandomRotation(180)], p=0.5), + tf.RandomResizedCrop(crop_size, scale=(min_scale, max_scale)), + tf.ColorJitter(0.4, 0.4, 0.4, 0.1), + RandomGaussianBlur(0.5, 0.1, 2.), + tf.Grayscale(1), + tf.RandomHorizontalFlip(), + tf.RandomVerticalFlip(), + tf.ToTensor(), + GaussNoise(p=0.5), + normalize + ])] * num) + + # build data + train_dataset = MultiCropDataset( + args.data_path, + transforms + ) + + sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + train_loader = torch.utils.data.DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.batch_size, + num_workers=args.workers, + pin_memory=True, + drop_last=True + ) + + # build model + model = resnet_models.__dict__[args.arch]( + normalize=True, + hidden_mlp=args.hidden_mlp, + output_dim=args.feat_dim, + nmb_prototypes=args.nmb_prototypes, + ) + + # synchronize batch norm layers + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + + # build optimizer + optimizer = torch.optim.SGD( + model.parameters(), + lr=args.base_lr, + momentum=0.9, + weight_decay=args.wd, + ) + + optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) + warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs) + iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs)) + cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \ + math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters]) + lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) + + # init gradient scaler if needed + scaler = GradScaler() if args.use_fp16 else None + + # optionally resume from a checkpoint + run_id = None + args.start_epoch = 0 + + # optionally resume from a checkpoint + if args.resume is not None: + print(f"=> loading checkpoint '{args.resume}'") + + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + checkpoint = torch.load(args.resume, map_location=f'cuda:{args.gpu}') + + args.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + if scaler is not None: + scaler.load_state_dict(checkpoint['scaler']) + + run_id = checkpoint['run_id'] + print(f"=> loaded checkpoint '{args.resume}' (epoch {args.start_epoch})") + + # log parameters for run, or resume existing run + if run_id is None and args.rank == 0: + # log parameters in mlflow + mlflow.end_run() + mlflow.set_experiment(args.experiment_name) + mlflow.log_artifact(args.config) + + #we don't want to add everything in the config + #to mlflow parameters, we'll just add the most + #likely to change parameters + mlflow.log_param('architecture', args.arch) + mlflow.log_param('epochs', args.epochs) + mlflow.log_param('batch_size', args.batch_size) + mlflow.log_param('base_lr', args.base_lr) + mlflow.log_param('final_lr', args.final_lr) + mlflow.log_param('temperature', args.temperature) + mlflow.log_param('feature_dim', args.feat_dim) + mlflow.log_param('queue_length', args.queue_length) + else: + # resume existing run + mlflow.start_run(run_id=run_id) + + # build the queue, or resume it + queue = None + queue_path = os.path.join(args.model_path, "queue" + str(args.rank) + ".pth") + if os.path.isfile(queue_path): + queue = torch.load(queue_path)["queue"] + + # the queue needs to be divisible by the batch size + args.queue_length -= args.queue_length % (args.batch_size * args.world_size) + + cudnn.benchmark = True + for epoch in range(args.start_epoch, args.epochs): + # set sampler + train_loader.sampler.set_epoch(epoch) + + # optionally starts a queue + if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None: + queue = torch.zeros( + len(args.crops_for_assign), + args.queue_length // args.world_size, + args.feat_dim, + ).cuda() + + # train the network + scores, queue = train(train_loader, model, optimizer, scaler, epoch, lr_schedule, queue, args) + training_stats.update(scores) + + # save checkpoints + if args.rank == 0: + save_dict = { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + "run_id": mlflow.active_run().info.run_id, + "norms": args.norms + } + if args.use_fp16: + save_dict["scaler"] = scaler.state_dict() + + torch.save( + save_dict, + os.path.join(args.model_path, "checkpoint.pth.tar"), + ) + + if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1: + shutil.copyfile( + os.path.join(args.model_path, "checkpoint.pth.tar"), + os.path.join(args.model_path, f"ckp-{epoch}.pth"), + ) + + if queue is not None: + torch.save({"queue": queue}, queue_path) + +def train(train_loader, model, optimizer, scaler, epoch, lr_schedule, queue, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses], + prefix="Epoch: [{}]".format(epoch)) + + model.train() + use_the_queue = False + + end = time.time() + for it, inputs in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + # update learning rate + iteration = epoch * len(train_loader) + it + for param_group in optimizer.param_groups: + param_group["lr"] = lr_schedule[iteration] + + # normalize the prototypes + with torch.no_grad(): + w = model.module.prototypes.weight.data.clone() + w = nn.functional.normalize(w, dim=1, p=2) + model.module.prototypes.weight.copy_(w) + + # ============ multi-res forward passes ... ============ + if scaler is not None: + with autocast(): + embedding, output = model(inputs) + + embedding = embedding.detach() + bs = inputs[0].size(0) + + # ============ swav loss ... ============ + loss = 0 + for i, crop_id in enumerate(args.crops_for_assign): + with torch.no_grad(): + out = output[bs * crop_id: bs * (crop_id + 1)].detach() + + # time to use the queue + if queue is not None: + if use_the_queue or not torch.all(queue[i, -1, :] == 0): + use_the_queue = True + out = torch.cat((torch.mm( + queue[i], + model.module.prototypes.weight.t() + ), out)) + + # fill the queue + queue[i, bs:] = queue[i, :-bs].clone() + queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs] + + # get assignments + q = distributed_sinkhorn(out, args)[-bs:] + + # cluster assignment prediction + subloss = 0 + for v in np.delete(np.arange(np.sum(args.nmb_crops)), crop_id): + x = output[bs * v: bs * (v + 1)] / args.temperature + subloss -= torch.mean(torch.sum(q * F.log_softmax(x, dim=1), dim=1)) + + loss += subloss / (np.sum(args.nmb_crops) - 1) + + loss /= len(args.crops_for_assign) + else: + embedding, output = model(inputs) + + embedding = embedding.detach() + bs = inputs[0].size(0) + + # ============ swav loss ... ============ + loss = 0 + for i, crop_id in enumerate(args.crops_for_assign): + with torch.no_grad(): + out = output[bs * crop_id: bs * (crop_id + 1)].detach() + + # time to use the queue + if queue is not None: + if use_the_queue or not torch.all(queue[i, -1, :] == 0): + use_the_queue = True + out = torch.cat((torch.mm( + queue[i], + model.module.prototypes.weight.t() + ), out)) + + # fill the queue + queue[i, bs:] = queue[i, :-bs].clone() + queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs] + + # get assignments + q = distributed_sinkhorn(out, args)[-bs:] + + # cluster assignment prediction + subloss = 0 + for v in np.delete(np.arange(np.sum(args.nmb_crops)), crop_id): + x = output[bs * v: bs * (v + 1)] / args.temperature + subloss -= torch.mean(torch.sum(q * F.log_softmax(x, dim=1), dim=1)) + + loss += subloss / (np.sum(args.nmb_crops) - 1) + + loss /= len(args.crops_for_assign) + + # ============ backward and optim step ... ============ + optimizer.zero_grad() + if scaler is not None: + scaler.scale(loss).backward() + + # cancel gradients for the prototypes + if iteration < args.freeze_prototypes_niters: + for name, p in model.named_parameters(): + if "prototypes" in name: + p.grad = None + + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + + # cancel gradients for the prototypes + if iteration < args.freeze_prototypes_niters: + for name, p in model.named_parameters(): + if "prototypes" in name: + p.grad = None + + optimizer.step() + + # ============ misc ... ============ + # acc1/acc5 are (K+1)-way contrast classifier accuracy + # measure accuracy and record loss + losses.update(loss.item(), inputs[0].size(0)) + batch_time.update(time.time() - end) + end = time.time() + + if args.rank == 0 and it % args.print_freq == 0: + progress.display(it) + + if args.rank == 0: + # store metrics to mlflow + mlflow.log_metric('loss', losses.avg, step=epoch) + + return (epoch, losses.avg), queue + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + +@torch.no_grad() +def distributed_sinkhorn(out, args): + Q = torch.exp(out / args.epsilon).t() # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * args.world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(args.sinkhorn_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the colomns must sum to 1 so that Q is an assignment + return Q.t() + + +if __name__ == "__main__": + main() diff --git a/pretraining/swav/utils.py b/pretraining/swav/utils.py new file mode 100644 index 0000000..14e053e --- /dev/null +++ b/pretraining/swav/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import argparse +import pickle +import os +import numpy as np +import torch +import torch.distributed as dist + +def init_distributed_mode(args): + """ + Initialize the following variables: + - world_size + - rank + """ + + args.is_slurm_job = "SLURM_JOB_ID" in os.environ + + if args.is_slurm_job: + args.rank = int(os.environ["SLURM_PROCID"]) + args.world_size = int(os.environ["SLURM_NNODES"]) * int( + os.environ["SLURM_TASKS_PER_NODE"][0] + ) + else: + # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch + # read environment variables + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + + # prepare distributed + dist.init_process_group( + backend="nccl", + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + + # set cuda device + args.gpu_to_work_on = args.rank % torch.cuda.device_count() + torch.cuda.set_device(args.gpu_to_work_on) + + return + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + + return res \ No newline at end of file From 805a26c8c97fc5a94d3c626fb0be5a5bcb40e512 Mon Sep 17 00:00:00 2001 From: conradry Date: Thu, 19 Aug 2021 08:58:41 -0400 Subject: [PATCH 04/19] refactor moco --- pretraining/mocov2/dataset.py | 14 ++++++--- pretraining/{ => mocov2}/resnet.py | 0 pretraining/{ => mocov2}/train_mocov2.py | 4 +-- pretraining/mocov2_config.yaml | 36 ------------------------ 4 files changed, 12 insertions(+), 42 deletions(-) rename pretraining/{ => mocov2}/resnet.py (100%) rename pretraining/{ => mocov2}/train_mocov2.py (99%) delete mode 100644 pretraining/mocov2_config.yaml diff --git a/pretraining/mocov2/dataset.py b/pretraining/mocov2/dataset.py index deb7733..17fe51b 100644 --- a/pretraining/mocov2/dataset.py +++ b/pretraining/mocov2/dataset.py @@ -1,3 +1,4 @@ +import os import random import torch import numpy as np @@ -14,18 +15,23 @@ class EMData(Dataset): def __init__(self, fpaths_dask_array, tfs): super(EMData, self).__init__() - self.fpaths_dask_array = fpaths_dask_array + #self.fpaths_dask_array = fpaths_dask_array + self.fpaths = np.array(glob(os.path.join(fpaths_dask_array, '*.tiff'))) self.tfs = tfs - self.fpaths = da.from_npy_stack(fpaths_dask_array) - print(f'Loaded {fpaths_dask_array} with {len(self.fpaths)} tiff images') + benchmarks = ['urocell', 'guay', 'cremi', 'perez', 'lucchi', 'kasthuri'] + for bnk in benchmarks: + indices = np.where(np.core.defchararray.find(self.fpaths, bnk) == -1)[0] + self.fpaths = self.fpaths[indices] + + print(f'Found {len(self.fpaths)} tiff images') def __len__(self): return len(self.fpaths) def __getitem__(self, idx): #get the filepath to load - f = self.fpaths[idx].compute() + f = self.fpaths[idx]#.compute() #load the image and add an empty channel dim image = Image.open(f) diff --git a/pretraining/resnet.py b/pretraining/mocov2/resnet.py similarity index 100% rename from pretraining/resnet.py rename to pretraining/mocov2/resnet.py diff --git a/pretraining/train_mocov2.py b/pretraining/mocov2/train_mocov2.py similarity index 99% rename from pretraining/train_mocov2.py rename to pretraining/mocov2/train_mocov2.py index 3203eef..15890c4 100755 --- a/pretraining/train_mocov2.py +++ b/pretraining/mocov2/train_mocov2.py @@ -1,4 +1,5 @@ """ + Copied with modification from https://github.com/facebookresearch/moco/blob/master/main_moco.py Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved @@ -37,7 +38,7 @@ import mocov2.builder as builder from mocov2.dataset import EMData, GaussianBlur, GaussNoise -from resnet import resnet50 +from ..resnet import resnet50 import mlflow @@ -51,7 +52,6 @@ def parse_args(): return vars(parser.parse_args()) - def main(): args = parse_args() diff --git a/pretraining/mocov2_config.yaml b/pretraining/mocov2_config.yaml deleted file mode 100644 index bfa665e..0000000 --- a/pretraining/mocov2_config.yaml +++ /dev/null @@ -1,36 +0,0 @@ -#basic definitions -experiment_name: "Filtered_CellEMNet_MoCoV2" -data_file: # a .npz dask array of filepaths -model_dir: # the directory in which to save model states -arch: "resnet50" -workers: 16 -epochs: 2 -save_freq: 20 -print_freq: 10 -batch_size: 128 -lr: 0.015 -schedule: - - 120 - - 160 -momentum: 0.9 -weight_decay: 0.0001 - -resume: "" -world_size: 1 -rank: 0 -dist_url: "tcp://localhost:10001" -dist_backend: "nccl" -multiprocessing_distributed: True - -moco_dim: 128 -moco_k: 65536 -moco_m: 0.999 -moco_t: 0.2 - -mlp: True -cos: False -logging: True - -norms: - mean: 0.58331613 - std: 0.09966064 \ No newline at end of file From 2f6235e4a9ac6d23ff88707dc9aaa5b899b1c67d Mon Sep 17 00:00:00 2001 From: conradry Date: Thu, 19 Aug 2021 08:59:18 -0400 Subject: [PATCH 05/19] dont forget config file --- pretraining/mocov2/mocov2_config.yaml | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 pretraining/mocov2/mocov2_config.yaml diff --git a/pretraining/mocov2/mocov2_config.yaml b/pretraining/mocov2/mocov2_config.yaml new file mode 100644 index 0000000..8fb58d4 --- /dev/null +++ b/pretraining/mocov2/mocov2_config.yaml @@ -0,0 +1,36 @@ +#basic definitions +experiment_name: "Filtered_CellEMNet_MoCoV2_No_Benchmarks" +data_file: "/data/IASEM/conradrw/data/cem500k/" # a .npz dask array of filepaths +model_dir: "/data/conradrw/FEM_No_Benchmarks/" # the directory in which to save model states +arch: "resnet50" +workers: 16 +epochs: 200 +save_freq: 20 +print_freq: 10 +batch_size: 128 +lr: 0.015 +schedule: + - 120 + - 160 +momentum: 0.9 +weight_decay: 0.0001 + +resume: "/data/conradrw/FEM_No_Benchmarks/current.pth.tar" +world_size: 1 +rank: 0 +dist_url: "tcp://localhost:10001" +dist_backend: "nccl" +multiprocessing_distributed: True + +moco_dim: 128 +moco_k: 65536 +moco_m: 0.999 +moco_t: 0.2 + +mlp: True +cos: False +logging: True + +norms: + mean: 0.58331613 + std: 0.09966064 \ No newline at end of file From 39c0c4e086fdfc985e87606583de83cae76d053d Mon Sep 17 00:00:00 2001 From: conradry Date: Thu, 19 Aug 2021 09:08:30 -0400 Subject: [PATCH 06/19] pixpro initial --- pretraining/pixpro/LARC.py | 132 ++++++++ pretraining/pixpro/__init__.py | 0 pretraining/pixpro/builder.py | 183 ++++++++++ pretraining/pixpro/dataset.py | 87 +++++ pretraining/pixpro/pixpro_config.yaml | 33 ++ pretraining/pixpro/train_pixpro.py | 469 ++++++++++++++++++++++++++ 6 files changed, 904 insertions(+) create mode 100644 pretraining/pixpro/LARC.py create mode 100644 pretraining/pixpro/__init__.py create mode 100644 pretraining/pixpro/builder.py create mode 100644 pretraining/pixpro/dataset.py create mode 100644 pretraining/pixpro/pixpro_config.yaml create mode 100644 pretraining/pixpro/train_pixpro.py diff --git a/pretraining/pixpro/LARC.py b/pretraining/pixpro/LARC.py new file mode 100644 index 0000000..0d1d6ea --- /dev/null +++ b/pretraining/pixpro/LARC.py @@ -0,0 +1,132 @@ +""" +Copied with modificiation from https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + +Modifications: +-------------- + +1. Added a condition in step() to not adapt the lr for parameter groups +with 'adapt_lr' == False. + +""" + +import torch +from torch import nn +from torch.nn.parameter import Parameter + +class LARC(object): + """ + :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, + in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive + local learning rate for each individual parameter. The algorithm is designed to improve + convergence of large batch training. + + See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. + + In practice it modifies the gradients of parameters as a proxy for modifying the learning rate + of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. + + ``` + model = ... + optim = torch.optim.Adam(model.parameters(), lr=...) + optim = LARC(optim) + ``` + + It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. + + ``` + model = ... + optim = torch.optim.Adam(model.parameters(), lr=...) + optim = LARC(optim) + optim = apex.fp16_utils.FP16_Optimizer(optim) + ``` + + Args: + optimizer: Pytorch optimizer to wrap and modify learning rate for. + trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 + clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. + eps: epsilon kludge to help with numerical stability while calculating adaptive_lr + """ + + def __init__( + self, + optimizer, + trust_coefficient=1e-3, + clip=False, + eps=1e-8 + ): + self.optim = optimizer + self.trust_coefficient = trust_coefficient + self.eps = eps + self.clip = clip + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + @property + def state(self): + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + @property + def param_groups(self): + return self.optim.param_groups + + @param_groups.setter + def param_groups(self, value): + self.optim.param_groups = value + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + + def zero_grad(self): + self.optim.zero_grad() + + def add_param_group(self, param_group): + self.optim.add_param_group( param_group) + + def step(self): + with torch.no_grad(): + weight_decays = [] + for group in self.optim.param_groups: + # absorb weight decay control from optimizer + weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 + weight_decays.append(weight_decay) + group['weight_decay'] = 0 + + #check if adapt_lr flag exists and if it's False + #don't adapt the lr for those parameters + if 'adapt_lr' in group: + if group['adapt_lr'] is False: + continue #go to next group + + for p in group['params']: + if p.grad is None: + continue + + param_norm = torch.norm(p.data) + grad_norm = torch.norm(p.grad.data) + + if param_norm != 0 and grad_norm != 0: + # calculate adaptive lr + weight decay + adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) + + # clip learning rate for LARC + if self.clip: + # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` + adaptive_lr = min(adaptive_lr/group['lr'], 1) + + p.grad.data += weight_decay * p.data + p.grad.data *= adaptive_lr + + self.optim.step() + # return weight decay control to optimizer + for i, group in enumerate(self.optim.param_groups): + group['weight_decay'] = weight_decays[i] \ No newline at end of file diff --git a/pretraining/pixpro/__init__.py b/pretraining/pixpro/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pretraining/pixpro/builder.py b/pretraining/pixpro/builder.py new file mode 100644 index 0000000..e73431c --- /dev/null +++ b/pretraining/pixpro/builder.py @@ -0,0 +1,183 @@ +import random, math, sys +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from collections import OrderedDict +from torch.utils.checkpoint import checkpoint + +class PPM(nn.Module): + def __init__(self, nin, gamma=2, nlayers=1): + super(PPM, self).__init__() + + #the networks can have 0-2 layers + if nlayers == 0: + layers = [nn.Identity()] + elif nlayers == 1: + layers = [nn.Conv2d(nin, nin, 1)] + elif nlayers == 2: + layers = [ + nn.Conv2d(nin, nin, 1, bias=False), + nn.BatchNorm2d(nin), + nn.ReLU(), + nn.Conv2d(nin, nin, 1) + ] + else: + raise Exception(f'nlayers must be 0, 1, or 2, got {nlayers}') + + self.transform = nn.Sequential(*layers) + self.gamma = gamma + self.cosine_sim = nn.CosineSimilarity(dim=1) + + def forward(self, x): + #(B, C, H, W, 1, 1) x (B, C, 1, 1, H, W) --> (B, H, W, H, W) + #relu is same as max(sim, 0) but differentiable + xi = x[:, :, :, :, None, None] + xj = x[:, :, None, None, :, :] + s = F.relu(self.cosine_sim(xi, xj)) ** self.gamma + + #output of "g" in the paper (B, C, H, W) + gx = self.transform(x) + + #use einsum and skip all the transposes + #and matrix multiplies + return torch.einsum('bijhw, bchw -> bcij', s, gx) + +class Encoder(nn.Module): + def __init__(self, backbone): + super(Encoder, self).__init__() + + #use ordered dict to retain layer names + #last two layers are avgpool and fc + layers = OrderedDict([layer for layer in backbone().named_children()][:-2]) + self.backbone = nn.Sequential(layers) + + #get the number of channels from last convolution + for layer in self.backbone.modules(): + if isinstance(layer, nn.Conv2d): + projection_nin = layer.out_channels + + #note that projection_nin = 2048 for resnet50 (matches paper) + self.projection = nn.Sequential( + nn.Conv2d(projection_nin, projection_nin, 1, bias=False), + nn.BatchNorm2d(projection_nin), + nn.ReLU(inplace=True), + nn.Conv2d(projection_nin, 256, 1) + ) + + def forward(self, x): + return self.projection(self.backbone(x)) + +class PixPro(nn.Module): + def __init__( + self, + backbone, + momentum=0.99, + ppm_layers=1, + ppm_gamma=2, + downsampling=32 + ): + super(PixPro, self).__init__() + + #create the encoder and momentum encoder + self.encoder = Encoder(backbone) + self.mom_encoder = deepcopy(self.encoder) + + #turn off gradients + for mom_param in self.mom_encoder.parameters(): + mom_param.requires_grad = False + + #hardcoded: encoder outputs 256 + self.ppm = PPM(256, gamma=ppm_gamma, nlayers=ppm_layers) + + self.grid_downsample = nn.AvgPool2d(downsampling, stride=downsampling) + self.momentum = momentum + + @torch.no_grad() + def _update_mom_encoder(self): + for param, mom_param in zip(self.encoder.parameters(), self.mom_encoder.parameters()): + mom_param.data = mom_param.data * self.momentum + param.data * (1. - self.momentum) + + def forward(self, view1, view2, view1_grid, view2_grid): + #pass each view through each encoder + y1 = self.ppm(self.encoder(view1)) + y2 = self.ppm(self.encoder(view2)) + + with torch.no_grad(): + #update mom encoder before forward + self._update_mom_encoder() + + z1 = self.mom_encoder(view1) + z2 = self.mom_encoder(view2) + + view1_grid = self.grid_downsample(view1_grid) + view2_grid = self.grid_downsample(view2_grid) + + return y1, y2, z1, z2, view1_grid, view2_grid + +def grid_distances(grid1, grid2): + #grid: (B, 2, H, W) --> (B, 2, H * W) + h, w = grid1.size()[-2:] + grid1 = grid1.flatten(2, -1)[..., :, None] #(B, 2, H * W, 1) + grid2 = grid2.flatten(2, -1)[..., None, :] #(B, 2, 1, H * W) + + y_distances = grid1[:, 0] - grid2[:, 0] + x_distances = grid1[:, 1] - grid2[:, 1] + + return torch.sqrt(y_distances ** 2 + x_distances ** 2) + +class ConsistencyLoss(nn.Module): + def __init__(self, distance_thr=0.7): + super(ConsistencyLoss, self).__init__() + self.distance_thr = distance_thr + self.cosine_sim = nn.CosineSimilarity(dim=1) + + def forward(self, y1, y2, z1, z2, view1_grid, view2_grid): + #(B, C, H * W) + y1 = y1.flatten(2, -1) + y2 = y2.flatten(2, -1) + z1 = z1.flatten(2, -1) + z2 = z2.flatten(2, -1) + + #pairwise distances between grid coordinates + #(B, C, H * W, H * W) + distances = grid_distances(view1_grid, view2_grid) + + #determine normalization factors for view1 and view2 + #(i.e. distance between "feature map bins") + #(B,) + view1_bin = torch.norm(view1_grid[..., 1, 1] - view1_grid[..., 0, 0], dim=-1) + view2_bin = torch.norm(view2_grid[..., 1, 1] - view2_grid[..., 0, 0], dim=-1) + + #(B, H * W, H * W) + view1_distances = distances / view1_bin[:, None, None] + view2_distances = distances / view2_bin[:, None, None] + + #(B, C, H * W, 1) x (B, C, 1, H * W) --> (B, H * W, H * W) + #important to keep view1 outputs (y1 and z1) as first items + #in the cosine sim measurement because distances tensor + #has that ordering fixed + view1_similarity = self.cosine_sim(y1[..., :, None], z2[..., None, :]) + view2_similarity = self.cosine_sim(z1[..., :, None], y2[..., None, :]) + similarities = view1_similarity + view2_similarity + + #only consider points that are matches for both views + view1_mask = view1_distances <= self.distance_thr + view2_mask = view2_distances <= self.distance_thr + mask = torch.logical_and(view1_mask, view2_mask) + + #mask-out non-matches with zeros + similarities = torch.where(mask, similarities, torch.zeros_like(similarities)) + matches_per_image = mask.sum(dim=(-1, -2)) #(B,) + + #average over images + similarities = similarities.sum(dim=(-1, -2)).masked_select(matches_per_image > 0) + matches_per_image = matches_per_image.masked_select(matches_per_image > 0) #(b,) + similarities = similarities / matches_per_image #(B,) + + #average over batch + matched_images = (matches_per_image > 0).sum() + similarities = similarities.sum() / matched_images #( ) + + return -similarities diff --git a/pretraining/pixpro/dataset.py b/pretraining/pixpro/dataset.py new file mode 100644 index 0000000..9c0876f --- /dev/null +++ b/pretraining/pixpro/dataset.py @@ -0,0 +1,87 @@ +import os +import cv2 +import numpy as np +import torch +from torch.utils.data import Dataset +from glob import glob +from albumentations import ImageOnlyTransform + +class ContrastData(Dataset): + def __init__( + self, + imdir, + space_tfs, + view1_color_tfs, + view2_color_tfs=None + ): + super(ContrastData, self).__init__() + + self.imdir = imdir + self.fpaths = glob(os.path.join(imdir, '*.tiff'), recursive=True) + #self.fnames = os.listdir(imdir) + + print(f'Found {len(self.fpaths)} images in directory') + + #crops, resizes, flips, rotations, etc. + self.space_tfs = space_tfs + + #brightness, contrast, jitter, blur, and + #normalization + self.view1_color_tfs = view1_color_tfs + self.view2_color_tfs = view2_color_tfs + + def __len__(self): + return len(self.fpaths) + + def __getitem__(self, idx): + fpath = self.fpaths[idx] + #fpath = os.path.join(self.imdir, self.fnames[idx]) + image = cv2.imread(fpath, 0) + + y = np.arange(0, image.shape[0], dtype=np.float32) + x = np.arange(0, image.shape[1], dtype=np.float32) + grid_y, grid_x = np.meshgrid(y, x) + grid_y, grid_x = grid_y.T, grid_x.T + + #space transforms treat coordinate grid like an image + #bilinear interp is good, nearest would be bad + view1_data = self.space_tfs(image=image, grid_y=grid_y[..., None], grid_x=grid_x[..., None]) + view2_data = self.space_tfs(image=image, grid_y=grid_y[..., None], grid_x=grid_x[..., None]) + + view1 = view1_data['image'] + view1_grid = np.concatenate([view1_data['grid_y'], view1_data['grid_x']], axis=-1) + view2 = view2_data['image'] + view2_grid = np.concatenate([view2_data['grid_y'], view2_data['grid_x']], axis=-1) + + view1 = self.view1_color_tfs(image=view1)['image'] + if self.view2_color_tfs is not None: + view2 = self.view2_color_tfs(image=view2)['image'] + else: + view2 = self.view1_color_tfs(image=view2)['image'] + + output = { + 'fpath': fpath, + 'view1': view1, + 'view1_grid': torch.from_numpy(view1_grid).permute(2, 0, 1), + 'view2': view2, + 'view2_grid': torch.from_numpy(view2_grid).permute(2, 0, 1) + } + + return output + +class Grayscale(ImageOnlyTransform): + """ + Resizes an image, but not the mask, to be divisible by a specific + number like 32. Necessary for evaluation with segmentation models + that use downsampling. + """ + def __init__(self, channels=1, always_apply=True, p=1.0): + super(Grayscale, self).__init__(always_apply, p) + self.channels = channels + + def apply(self, img, **params): + if img.ndim == 2: + img = np.repeat(img[..., None], self.channels, axis=-1) + elif img.ndim == 3 and img.shape[-1] == 3: + img = img[..., 0] + return img diff --git a/pretraining/pixpro/pixpro_config.yaml b/pretraining/pixpro/pixpro_config.yaml new file mode 100644 index 0000000..8de43d7 --- /dev/null +++ b/pretraining/pixpro/pixpro_config.yaml @@ -0,0 +1,33 @@ +#basic definitions +experiment_name: "CEM500K_PixPro100" +data_dir: "/data/IASEM/conradrw/data/cem500k/" # a .npz dask array of filepaths +model_dir: "/data/IASEM/conradrw/models/CEM500K_PixPro100/" # the directory in which to save model states +arch: "resnet50" +workers: 32 +start_epoch: 0 +epochs: 100 +save_freq: 20 +print_freq: 10 +batch_size: 512 +lr: 2 +momentum: 0.9 +weight_decay: 0.00001 +logging: True + +resume: "" +world_size: 1 +rank: 0 +dist_url: "tcp://localhost:10001" +dist_backend: "nccl" +multiprocessing_distributed: True + +pixpro_mom: 0.9934 +ppm_layers: 1 +ppm_gamma: 2 +pixpro_t: 0.7 + +fp16: True + +norms: + mean: 0.58331613 + std: 0.09966064 diff --git a/pretraining/pixpro/train_pixpro.py b/pretraining/pixpro/train_pixpro.py new file mode 100644 index 0000000..6b8c9b1 --- /dev/null +++ b/pretraining/pixpro/train_pixpro.py @@ -0,0 +1,469 @@ +""" +Copied and modified from: +https://github.com/facebookresearch/moco/blob/master/main_moco.py + +Which is really copied and modified from the source of all distributed +training scripts: +https://github.com/pytorch/examples/blob/master/imagenet/main.py +""" +import argparse +import builtins +import math +import os, sys +import random +import shutil +import time +import warnings +import yaml + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models + +from LARC import LARC + +from torch.cuda.amp import autocast +from torch.cuda.amp import GradScaler + +import albumentations as A +from albumentations.pytorch import ToTensorV2 + +from dataset import ContrastData, Grayscale +from builder import PixPro, ConsistencyLoss + +sys.path.append('/home/conradrw/nbs/cellemnet/pretraining/') +from resnet import resnet50 + +import mlflow + +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +def parse_args(): + parser = argparse.ArgumentParser(description='PyTorch PixPro Training') + parser.add_argument('config', help='Path to .yaml training config file') + + return vars(parser.parse_args()) + +def main(): + args = parse_args() + + with open(args['config'], 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + if not os.path.isdir(config['model_dir']): + os.mkdir(config['model_dir']) + + config['config_file'] = args['config'] + + #world size is the number of processes that will run + if config['dist_url'] == "env://" and config['world_size'] == -1: + config['world_size'] = int(os.environ["WORLD_SIZE"]) + + config['distributed'] = config['world_size'] > 1 or config['multiprocessing_distributed'] + + ngpus_per_node = torch.cuda.device_count() + config['ngpus_per_node'] = ngpus_per_node + if config['multiprocessing_distributed']: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + config['world_size'] = ngpus_per_node * config['world_size'] + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config)) + else: + # Simply call main_worker function + main_worker(config['gpu'], ngpus_per_node, config) + +def main_worker(gpu, ngpus_per_node, config): + config['gpu'] = gpu + + # suppress printing if not master process + if config['multiprocessing_distributed'] and config['gpu'] != 0: + def print_pass(*args): + pass + builtins.print = print_pass + + if config['gpu'] is not None: + print("Use GPU: {} for training".format(config['gpu'])) + + if config['distributed']: + if config['dist_url'] == "env://" and config['rank'] == -1: + config['rank'] = int(os.environ["RANK"]) + if config['multiprocessing_distributed']: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + config['rank'] = config['rank'] * ngpus_per_node + gpu + + dist.init_process_group(backend=config['dist_backend'], init_method=config['dist_url'], + world_size=config['world_size'], rank=config['rank']) + + print("=> creating model '{}'".format(config['arch'])) + + model = PixPro( + resnet50, + config['pixpro_mom'], config['ppm_layers'], config['ppm_gamma'] + ) + + if config['distributed']: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + # Turn on SyncBatchNorm + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + if config['gpu'] is not None: + torch.cuda.set_device(config['gpu']) + model.cuda(config['gpu']) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + config['batch_size'] = int(config['batch_size'] / ngpus_per_node) + config['workers'] = int((config['workers'] + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config['gpu']]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif config['gpu'] is not None: + torch.cuda.set_device(config['gpu']) + model = model.cuda(config['gpu']) + # comment out the following line for debugging + raise NotImplementedError("Only DistributedDataParallel is supported.") + else: + # AllGather implementation (batch shuffle, queue update, etc.) in + # this code only supports DistributedDataParallel. + raise NotImplementedError("Only DistributedDataParallel is supported.") + + #define loss criterion and optimizer + criterion = ConsistencyLoss(distance_thr=config['pixpro_t']).cuda(config['gpu']) + + optimizer = configure_optimizer(model, config) + + # optionally resume from a checkpoint + if config['resume']: + if os.path.isfile(config['resume']): + print("=> loading checkpoint '{}'".format(config['resume'])) + if config['gpu'] is None: + checkpoint = torch.load(config['resume']) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(config['gpu']) + checkpoint = torch.load(config['resume'], map_location=loc) + config['start_epoch'] = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(config['resume'], checkpoint['epoch'])) + else: + config['start_epoch'] = 0 + print("=> no checkpoint found at '{}'".format(config['resume'])) + + cudnn.benchmark = True + + norms = config['norms'] + mean_pixel = norms['mean'] + std_pixel = norms['std'] + normalize = A.Normalize(mean=[mean_pixel], std=[std_pixel]) + + #physical space only + space_tfs = A.Compose([ + A.RandomResizedCrop(224, 224, scale=(0.2, 1.0)), + Grayscale(3), + A.HorizontalFlip(), + A.VerticalFlip() + ], additional_targets={'grid_y': 'image', 'grid_x': 'image'}) + + #could work for both views + view1_color_tfs = A.Compose([ + A.ColorJitter(0.4, 0.4, 0.2, 0.1, p=0.8), + Grayscale(1), + A.GaussianBlur(blur_limit=23, sigma_limit=(0.1, 2.0), p=1.0), + normalize, + ToTensorV2() + ]) + + #technically optional, but used in the BYOL paper + view2_color_tfs = A.Compose([ + A.ColorJitter(0.4, 0.4, 0.2, 0.1, p=0.8), + Grayscale(1), + A.GaussianBlur(blur_limit=23, sigma_limit=(0.1, 2.0), p=0.1), + A.GaussNoise(p=0.2), + normalize, + ToTensorV2() + ]) + + train_dataset = ContrastData( + config['data_dir'], space_tfs, view1_color_tfs, view2_color_tfs + ) + + if config['distributed']: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=config['batch_size'], shuffle=(train_sampler is None), + num_workers=config['workers'], pin_memory=True, sampler=train_sampler, drop_last=True) + + # encoder momentum is updated by STEP and not EPOCH + config['train_steps'] = config['epochs'] * len(train_loader) + config['current_step'] = config['start_epoch'] * len(train_loader) + + if config['fp16']: + scaler = GradScaler() + else: + scaler = None + + #log parameters, if needed: + if config['logging'] and (config['multiprocessing_distributed'] + and config['rank'] % ngpus_per_node == 0): + + #end any old runs + mlflow.end_run() + mlflow.set_experiment(config['experiment_name']) + mlflow.log_artifact(config['config_file']) + + #we don't want to add everything in the config + #to mlflow parameters, we'll just add the most + #likely to change parameters + mlflow.log_param('data_dir', config['data_dir']) + mlflow.log_param('architecture', config['arch']) + mlflow.log_param('epochs', config['epochs']) + mlflow.log_param('batch_size', config['batch_size']) + mlflow.log_param('learning_rate', config['lr']) + mlflow.log_param('pixpro_mom', config['pixpro_mom']) + mlflow.log_param('ppm_layers', config['ppm_layers']) + mlflow.log_param('ppm_gamma', config['ppm_gamma']) + mlflow.log_param('pixpro_t', config['pixpro_t']) + + for epoch in range(config['start_epoch'], config['epochs']): + if config['distributed']: + train_sampler.set_epoch(epoch) + + adjust_learning_rate(optimizer, epoch, config) + + # train for one epoch + train(train_loader, model, criterion, optimizer, scaler, epoch, config) + + if not config['multiprocessing_distributed'] or (config['multiprocessing_distributed'] + and config['rank'] % ngpus_per_node == 0): + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': config['arch'], + 'state_dict': model.state_dict(), + 'optimizer' : optimizer.state_dict(), + 'norms': [mean_pixel, std_pixel], + }, is_best=False, filename=os.path.join(config['model_dir'], 'current.pth.tar')) + + #save checkpoint every save_freq epochs + if (epoch + 1) % config['save_freq'] == 0: + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': config['arch'], + 'state_dict': model.state_dict(), + 'optimizer' : optimizer.state_dict(), + 'norms': [mean_pixel, std_pixel], + }, is_best=False, filename=os.path.join(config['model_dir'] + 'checkpoint_{:04d}.pth.tar'.format(epoch + 1))) + +def train(train_loader, model, criterion, optimizer, scaler, epoch, config): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses], + prefix="Epoch: [{}]".format(epoch) + ) + + # switch to train mode + model.train() + + end = time.time() + for i, batch in enumerate(train_loader): + view1 = batch['view1'] + view1_grid = batch['view1_grid'] + view2 = batch['view2'] + view2_grid = batch['view2_grid'] + + # measure data loading time + data_time.update(time.time() - end) + + if config['gpu'] is not None: + view1 = view1.cuda(config['gpu'], non_blocking=True) + view1_grid = view1_grid.cuda(config['gpu'], non_blocking=True) + view2 = view2.cuda(config['gpu'], non_blocking=True) + view2_grid = view2_grid.cuda(config['gpu'], non_blocking=True) + + optimizer.zero_grad() + + # compute output and loss + if config['fp16']: + with autocast(): + output = model(view1, view2, view1_grid, view2_grid) + loss = criterion(*output) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + output = model(view1, view2, view1_grid, view2_grid) + loss = criterion(*output) + loss.backward() + optimizer.step() + + # avg loss from batch size + losses.update(loss.item(), view1.size(0)) + + # update current step and encoder momentum + config['current_step'] += 1 + adjust_encoder_momentum(model, config) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % config['print_freq'] == 0: + progress.display(i) + + if config['rank'] % config['ngpus_per_node'] == 0: + # store metrics to mlflow + mlflow.log_metric('sim_loss', losses.avg, step=epoch) + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + +def configure_optimizer(model, config): + """ + Takes an optimizer and separates parameters into two groups + that either use weight decay or are exempt. + + Only BatchNorm parameters and biases are excluded. + """ + decay = set() + no_decay = set() + + blacklist = (nn.BatchNorm2d,) + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(recurse=False): + full_name = '%s.%s' % (mn, pn) if mn else pn + + if full_name.endswith('bias'): + no_decay.add(full_name) + elif full_name.endswith('weight') and isinstance(m, blacklist): + no_decay.add(full_name) + else: + decay.add(full_name) + + param_dict = {pn: p for pn, p in model.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert(len(inter_params) == 0), "Overlapping decay and no decay" + assert(len(param_dict.keys() - union_params) == 0), "Missing decay parameters" + + decay_params = [param_dict[pn] for pn in sorted(list(decay))] + no_decay_params = [param_dict[pn] for pn in sorted(list(no_decay))] + + #the adapt_lr key tells LARS not to adapt the lr (see 'LARC.py') + param_groups = [ + {"params": decay_params, "weight_decay": config['weight_decay'], "adapt_lr": True}, + {"params": no_decay_params, "weight_decay": 0., "adapt_lr": False} + ] + + base_optimizer = torch.optim.SGD( + param_groups, lr=config['lr'], momentum=config['momentum'] + ) + + #LARC without clipping == LARS + #lower trust_coefficient to match SimCLR and BYOL + #(too high of a trust_coefficient leads to NaN losses!) + optimizer = LARC(optimizer=base_optimizer, trust_coefficient=1e-3) + + return optimizer + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + +def adjust_encoder_momentum(model, config): + base_mom = config['pixpro_mom'] + new_mom = 1 - (1 - base_mom) * (math.cos(math.pi * config['current_step'] / config['train_steps']) + 1) / 2 + model.module.momentum = new_mom + +def adjust_learning_rate(optimizer, epoch, config): + """Decay the learning rate based on schedule""" + lr = config['lr'] + #cosine lr schedule + lr *= 0.5 * (1. + math.cos(math.pi * epoch / config['epochs'])) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() From d162e8dac0a007b66ee42608d1dfb0c7c8bfe527 Mon Sep 17 00:00:00 2001 From: conradry Date: Thu, 19 Aug 2021 09:08:53 -0400 Subject: [PATCH 07/19] dont need pixpro --- pretraining/pixpro/LARC.py | 132 -------- pretraining/pixpro/__init__.py | 0 pretraining/pixpro/builder.py | 183 ---------- pretraining/pixpro/dataset.py | 87 ----- pretraining/pixpro/pixpro_config.yaml | 33 -- pretraining/pixpro/train_pixpro.py | 469 -------------------------- 6 files changed, 904 deletions(-) delete mode 100644 pretraining/pixpro/LARC.py delete mode 100644 pretraining/pixpro/__init__.py delete mode 100644 pretraining/pixpro/builder.py delete mode 100644 pretraining/pixpro/dataset.py delete mode 100644 pretraining/pixpro/pixpro_config.yaml delete mode 100644 pretraining/pixpro/train_pixpro.py diff --git a/pretraining/pixpro/LARC.py b/pretraining/pixpro/LARC.py deleted file mode 100644 index 0d1d6ea..0000000 --- a/pretraining/pixpro/LARC.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Copied with modificiation from https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py - -Modifications: --------------- - -1. Added a condition in step() to not adapt the lr for parameter groups -with 'adapt_lr' == False. - -""" - -import torch -from torch import nn -from torch.nn.parameter import Parameter - -class LARC(object): - """ - :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, - in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive - local learning rate for each individual parameter. The algorithm is designed to improve - convergence of large batch training. - - See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. - - In practice it modifies the gradients of parameters as a proxy for modifying the learning rate - of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. - - ``` - model = ... - optim = torch.optim.Adam(model.parameters(), lr=...) - optim = LARC(optim) - ``` - - It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. - - ``` - model = ... - optim = torch.optim.Adam(model.parameters(), lr=...) - optim = LARC(optim) - optim = apex.fp16_utils.FP16_Optimizer(optim) - ``` - - Args: - optimizer: Pytorch optimizer to wrap and modify learning rate for. - trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 - clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. - eps: epsilon kludge to help with numerical stability while calculating adaptive_lr - """ - - def __init__( - self, - optimizer, - trust_coefficient=1e-3, - clip=False, - eps=1e-8 - ): - self.optim = optimizer - self.trust_coefficient = trust_coefficient - self.eps = eps - self.clip = clip - - def __getstate__(self): - return self.optim.__getstate__() - - def __setstate__(self, state): - self.optim.__setstate__(state) - - @property - def state(self): - return self.optim.state - - def __repr__(self): - return self.optim.__repr__() - - @property - def param_groups(self): - return self.optim.param_groups - - @param_groups.setter - def param_groups(self, value): - self.optim.param_groups = value - - def state_dict(self): - return self.optim.state_dict() - - def load_state_dict(self, state_dict): - self.optim.load_state_dict(state_dict) - - def zero_grad(self): - self.optim.zero_grad() - - def add_param_group(self, param_group): - self.optim.add_param_group( param_group) - - def step(self): - with torch.no_grad(): - weight_decays = [] - for group in self.optim.param_groups: - # absorb weight decay control from optimizer - weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 - weight_decays.append(weight_decay) - group['weight_decay'] = 0 - - #check if adapt_lr flag exists and if it's False - #don't adapt the lr for those parameters - if 'adapt_lr' in group: - if group['adapt_lr'] is False: - continue #go to next group - - for p in group['params']: - if p.grad is None: - continue - - param_norm = torch.norm(p.data) - grad_norm = torch.norm(p.grad.data) - - if param_norm != 0 and grad_norm != 0: - # calculate adaptive lr + weight decay - adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) - - # clip learning rate for LARC - if self.clip: - # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` - adaptive_lr = min(adaptive_lr/group['lr'], 1) - - p.grad.data += weight_decay * p.data - p.grad.data *= adaptive_lr - - self.optim.step() - # return weight decay control to optimizer - for i, group in enumerate(self.optim.param_groups): - group['weight_decay'] = weight_decays[i] \ No newline at end of file diff --git a/pretraining/pixpro/__init__.py b/pretraining/pixpro/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pretraining/pixpro/builder.py b/pretraining/pixpro/builder.py deleted file mode 100644 index e73431c..0000000 --- a/pretraining/pixpro/builder.py +++ /dev/null @@ -1,183 +0,0 @@ -import random, math, sys -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from copy import deepcopy -from collections import OrderedDict -from torch.utils.checkpoint import checkpoint - -class PPM(nn.Module): - def __init__(self, nin, gamma=2, nlayers=1): - super(PPM, self).__init__() - - #the networks can have 0-2 layers - if nlayers == 0: - layers = [nn.Identity()] - elif nlayers == 1: - layers = [nn.Conv2d(nin, nin, 1)] - elif nlayers == 2: - layers = [ - nn.Conv2d(nin, nin, 1, bias=False), - nn.BatchNorm2d(nin), - nn.ReLU(), - nn.Conv2d(nin, nin, 1) - ] - else: - raise Exception(f'nlayers must be 0, 1, or 2, got {nlayers}') - - self.transform = nn.Sequential(*layers) - self.gamma = gamma - self.cosine_sim = nn.CosineSimilarity(dim=1) - - def forward(self, x): - #(B, C, H, W, 1, 1) x (B, C, 1, 1, H, W) --> (B, H, W, H, W) - #relu is same as max(sim, 0) but differentiable - xi = x[:, :, :, :, None, None] - xj = x[:, :, None, None, :, :] - s = F.relu(self.cosine_sim(xi, xj)) ** self.gamma - - #output of "g" in the paper (B, C, H, W) - gx = self.transform(x) - - #use einsum and skip all the transposes - #and matrix multiplies - return torch.einsum('bijhw, bchw -> bcij', s, gx) - -class Encoder(nn.Module): - def __init__(self, backbone): - super(Encoder, self).__init__() - - #use ordered dict to retain layer names - #last two layers are avgpool and fc - layers = OrderedDict([layer for layer in backbone().named_children()][:-2]) - self.backbone = nn.Sequential(layers) - - #get the number of channels from last convolution - for layer in self.backbone.modules(): - if isinstance(layer, nn.Conv2d): - projection_nin = layer.out_channels - - #note that projection_nin = 2048 for resnet50 (matches paper) - self.projection = nn.Sequential( - nn.Conv2d(projection_nin, projection_nin, 1, bias=False), - nn.BatchNorm2d(projection_nin), - nn.ReLU(inplace=True), - nn.Conv2d(projection_nin, 256, 1) - ) - - def forward(self, x): - return self.projection(self.backbone(x)) - -class PixPro(nn.Module): - def __init__( - self, - backbone, - momentum=0.99, - ppm_layers=1, - ppm_gamma=2, - downsampling=32 - ): - super(PixPro, self).__init__() - - #create the encoder and momentum encoder - self.encoder = Encoder(backbone) - self.mom_encoder = deepcopy(self.encoder) - - #turn off gradients - for mom_param in self.mom_encoder.parameters(): - mom_param.requires_grad = False - - #hardcoded: encoder outputs 256 - self.ppm = PPM(256, gamma=ppm_gamma, nlayers=ppm_layers) - - self.grid_downsample = nn.AvgPool2d(downsampling, stride=downsampling) - self.momentum = momentum - - @torch.no_grad() - def _update_mom_encoder(self): - for param, mom_param in zip(self.encoder.parameters(), self.mom_encoder.parameters()): - mom_param.data = mom_param.data * self.momentum + param.data * (1. - self.momentum) - - def forward(self, view1, view2, view1_grid, view2_grid): - #pass each view through each encoder - y1 = self.ppm(self.encoder(view1)) - y2 = self.ppm(self.encoder(view2)) - - with torch.no_grad(): - #update mom encoder before forward - self._update_mom_encoder() - - z1 = self.mom_encoder(view1) - z2 = self.mom_encoder(view2) - - view1_grid = self.grid_downsample(view1_grid) - view2_grid = self.grid_downsample(view2_grid) - - return y1, y2, z1, z2, view1_grid, view2_grid - -def grid_distances(grid1, grid2): - #grid: (B, 2, H, W) --> (B, 2, H * W) - h, w = grid1.size()[-2:] - grid1 = grid1.flatten(2, -1)[..., :, None] #(B, 2, H * W, 1) - grid2 = grid2.flatten(2, -1)[..., None, :] #(B, 2, 1, H * W) - - y_distances = grid1[:, 0] - grid2[:, 0] - x_distances = grid1[:, 1] - grid2[:, 1] - - return torch.sqrt(y_distances ** 2 + x_distances ** 2) - -class ConsistencyLoss(nn.Module): - def __init__(self, distance_thr=0.7): - super(ConsistencyLoss, self).__init__() - self.distance_thr = distance_thr - self.cosine_sim = nn.CosineSimilarity(dim=1) - - def forward(self, y1, y2, z1, z2, view1_grid, view2_grid): - #(B, C, H * W) - y1 = y1.flatten(2, -1) - y2 = y2.flatten(2, -1) - z1 = z1.flatten(2, -1) - z2 = z2.flatten(2, -1) - - #pairwise distances between grid coordinates - #(B, C, H * W, H * W) - distances = grid_distances(view1_grid, view2_grid) - - #determine normalization factors for view1 and view2 - #(i.e. distance between "feature map bins") - #(B,) - view1_bin = torch.norm(view1_grid[..., 1, 1] - view1_grid[..., 0, 0], dim=-1) - view2_bin = torch.norm(view2_grid[..., 1, 1] - view2_grid[..., 0, 0], dim=-1) - - #(B, H * W, H * W) - view1_distances = distances / view1_bin[:, None, None] - view2_distances = distances / view2_bin[:, None, None] - - #(B, C, H * W, 1) x (B, C, 1, H * W) --> (B, H * W, H * W) - #important to keep view1 outputs (y1 and z1) as first items - #in the cosine sim measurement because distances tensor - #has that ordering fixed - view1_similarity = self.cosine_sim(y1[..., :, None], z2[..., None, :]) - view2_similarity = self.cosine_sim(z1[..., :, None], y2[..., None, :]) - similarities = view1_similarity + view2_similarity - - #only consider points that are matches for both views - view1_mask = view1_distances <= self.distance_thr - view2_mask = view2_distances <= self.distance_thr - mask = torch.logical_and(view1_mask, view2_mask) - - #mask-out non-matches with zeros - similarities = torch.where(mask, similarities, torch.zeros_like(similarities)) - matches_per_image = mask.sum(dim=(-1, -2)) #(B,) - - #average over images - similarities = similarities.sum(dim=(-1, -2)).masked_select(matches_per_image > 0) - matches_per_image = matches_per_image.masked_select(matches_per_image > 0) #(b,) - similarities = similarities / matches_per_image #(B,) - - #average over batch - matched_images = (matches_per_image > 0).sum() - similarities = similarities.sum() / matched_images #( ) - - return -similarities diff --git a/pretraining/pixpro/dataset.py b/pretraining/pixpro/dataset.py deleted file mode 100644 index 9c0876f..0000000 --- a/pretraining/pixpro/dataset.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -import cv2 -import numpy as np -import torch -from torch.utils.data import Dataset -from glob import glob -from albumentations import ImageOnlyTransform - -class ContrastData(Dataset): - def __init__( - self, - imdir, - space_tfs, - view1_color_tfs, - view2_color_tfs=None - ): - super(ContrastData, self).__init__() - - self.imdir = imdir - self.fpaths = glob(os.path.join(imdir, '*.tiff'), recursive=True) - #self.fnames = os.listdir(imdir) - - print(f'Found {len(self.fpaths)} images in directory') - - #crops, resizes, flips, rotations, etc. - self.space_tfs = space_tfs - - #brightness, contrast, jitter, blur, and - #normalization - self.view1_color_tfs = view1_color_tfs - self.view2_color_tfs = view2_color_tfs - - def __len__(self): - return len(self.fpaths) - - def __getitem__(self, idx): - fpath = self.fpaths[idx] - #fpath = os.path.join(self.imdir, self.fnames[idx]) - image = cv2.imread(fpath, 0) - - y = np.arange(0, image.shape[0], dtype=np.float32) - x = np.arange(0, image.shape[1], dtype=np.float32) - grid_y, grid_x = np.meshgrid(y, x) - grid_y, grid_x = grid_y.T, grid_x.T - - #space transforms treat coordinate grid like an image - #bilinear interp is good, nearest would be bad - view1_data = self.space_tfs(image=image, grid_y=grid_y[..., None], grid_x=grid_x[..., None]) - view2_data = self.space_tfs(image=image, grid_y=grid_y[..., None], grid_x=grid_x[..., None]) - - view1 = view1_data['image'] - view1_grid = np.concatenate([view1_data['grid_y'], view1_data['grid_x']], axis=-1) - view2 = view2_data['image'] - view2_grid = np.concatenate([view2_data['grid_y'], view2_data['grid_x']], axis=-1) - - view1 = self.view1_color_tfs(image=view1)['image'] - if self.view2_color_tfs is not None: - view2 = self.view2_color_tfs(image=view2)['image'] - else: - view2 = self.view1_color_tfs(image=view2)['image'] - - output = { - 'fpath': fpath, - 'view1': view1, - 'view1_grid': torch.from_numpy(view1_grid).permute(2, 0, 1), - 'view2': view2, - 'view2_grid': torch.from_numpy(view2_grid).permute(2, 0, 1) - } - - return output - -class Grayscale(ImageOnlyTransform): - """ - Resizes an image, but not the mask, to be divisible by a specific - number like 32. Necessary for evaluation with segmentation models - that use downsampling. - """ - def __init__(self, channels=1, always_apply=True, p=1.0): - super(Grayscale, self).__init__(always_apply, p) - self.channels = channels - - def apply(self, img, **params): - if img.ndim == 2: - img = np.repeat(img[..., None], self.channels, axis=-1) - elif img.ndim == 3 and img.shape[-1] == 3: - img = img[..., 0] - return img diff --git a/pretraining/pixpro/pixpro_config.yaml b/pretraining/pixpro/pixpro_config.yaml deleted file mode 100644 index 8de43d7..0000000 --- a/pretraining/pixpro/pixpro_config.yaml +++ /dev/null @@ -1,33 +0,0 @@ -#basic definitions -experiment_name: "CEM500K_PixPro100" -data_dir: "/data/IASEM/conradrw/data/cem500k/" # a .npz dask array of filepaths -model_dir: "/data/IASEM/conradrw/models/CEM500K_PixPro100/" # the directory in which to save model states -arch: "resnet50" -workers: 32 -start_epoch: 0 -epochs: 100 -save_freq: 20 -print_freq: 10 -batch_size: 512 -lr: 2 -momentum: 0.9 -weight_decay: 0.00001 -logging: True - -resume: "" -world_size: 1 -rank: 0 -dist_url: "tcp://localhost:10001" -dist_backend: "nccl" -multiprocessing_distributed: True - -pixpro_mom: 0.9934 -ppm_layers: 1 -ppm_gamma: 2 -pixpro_t: 0.7 - -fp16: True - -norms: - mean: 0.58331613 - std: 0.09966064 diff --git a/pretraining/pixpro/train_pixpro.py b/pretraining/pixpro/train_pixpro.py deleted file mode 100644 index 6b8c9b1..0000000 --- a/pretraining/pixpro/train_pixpro.py +++ /dev/null @@ -1,469 +0,0 @@ -""" -Copied and modified from: -https://github.com/facebookresearch/moco/blob/master/main_moco.py - -Which is really copied and modified from the source of all distributed -training scripts: -https://github.com/pytorch/examples/blob/master/imagenet/main.py -""" -import argparse -import builtins -import math -import os, sys -import random -import shutil -import time -import warnings -import yaml - -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.distributed as dist -import torch.optim -import torch.multiprocessing as mp -import torch.utils.data -import torch.utils.data.distributed -import torchvision.transforms as transforms -import torchvision.datasets as datasets -import torchvision.models as models - -from LARC import LARC - -from torch.cuda.amp import autocast -from torch.cuda.amp import GradScaler - -import albumentations as A -from albumentations.pytorch import ToTensorV2 - -from dataset import ContrastData, Grayscale -from builder import PixPro, ConsistencyLoss - -sys.path.append('/home/conradrw/nbs/cellemnet/pretraining/') -from resnet import resnet50 - -import mlflow - -model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) - -def parse_args(): - parser = argparse.ArgumentParser(description='PyTorch PixPro Training') - parser.add_argument('config', help='Path to .yaml training config file') - - return vars(parser.parse_args()) - -def main(): - args = parse_args() - - with open(args['config'], 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) - - if not os.path.isdir(config['model_dir']): - os.mkdir(config['model_dir']) - - config['config_file'] = args['config'] - - #world size is the number of processes that will run - if config['dist_url'] == "env://" and config['world_size'] == -1: - config['world_size'] = int(os.environ["WORLD_SIZE"]) - - config['distributed'] = config['world_size'] > 1 or config['multiprocessing_distributed'] - - ngpus_per_node = torch.cuda.device_count() - config['ngpus_per_node'] = ngpus_per_node - if config['multiprocessing_distributed']: - # Since we have ngpus_per_node processes per node, the total world_size - # needs to be adjusted accordingly - config['world_size'] = ngpus_per_node * config['world_size'] - # Use torch.multiprocessing.spawn to launch distributed processes: the - # main_worker process function - mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config)) - else: - # Simply call main_worker function - main_worker(config['gpu'], ngpus_per_node, config) - -def main_worker(gpu, ngpus_per_node, config): - config['gpu'] = gpu - - # suppress printing if not master process - if config['multiprocessing_distributed'] and config['gpu'] != 0: - def print_pass(*args): - pass - builtins.print = print_pass - - if config['gpu'] is not None: - print("Use GPU: {} for training".format(config['gpu'])) - - if config['distributed']: - if config['dist_url'] == "env://" and config['rank'] == -1: - config['rank'] = int(os.environ["RANK"]) - if config['multiprocessing_distributed']: - # For multiprocessing distributed training, rank needs to be the - # global rank among all the processes - config['rank'] = config['rank'] * ngpus_per_node + gpu - - dist.init_process_group(backend=config['dist_backend'], init_method=config['dist_url'], - world_size=config['world_size'], rank=config['rank']) - - print("=> creating model '{}'".format(config['arch'])) - - model = PixPro( - resnet50, - config['pixpro_mom'], config['ppm_layers'], config['ppm_gamma'] - ) - - if config['distributed']: - # For multiprocessing distributed, DistributedDataParallel constructor - # should always set the single device scope, otherwise, - # DistributedDataParallel will use all available devices. - # Turn on SyncBatchNorm - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - - if config['gpu'] is not None: - torch.cuda.set_device(config['gpu']) - model.cuda(config['gpu']) - # When using a single GPU per process and per - # DistributedDataParallel, we need to divide the batch size - # ourselves based on the total number of GPUs we have - config['batch_size'] = int(config['batch_size'] / ngpus_per_node) - config['workers'] = int((config['workers'] + ngpus_per_node - 1) / ngpus_per_node) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config['gpu']]) - else: - model.cuda() - # DistributedDataParallel will divide and allocate batch_size to all - # available GPUs if device_ids are not set - model = torch.nn.parallel.DistributedDataParallel(model) - elif config['gpu'] is not None: - torch.cuda.set_device(config['gpu']) - model = model.cuda(config['gpu']) - # comment out the following line for debugging - raise NotImplementedError("Only DistributedDataParallel is supported.") - else: - # AllGather implementation (batch shuffle, queue update, etc.) in - # this code only supports DistributedDataParallel. - raise NotImplementedError("Only DistributedDataParallel is supported.") - - #define loss criterion and optimizer - criterion = ConsistencyLoss(distance_thr=config['pixpro_t']).cuda(config['gpu']) - - optimizer = configure_optimizer(model, config) - - # optionally resume from a checkpoint - if config['resume']: - if os.path.isfile(config['resume']): - print("=> loading checkpoint '{}'".format(config['resume'])) - if config['gpu'] is None: - checkpoint = torch.load(config['resume']) - else: - # Map model to be loaded to specified single gpu. - loc = 'cuda:{}'.format(config['gpu']) - checkpoint = torch.load(config['resume'], map_location=loc) - config['start_epoch'] = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer']) - print("=> loaded checkpoint '{}' (epoch {})" - .format(config['resume'], checkpoint['epoch'])) - else: - config['start_epoch'] = 0 - print("=> no checkpoint found at '{}'".format(config['resume'])) - - cudnn.benchmark = True - - norms = config['norms'] - mean_pixel = norms['mean'] - std_pixel = norms['std'] - normalize = A.Normalize(mean=[mean_pixel], std=[std_pixel]) - - #physical space only - space_tfs = A.Compose([ - A.RandomResizedCrop(224, 224, scale=(0.2, 1.0)), - Grayscale(3), - A.HorizontalFlip(), - A.VerticalFlip() - ], additional_targets={'grid_y': 'image', 'grid_x': 'image'}) - - #could work for both views - view1_color_tfs = A.Compose([ - A.ColorJitter(0.4, 0.4, 0.2, 0.1, p=0.8), - Grayscale(1), - A.GaussianBlur(blur_limit=23, sigma_limit=(0.1, 2.0), p=1.0), - normalize, - ToTensorV2() - ]) - - #technically optional, but used in the BYOL paper - view2_color_tfs = A.Compose([ - A.ColorJitter(0.4, 0.4, 0.2, 0.1, p=0.8), - Grayscale(1), - A.GaussianBlur(blur_limit=23, sigma_limit=(0.1, 2.0), p=0.1), - A.GaussNoise(p=0.2), - normalize, - ToTensorV2() - ]) - - train_dataset = ContrastData( - config['data_dir'], space_tfs, view1_color_tfs, view2_color_tfs - ) - - if config['distributed']: - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - else: - train_sampler = None - - train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=config['batch_size'], shuffle=(train_sampler is None), - num_workers=config['workers'], pin_memory=True, sampler=train_sampler, drop_last=True) - - # encoder momentum is updated by STEP and not EPOCH - config['train_steps'] = config['epochs'] * len(train_loader) - config['current_step'] = config['start_epoch'] * len(train_loader) - - if config['fp16']: - scaler = GradScaler() - else: - scaler = None - - #log parameters, if needed: - if config['logging'] and (config['multiprocessing_distributed'] - and config['rank'] % ngpus_per_node == 0): - - #end any old runs - mlflow.end_run() - mlflow.set_experiment(config['experiment_name']) - mlflow.log_artifact(config['config_file']) - - #we don't want to add everything in the config - #to mlflow parameters, we'll just add the most - #likely to change parameters - mlflow.log_param('data_dir', config['data_dir']) - mlflow.log_param('architecture', config['arch']) - mlflow.log_param('epochs', config['epochs']) - mlflow.log_param('batch_size', config['batch_size']) - mlflow.log_param('learning_rate', config['lr']) - mlflow.log_param('pixpro_mom', config['pixpro_mom']) - mlflow.log_param('ppm_layers', config['ppm_layers']) - mlflow.log_param('ppm_gamma', config['ppm_gamma']) - mlflow.log_param('pixpro_t', config['pixpro_t']) - - for epoch in range(config['start_epoch'], config['epochs']): - if config['distributed']: - train_sampler.set_epoch(epoch) - - adjust_learning_rate(optimizer, epoch, config) - - # train for one epoch - train(train_loader, model, criterion, optimizer, scaler, epoch, config) - - if not config['multiprocessing_distributed'] or (config['multiprocessing_distributed'] - and config['rank'] % ngpus_per_node == 0): - save_checkpoint({ - 'epoch': epoch + 1, - 'arch': config['arch'], - 'state_dict': model.state_dict(), - 'optimizer' : optimizer.state_dict(), - 'norms': [mean_pixel, std_pixel], - }, is_best=False, filename=os.path.join(config['model_dir'], 'current.pth.tar')) - - #save checkpoint every save_freq epochs - if (epoch + 1) % config['save_freq'] == 0: - save_checkpoint({ - 'epoch': epoch + 1, - 'arch': config['arch'], - 'state_dict': model.state_dict(), - 'optimizer' : optimizer.state_dict(), - 'norms': [mean_pixel, std_pixel], - }, is_best=False, filename=os.path.join(config['model_dir'] + 'checkpoint_{:04d}.pth.tar'.format(epoch + 1))) - -def train(train_loader, model, criterion, optimizer, scaler, epoch, config): - batch_time = AverageMeter('Time', ':6.3f') - data_time = AverageMeter('Data', ':6.3f') - losses = AverageMeter('Loss', ':.4e') - - progress = ProgressMeter( - len(train_loader), - [batch_time, data_time, losses], - prefix="Epoch: [{}]".format(epoch) - ) - - # switch to train mode - model.train() - - end = time.time() - for i, batch in enumerate(train_loader): - view1 = batch['view1'] - view1_grid = batch['view1_grid'] - view2 = batch['view2'] - view2_grid = batch['view2_grid'] - - # measure data loading time - data_time.update(time.time() - end) - - if config['gpu'] is not None: - view1 = view1.cuda(config['gpu'], non_blocking=True) - view1_grid = view1_grid.cuda(config['gpu'], non_blocking=True) - view2 = view2.cuda(config['gpu'], non_blocking=True) - view2_grid = view2_grid.cuda(config['gpu'], non_blocking=True) - - optimizer.zero_grad() - - # compute output and loss - if config['fp16']: - with autocast(): - output = model(view1, view2, view1_grid, view2_grid) - loss = criterion(*output) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - output = model(view1, view2, view1_grid, view2_grid) - loss = criterion(*output) - loss.backward() - optimizer.step() - - # avg loss from batch size - losses.update(loss.item(), view1.size(0)) - - # update current step and encoder momentum - config['current_step'] += 1 - adjust_encoder_momentum(model, config) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if i % config['print_freq'] == 0: - progress.display(i) - - if config['rank'] % config['ngpus_per_node'] == 0: - # store metrics to mlflow - mlflow.log_metric('sim_loss', losses.avg, step=epoch) - - -def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): - torch.save(state, filename) - if is_best: - shutil.copyfile(filename, 'model_best.pth.tar') - -def configure_optimizer(model, config): - """ - Takes an optimizer and separates parameters into two groups - that either use weight decay or are exempt. - - Only BatchNorm parameters and biases are excluded. - """ - decay = set() - no_decay = set() - - blacklist = (nn.BatchNorm2d,) - for mn, m in model.named_modules(): - for pn, p in m.named_parameters(recurse=False): - full_name = '%s.%s' % (mn, pn) if mn else pn - - if full_name.endswith('bias'): - no_decay.add(full_name) - elif full_name.endswith('weight') and isinstance(m, blacklist): - no_decay.add(full_name) - else: - decay.add(full_name) - - param_dict = {pn: p for pn, p in model.named_parameters()} - inter_params = decay & no_decay - union_params = decay | no_decay - assert(len(inter_params) == 0), "Overlapping decay and no decay" - assert(len(param_dict.keys() - union_params) == 0), "Missing decay parameters" - - decay_params = [param_dict[pn] for pn in sorted(list(decay))] - no_decay_params = [param_dict[pn] for pn in sorted(list(no_decay))] - - #the adapt_lr key tells LARS not to adapt the lr (see 'LARC.py') - param_groups = [ - {"params": decay_params, "weight_decay": config['weight_decay'], "adapt_lr": True}, - {"params": no_decay_params, "weight_decay": 0., "adapt_lr": False} - ] - - base_optimizer = torch.optim.SGD( - param_groups, lr=config['lr'], momentum=config['momentum'] - ) - - #LARC without clipping == LARS - #lower trust_coefficient to match SimCLR and BYOL - #(too high of a trust_coefficient leads to NaN losses!) - optimizer = LARC(optimizer=base_optimizer, trust_coefficient=1e-3) - - return optimizer - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self, name, fmt=':f'): - self.name = name - self.fmt = fmt - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - def __str__(self): - fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' - return fmtstr.format(**self.__dict__) - -class ProgressMeter(object): - def __init__(self, num_batches, meters, prefix=""): - self.batch_fmtstr = self._get_batch_fmtstr(num_batches) - self.meters = meters - self.prefix = prefix - - def display(self, batch): - entries = [self.prefix + self.batch_fmtstr.format(batch)] - entries += [str(meter) for meter in self.meters] - print('\t'.join(entries)) - - def _get_batch_fmtstr(self, num_batches): - num_digits = len(str(num_batches // 1)) - fmt = '{:' + str(num_digits) + 'd}' - return '[' + fmt + '/' + fmt.format(num_batches) + ']' - -def adjust_encoder_momentum(model, config): - base_mom = config['pixpro_mom'] - new_mom = 1 - (1 - base_mom) * (math.cos(math.pi * config['current_step'] / config['train_steps']) + 1) / 2 - model.module.momentum = new_mom - -def adjust_learning_rate(optimizer, epoch, config): - """Decay the learning rate based on schedule""" - lr = config['lr'] - #cosine lr schedule - lr *= 0.5 * (1. + math.cos(math.pi * epoch / config['epochs'])) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - -def accuracy(output, target, topk=(1,)): - """Computes the accuracy over the k top predictions for the specified values of k""" - with torch.no_grad(): - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -if __name__ == '__main__': - main() From 2a298f07f82eaa4fef968fa7d0e8e127458efc2b Mon Sep 17 00:00:00 2001 From: conradry Date: Thu, 19 Aug 2021 10:37:13 -0400 Subject: [PATCH 08/19] pkl patches filtering --- dataset/classify_nn.py | 156 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 dataset/classify_nn.py diff --git a/dataset/classify_nn.py b/dataset/classify_nn.py new file mode 100644 index 0000000..cebc896 --- /dev/null +++ b/dataset/classify_nn.py @@ -0,0 +1,156 @@ +""" +Description: +------------ + +Fits a ResNet34 model to images that have manually been labeled as "informative" or "uninformative". It's assumed that +images have been manually labeled using the corrector.py utilities running in a Jupyter notebook (see notebooks/labeling.ipynb). + +The results of this script are the roc curve plot on a randomly chosen validation set of images, the +model state dict as a .pth file and the model's predictions on all the remaining unlabeled images. + +Example usage: +-------------- + +python classify_nn.py {impaths_file} {savedir} --labels {label_file} --weights {weights_file} + +For help with arguments: +------------------------ + +python classify_nn.py --help +""" + +DEFAULT_WEIGHTS = "https://www.dropbox.com/s/2libiwgx0qdgxqv/patch_quality_classifier_nn.pth?raw=1" + +import os, sys, cv2, argparse +import pickle +import numpy as np +from skimage import io +from glob import glob + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torchvision.models import resnet34 +from torch.optim import Adam +from torch.utils.data import DataLoader, Dataset + +from albumentations import Compose, Normalize, Resize +from albumentations.pytorch import ToTensorV2 +from tqdm import tqdm + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Classifies a set of images by fitting a random forest to an array of descriptive features' + ) + parser.add_argument('dedupe_dir', type=str) + parser.add_argument('savedir', type=str) + parser.add_argument('--weights', type=str, metavar='weights', + help='Optional, path to nn weights file. The default is to download weights used in the paper.') + args = parser.parse_args() + + # parse the arguments + dedupe_dir = args.dedupe_dir + savedir = args.savedir + weights = args.weights + + # make sure the savedir exists + if not os.path.isdir(savedir): + os.mkdir(savedir) + + # list all pkl deduplicated files + fpaths = glob(os.path.join(dedupe_dir, '*.pkl')) + + # set up evaluation transforms (assumes imagenet + # pretrained as default in train_nn.py) + imsize = 224 + normalize = Normalize() #default is imagenet normalization + eval_tfs = Compose([ + Resize(imsize, imsize), + normalize, + ToTensorV2() + ]) + + # create the resnet34 model + model = resnet34() + + # modify the output layer to predict 1 class only + model.fc = nn.Linear(in_features=512, out_features=1) + + # load the weights from file or from online + # load the weights from file or from online + if weights is not None: + state_dict = torch.load(weights, map_location='cpu') + else: + state_dict = torch.hub.load_state_dict_from_url(DEFAULT_WEIGHTS) + + # load in the weights (strictly) + msg = model.load_state_dict(state_dict) + model = model.cuda() + cudnn.benchmark = True + + # make a basic dataset class for loading and + # augmenting images WITHOUT any labels + class SimpleDataset(Dataset): + def __init__(self, image_dict, tfs=None): + super(SimpleDataset, self).__init__() + self.image_dict = image_dict + self.tfs = tfs + + def __len__(self): + return len(self.image_dict['names']) + + def __getitem__(self, idx): + # load the image + fname = self.image_dict['names'][idx] + image = self.image_dict['patches'][idx] + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + + # apply transforms + if self.tfs is not None: + image = self.tfs(image=image)['image'] + + return {'fname': fname, 'image': image} + + + for fp in fpaths: + print(f'Filtering on {fp}') + dataset_name = os.path.basename(fp) + if '-ROI-' in dataset_name: + dataset_name = dataset_name.split('-ROI-')[0] + else: + dataset_name = dataset_name[:-4] # remove .pkl + + dataset_savedir = os.path.join(savedir, dataset_name) + if not os.path.exists(dataset_savedir): + os.mkdir(dataset_savedir) + + # load the patches_dict + with open(fp, mode='rb') as handle: + patches_dict = pickle.load(handle) + + # create datasets for the train, validation, and test sets + tst_data = SimpleDataset(patches_dict, eval_tfs) + test = DataLoader(tst_data, batch_size=128, shuffle=False, pin_memory=True, num_workers=4) + + # lastly run inference on the entire set of unlabeled images + print(f'Running inference on test set...') + tst_fnames = [] + tst_predictions = [] + for data in test: + with torch.no_grad(): + # load data onto gpu then forward pass + images = data['image'].cuda(non_blocking=True) + output = model.eval()(images) + predictions = nn.Sigmoid()(output) + + predictions = predictions.detach().cpu().numpy() + tst_predictions.append(predictions) + tst_fnames.append(data['fname']) + + tst_fnames = np.concatenate(tst_fnames, axis=0) + tst_predictions = np.concatenate(tst_predictions, axis=0) + tst_predictions = (tst_predictions[:, 0] > 0.5).astype(np.uint8) + + for ix, (fn, img) in enumerate(zip(patches_dict['names'], patches_dict['patches'])): + if tst_predictions[ix] == 1: + io.imsave(os.path.join(dataset_savedir, fn + '.tiff'), img) \ No newline at end of file From 0ee4d2065c2a85959e93e6604565a86a8e27cb8b Mon Sep 17 00:00:00 2001 From: conradry Date: Fri, 15 Apr 2022 11:50:10 -0400 Subject: [PATCH 09/19] updated curation pipeline --- .../{classify_nn.py => classify_patches.py} | 37 ++- dataset/patchify2d.py | 161 +++++++++++ dataset/patchify3d.py | 260 +++++++++++++++++ dataset/preprocess/binvol.py | 50 ++++ dataset/preprocess/cleanup2d.py | 137 +++++++++ dataset/preprocess/mrc2byte.py | 51 ++++ dataset/preprocess/vid2stack.py | 64 +++++ dataset/{filtered => }/train_nn.py | 0 dataset/train_patch_classifier.py | 267 ++++++++++++++++++ 9 files changed, 1008 insertions(+), 19 deletions(-) rename dataset/{classify_nn.py => classify_patches.py} (81%) create mode 100644 dataset/patchify2d.py create mode 100644 dataset/patchify3d.py create mode 100644 dataset/preprocess/binvol.py create mode 100644 dataset/preprocess/cleanup2d.py create mode 100644 dataset/preprocess/mrc2byte.py create mode 100644 dataset/preprocess/vid2stack.py rename dataset/{filtered => }/train_nn.py (100%) create mode 100644 dataset/train_patch_classifier.py diff --git a/dataset/classify_nn.py b/dataset/classify_patches.py similarity index 81% rename from dataset/classify_nn.py rename to dataset/classify_patches.py index cebc896..95b4d63 100644 --- a/dataset/classify_nn.py +++ b/dataset/classify_patches.py @@ -2,16 +2,12 @@ Description: ------------ -Fits a ResNet34 model to images that have manually been labeled as "informative" or "uninformative". It's assumed that -images have been manually labeled using the corrector.py utilities running in a Jupyter notebook (see notebooks/labeling.ipynb). - -The results of this script are the roc curve plot on a randomly chosen validation set of images, the -model state dict as a .pth file and the model's predictions on all the remaining unlabeled images. +Classifies EM images into "informative" or "uninformative". Example usage: -------------- -python classify_nn.py {impaths_file} {savedir} --labels {label_file} --weights {weights_file} +python classify_nn.py {deduped_dir} {savedir} --labels {label_file} --weights {weights_file} For help with arguments: ------------------------ @@ -19,7 +15,7 @@ python classify_nn.py --help """ -DEFAULT_WEIGHTS = "https://www.dropbox.com/s/2libiwgx0qdgxqv/patch_quality_classifier_nn.pth?raw=1" +DEFAULT_WEIGHTS = "https://zenodo.org/record/6458015/files/patch_quality_classifier_nn.pth?download=1" import os, sys, cv2, argparse import pickle @@ -42,10 +38,11 @@ parser = argparse.ArgumentParser( description='Classifies a set of images by fitting a random forest to an array of descriptive features' ) - parser.add_argument('dedupe_dir', type=str) + parser.add_argument('dedupe_dir', type=str, help='Directory containing ') parser.add_argument('savedir', type=str) parser.add_argument('--weights', type=str, metavar='weights', help='Optional, path to nn weights file. The default is to download weights used in the paper.') + args = parser.parse_args() # parse the arguments @@ -53,6 +50,8 @@ savedir = args.savedir weights = args.weights + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # make sure the savedir exists if not os.path.isdir(savedir): os.mkdir(savedir) @@ -85,7 +84,7 @@ # load in the weights (strictly) msg = model.load_state_dict(state_dict) - model = model.cuda() + model = model.to(device) cudnn.benchmark = True # make a basic dataset class for loading and @@ -111,35 +110,35 @@ def __getitem__(self, idx): return {'fname': fname, 'image': image} - - for fp in fpaths: - print(f'Filtering on {fp}') + for fp in tqdm(fpaths): dataset_name = os.path.basename(fp) if '-ROI-' in dataset_name: dataset_name = dataset_name.split('-ROI-')[0] else: - dataset_name = dataset_name[:-4] # remove .pkl - + dataset_name = dataset_name[:-len('.pkl')] + dataset_savedir = os.path.join(savedir, dataset_name) if not os.path.exists(dataset_savedir): os.mkdir(dataset_savedir) - + else: + continue + # load the patches_dict with open(fp, mode='rb') as handle: patches_dict = pickle.load(handle) # create datasets for the train, validation, and test sets tst_data = SimpleDataset(patches_dict, eval_tfs) - test = DataLoader(tst_data, batch_size=128, shuffle=False, pin_memory=True, num_workers=4) + test = DataLoader(tst_data, batch_size=128, shuffle=False, + pin_memory=True, num_workers=4) # lastly run inference on the entire set of unlabeled images - print(f'Running inference on test set...') tst_fnames = [] tst_predictions = [] for data in test: with torch.no_grad(): # load data onto gpu then forward pass - images = data['image'].cuda(non_blocking=True) + images = data['image'].to(device, non_blocking=True) output = model.eval()(images) predictions = nn.Sigmoid()(output) @@ -153,4 +152,4 @@ def __getitem__(self, idx): for ix, (fn, img) in enumerate(zip(patches_dict['names'], patches_dict['patches'])): if tst_predictions[ix] == 1: - io.imsave(os.path.join(dataset_savedir, fn + '.tiff'), img) \ No newline at end of file + io.imsave(os.path.join(dataset_savedir, fn + '.tiff'), img, check_contrast=False) \ No newline at end of file diff --git a/dataset/patchify2d.py b/dataset/patchify2d.py new file mode 100644 index 0000000..ec7db0e --- /dev/null +++ b/dataset/patchify2d.py @@ -0,0 +1,161 @@ +""" +Description: +------------ + +This script accepts a directory with image volume files and slices cross sections +from the given axes (xy, xz, yz). The resultant cross sections are saved in +the given save directory. + +Importantly, the saved image files are given a slightly different filename: +We add '-LOC-{axis}_{slice_index}' to the end of the filename, where axis denotes the +cross-sectioning plane (0->xy, 1->xz, 2->yz) and the slice index is the position of +the cross-section on that axis. Once images from 2d and 3d datasets +start getting mixed together, it can be difficult to keep track of the +provenance of each patch. Everything that appears before '-LOC-' is the +name of the original dataset, the axis and slice index allow us to lookup the +exact location of the cross-section in the volume. + +Example usage: +-------------- + +python cross_section3d.py {imdir} {savedir} --axes 0 1 2 --spacing 1 --processes 4 + +For help with arguments: +------------------------ + +python cross_section3d.py --help +""" + +import os +import math +import pickle +import argparse +import numpy as np +from glob import glob +from skimage import io +from multiprocessing import Pool + +MAX_VALUES_BY_DTYPE = { + np.dtype("uint8"): 255, + np.dtype("uint16"): 65535, + np.dtype("int16"): 32767, + np.dtype("uint32"): 4294967295, + np.dtype("float32"): 1.0, +} + +def patch_crop(image, crop_size=224): + if image.ndim == 3: + if image.shape[2] not in [1, 3]: + print('Accidentally 3d?', image.shape) + image = image[..., 0] + + # at least 1 image patch + ysize, xsize = image.shape + ny = max(1, int(round(ysize / crop_size))) + nx = max(1, int(round(xsize / crop_size))) + + patches = [] + locs = [] + for y in range(ny): + # start and end indices for y + ys = y * crop_size + ye = min(ys + crop_size, ysize) + for x in range(nx): + # start and end indices for x + xs = x * crop_size + xe = min(xs + crop_size, xsize) + + # crop the patch + patch = image[ys:ye, xs:xe] + + patches.append(patch) + locs.append(f'{ys}-{ye}_{xs}-{xe}') + + return patches, locs + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Create dataset for nn experimentation') + parser.add_argument('imdir', type=str, metavar='imdir', help='Directory containing 2d image files') + parser.add_argument('savedir', type=str, metavar='savedir', help='Path to save the patch files') + parser.add_argument('-cs', '--crop_size', dest='crop_size', type=int, metavar='crop_size', default=224, + help='Size of square image patches. Default 224.') + parser.add_argument('-p', '--processes', dest='processes', type=int, metavar='processes', default=4, + help='Number of processes to run, more processes will run faster but consume more memory') + + + args = parser.parse_args() + + # read in the parser arguments + imdir = args.imdir + savedir = args.savedir + crop_size = args.crop_size + processes = args.processes + + # check if the savedir exists, if not create it + if not os.path.isdir(savedir): + os.mkdir(savedir) + + # get the list of all images (png, jpg, tif) + fpath_groups = {} + for sd in glob(os.path.join(imdir, '*')): + if os.path.isdir(sd): + sd_name = os.path.basename(sd) + fpath_groups[sd_name] = [fp for fp in glob(os.path.join(sd, '*')) if not os.path.isdir(fp)] + + print(f'Found {len(fpath_groups.keys())} image groups to process') + + subdirs = list(fpath_groups.keys()) + fpath_lists = list(fpath_groups.values()) + + def create_patches(*args): + subdir, fpaths = args[0] + + exp_name = subdir + patch_dict = {'names': [], 'patches': []} + zpad = 1 + math.ceil(math.log(len(fpaths))) + + # check if results have already been generated, + # skip this image if so. useful for resuming + out_path = os.path.join(savedir, exp_name + '.pkl') + if os.path.isfile(out_path): + print(f'Already processed {fp}, skipping!') + return + + for ix,fp in enumerate(fpaths): + # try to load the image, if it's not possible + # then pass but print + try: + im = io.imread(fp) + except: + print('Failed to open: ', fp) + return + + assert (im.min() >= 0), 'Negative images not allowed!' + + if im.dtype != np.uint8: + dtype = im.dtype + max_value = MAX_VALUES_BY_DTYPE[dtype] + im = im.astype(np.float32) / max_value + im = (im * 255).astype(np.uint8) + + # crop the image into patches + patches, locs = patch_crop(im, crop_size) + + # appropriate filenames with location + imname = str(ix).zfill(zpad) + names = [] + for loc_str in locs: + # add the -LOC- to indicate the point of separation between + # the dataset name and the slice location information + patch_loc_str = f'-LOC-2d-{loc_str}' + names.append(imname + patch_loc_str) + + # store results in patch_dict + patch_dict['names'].extend(names) + patch_dict['patches'].extend(patches) + + with open(out_path, 'wb') as handle: + pickle.dump(patch_dict, handle) + + with Pool(processes) as pool: + pool.map(create_patches, zip(subdirs, fpath_lists)) diff --git a/dataset/patchify3d.py b/dataset/patchify3d.py new file mode 100644 index 0000000..19164e7 --- /dev/null +++ b/dataset/patchify3d.py @@ -0,0 +1,260 @@ +""" +Description: +------------ + +This script accepts a directory with image volume files and slices cross sections +from the given axes (xy, xz, yz). Then it patches the cross-sections into many smaller +images. All patches from a dataset a deduplicated such that patches with nearly identical +content are filtered out. + +Patches along with filenames are stored in a dictionary and saved as pickle +files. Importantly, filenames follow the convention: + +'{dataset_name}-LOC-{slicing_axis}_{slice_index}_{h1}-{h2}_{w1}-{w2}' + +Slicing axis denotes the cross-sectioning plane (0->xy, 1->xz, 2->yz). Slice index +is the index of the image along the slicing axis. h1,h2 are start and end rows and +w1,w2 are start and end columns. This gives enough information to precisely locate +the patch in the original 3D dataset. + +Lastly, if the directory of 3D datasets includes a mixture of isotropic and anisotropic +volumes it is important that each dataset has a correct header recording the voxel +size. This script uses SimpleITK to read the header. If z resolution is more that 25% +different than xy resolution, then cross-sections will only be cut from the xy plane +even if axes 0, 1, 2 are passed to the script (see usage example below). + +Likewise, if there are video files as well, it is essential that they have the word 'video' +somewhere in the filename. + +Example usage: +-------------- + +python patchify3d.py {imdir} {savedir} --axes 0 1 2 --spacing 1 --processes 4 + +For help with arguments: +------------------------ + +python patchify3d.py --help + +""" + +import os +import math +import pickle +import argparse +import imagehash +import numpy as np +import SimpleITK as sitk +from glob import glob +from PIL import Image +from skimage import io +from multiprocessing import Pool + +MAX_VALUES_BY_DTYPE = { + np.dtype("uint8"): 255, + np.dtype("uint16"): 65535, + np.dtype("int16"): 32767, + np.dtype("uint32"): 4294967295, + np.dtype("float32"): 1.0, +} + +def calculate_hash(image, crop_size, hash_size=8): + # calculate the hash on the resized image + imsize = (crop_size, crop_size) + pil_image = Image.fromarray(image).resize(imsize, resample=2) + + return imagehash.dhash(pil_image, hash_size=hash_size).hash + +def patch_and_hash(image, crop_size=224, hash_size=8): + if image.ndim == 3: + image = image[..., 0] + + # at least 1 image patch of any size + ysize, xsize = image.shape + ny = max(1, int(round(ysize / crop_size))) + nx = max(1, int(round(xsize / crop_size))) + + patches = [] + hashes = [] + locs = [] + for y in range(ny): + # start and end indices for y + ys = y * crop_size + ye = min(ys + crop_size, ysize) + for x in range(nx): + # start and end indices for x + xs = x * crop_size + xe = min(xs + crop_size, xsize) + + # crop the patch and calculate its hash + patch = image[ys:ye, xs:xe] + patch_hash = calculate_hash(patch, crop_size, hash_size) + + patches.append(patch) + hashes.append(patch_hash) + locs.append(f'{ys}-{ye}_{xs}-{xe}') + + return patches, hashes, locs + +def deduplicate(patch_dict, min_distance): + # all hashes are the same size + hashes = np.array(patch_dict['hashes']) + hashes = hashes.reshape(len(hashes), -1) + + # randomly permute the hashes such that we'll have random ordering + random_indices = np.random.permutation(np.arange(0, len(hashes))) + hashes = hashes[random_indices] + + exemplars = [] + while len(hashes) > 0: + ref_hash = hashes[0] + + # a match has Hamming distance less than min_distance + matches = np.where( + np.logical_xor(ref_hash, hashes).sum(1) <= min_distance + )[0] + + # ref_hash is the exemplar (i.e. first in matches) + exemplars.append(random_indices[matches[0]]) + + # remove all the matched images from both hashes and indices + hashes = np.delete(hashes, matches, axis=0) + random_indices = np.delete(random_indices, matches, axis=0) + + names = [] + patches = [] + for index in exemplars: + names.append(patch_dict['names'][index]) + patches.append(patch_dict['patches'][index]) + + return {'names': names, 'patches': patches} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Create dataset for nn experimentation') + parser.add_argument('imdir', type=str, metavar='imdir', help='Directory containing volume files') + parser.add_argument('savedir', type=str, metavar='savedir', help='Path to save the cross sections') + parser.add_argument('-a', '--axes', dest='axes', type=int, metavar='axes', nargs='+', default=[0, 1, 2], + help='Volume axes along which to slice (0-xy, 1-xz, 2-yz)') + parser.add_argument('-s', '--spacing', dest='spacing', type=int, metavar='spacing', default=1, + help='Spacing between image slices') + parser.add_argument('-cs', '--crop_size', dest='crop_size', type=int, metavar='crop_size', default=224, + help='Size of square image patches. Default 224.') + parser.add_argument('-hs', '--hash_size', dest='hash_size', type=int, metavar='hash_size', default=8, + help='Size of the image hash. Default 8 (assumes crop size of 224).') + parser.add_argument('-d', '--min_distance', dest='min_distance', type=int, metavar='min_distance', default=12, + help='Minimum Hamming distance between hashes to be considered unique. Default 12 (assumes hash size of 8)') + parser.add_argument('-p', '--processes', dest='processes', type=int, metavar='processes', default=4, + help='Number of processes to run, more processes will run faster but consume more memory') + + args = parser.parse_args() + + # read in the parser arguments + imdir = args.imdir + savedir = args.savedir + axes = args.axes + spacing = args.spacing + crop_size = args.crop_size + hash_size = args.hash_size + min_distance = args.min_distance + processes = args.processes + + # check if the savedir exists, if not create it + if not os.path.isdir(savedir): + os.mkdir(savedir) + + # get the list of all volumes (mrc, tif, nrrd, nii.gz, etc.) + fpaths = glob(os.path.join(imdir, '*')) + print(f'Found {len(fpaths)} image volumes to process') + + def create_slices(fp): + # extract the experiment name from the filepath + # add a special case for .nii.gz files + if fp[-5:] == 'nii.gz': + fext = 'nii.gz' + else: + fext = fp.split('.')[-1] + + exp_name = os.path.basename(fp).split(f'.{fext}')[0] + + # check if results have already been generated + # skip this volume, if so. useful for resuming + out_path = os.path.join(savedir, exp_name + '.pkl') + if os.path.isfile(out_path): + print(f'Already processed {fp}, skipping!') + return + + # try to load the volume, if it's not possible + # then pass but print + try: + im = sitk.ReadImage(fp) + + if len(im.GetSize()) > 3: + im = im[..., 0] + + print('Loaded', fp, im.GetSize()) + except: + print('Failed to open: ', fp) + pass + + # extract the pixel size from the volume + # if the z-pixel size is more than 25% different + # from the x-pixel size, don't slice over orthogonal + # directions + pixel_sizes = im.GetSpacing() + anisotropy = np.abs(pixel_sizes[0] - pixel_sizes[2]) / pixel_sizes[0] + + im = sitk.GetArrayFromImage(im) + assert (im.min() >= 0), 'Negative images not allowed!' + + if im.dtype != np.uint8: + dtype = im.dtype + max_value = MAX_VALUES_BY_DTYPE[dtype] + im = im.astype(np.float32) / max_value + im = (im * 255).astype(np.uint8) + + patch_dict = {'names': [], 'patches': [], 'hashes': []} + for axis in axes: + # only process xy slices if the volume is anisotropic + if (anisotropy > 0.25 or 'video' in exp_name.lower()) and (axis != 0): + continue + + # evenly spaced slices + nmax = im.shape[axis] - 1 + slice_indices = np.arange(0, nmax, spacing, dtype='int') + zpad = math.ceil(math.log(nmax, 10)) + + for idx in slice_indices: + # slice the volume on the proper axis + if axis == 0: + im_slice = im[idx] + elif axis == 1: + im_slice = im[:, idx] + else: + im_slice = im[:, :, idx] + + # crop the image into patches + patches, hashes, locs = patch_and_hash(im_slice, crop_size, hash_size) + + # appropriate filenames with location + names = [] + for loc_str in locs: + # add the -LOC- to indicate the point of separation between + # the dataset name and the slice location information + index_str = str(idx).zfill(zpad) + patch_loc_str = f'-LOC-{axis}_{index_str}_{loc_str}' + names.append(exp_name + patch_loc_str) + + # store results in patch_dict + patch_dict['names'].extend(names) + patch_dict['patches'].extend(patches) + patch_dict['hashes'].extend(hashes) + + patch_dict = deduplicate(patch_dict, min_distance) + + out_path = os.path.join(savedir, exp_name + '.pkl') + with open(out_path, 'wb') as handle: + pickle.dump(patch_dict, handle) + + with Pool(processes) as pool: + pool.map(create_slices, fpaths) \ No newline at end of file diff --git a/dataset/preprocess/binvol.py b/dataset/preprocess/binvol.py new file mode 100644 index 0000000..89bad70 --- /dev/null +++ b/dataset/preprocess/binvol.py @@ -0,0 +1,50 @@ +""" +Description: +------------ + +This script is used to bin .mrc files by a factor of 2. This is a standard +preprocessing step to lower the resolution from under 10 nm to 10-20 nm range. + +For help downloading and installing IMOD, see: +https://bio3d.colorado.edu/imod/ + +Example usage: +-------------- + +python binvol.py {mrcdir} {--inplace} + +For help with arguments: +------------------------ + +python binvol.py --help + +""" + +import os, argparse +import subprocess +from glob import glob + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('mrcdir', type=str, help='Directory containing mrc image volumes') + parser.add_argument('--inplace', action='store_true', + help='If passed, original mrcs will be permanently deleted in favor of the binned data.') + + args = parser.parse_args() + + #read in the argument + mrcdir = args.mrcdir + inplace = args.inplace + + #gather the mrc filepaths + fnames = glob(os.path.join(mrcdir, '*.mrc')) + print('Found {} mrc files to bin'.format(len(fnames))) + FNULL = open(os.devnull, 'w') + + for fn in fnames: + #create the IMOD command and run it + command = ['binvol', fn, fn.replace('.mrc', 'BV2.mrc')] + subprocess.call(command, stdout=FNULL, stderr=subprocess.STDOUT) + + if inplace: + os.remove(fn) \ No newline at end of file diff --git a/dataset/preprocess/cleanup2d.py b/dataset/preprocess/cleanup2d.py new file mode 100644 index 0000000..8c7395a --- /dev/null +++ b/dataset/preprocess/cleanup2d.py @@ -0,0 +1,137 @@ +""" +Description: +------------ + +This script accepts a directory with 2d images and processes them +to make sure they all have unsigned 8-bit pixels and are grayscale. + +The resultant images +are saved in the given save directory. + +Importantly, the saved image files are given a slightly different filename: +We add '-LOC-2d' to the end of the filename. Once images from 2d and 3d datasets +start getting mixed together, it can be difficult to keep track of the +provenance of each patch. Everything that appears before '-LOC-2d' is the +name of the original dataset and we, of course, know that that dataset is +a 2d EM image. + +Example usage: +-------------- + +python cleanup2d.py {imdir} -o {savedir} -p 4 + +For help with arguments: +------------------------ + +python cleanup2d.py --help +""" + +import os, math +import argparse +import numpy as np +from glob import glob +from skimage.io import imread, imsave +from multiprocessing import Pool + +if __name__ == "__main__": + + #setup the argument parser + parser = argparse.ArgumentParser(description='Create dataset for nn experimentation') + parser.add_argument('imdir', type=str, metavar='imdir', help='Directory containing subdirectories of 2d EM images') + parser.add_argument('-o', type=str, metavar='savedir', dest='savedir', + help='Path to save the processed images as copies, if not given images will be overwritten.') + parser.add_argument('-p', '--processes', dest='processes', type=int, metavar='processes', default=4, + help='Number of processes to run, more processes will run faster but consume more memory') + + + args = parser.parse_args() + + imdir = args.imdir + processes = args.processes + + savedir = args.savedir + if savedir is None: + savedir = args.imdir + else: + os.makedirs(savedir, exist_ok=True) + + # get the list of all images (png, jpg, tif, etc.) + fpath_groups = {} + for sd in glob(os.path.join(imdir, '*')): + if os.path.isdir(sd): + sd_name = os.path.basename(sd) + fpath_groups[sd_name] = [fp for fp in glob(os.path.join(sd, '*')) if not os.path.isdir(fp)] + + print(f'Found {len(fpath_groups.keys())} image groups to process') + + subdirs = list(fpath_groups.keys()) + fpath_lists = list(fpath_groups.values()) + + def process_image(*args): + subdir, fpaths = args[0] + for fp in fpaths: + try: + im = imread(fp) + + # has to be 2d image with or without channels + assert im.ndim < 4 + + # make sure last dim is channel + if im.ndim == 3: + assert im.shape[-1] <= 4 + im = im[..., 0] + + except: + print('Failed to read: ', fp) + pass + + dtype = str(im.dtype) + + is_float = False + unsigned = False + if dtype[0] == 'u': + unsigned = True + elif dtype[0] == 'f': + is_float = True + + if dtype == 'uint8': # nothing to do + pass + else: + # get the bitdepth + bits = int(dtype[-2]) # 16, 32, or 64 + + # explicitly convert the image to float + im = im.astype('float') + + if unsigned: + # unsigned conversion just requires division by max and + # multiplication by 255 + im /= (2 ** bits) + im *= 255 + elif unsigned is False and is_float is False: + # signed conversion adds the additional step of + # subtracting the minimum negative value + im -= -(2 ** (bits - 1)) + im /= (2 ** bits) + im *= 255 + else: + # we're working with float. + # because the range is variable, we'll just subtract + # the minimum, divide by maximum, and multiply by 255 + im -= im.min() + im /= im.max() + im *= 255 + + im = im.astype(np.uint8) + + # save the processed image to the new directory + outdir = os.path.join(savedir, subdir) + if not os.path.exists(outdir): + os.mkdir(outdir) + + imsave(os.path.join(outdir, im_name), im, check_contrast=False) + + with Pool(processes) as pool: + pool.map(process_image, zip(subdirs, fpath_lists)) + + print('Finished 2D image cleanup.') \ No newline at end of file diff --git a/dataset/preprocess/mrc2byte.py b/dataset/preprocess/mrc2byte.py new file mode 100644 index 0000000..bc733bd --- /dev/null +++ b/dataset/preprocess/mrc2byte.py @@ -0,0 +1,51 @@ +""" +Description: +------------ + +Throughout this repository we use SimpleITK to load image volumes. The MRC +file format can sometimes cause issues when the files are saved with signed +bytes. To prevent errors in which images are cliiped from 0-127, it is necessary +to make the mrc volumes unsigned byte type. This script takes a directory +containing mrc files and performs that conversion using IMOD. + +For help downloading and installing IMOD, see: +https://bio3d.colorado.edu/imod/ + +Example usage: +-------------- + +python mrc2byte.py {mrcdir} + +For help with arguments: +------------------------ + +python mrc2byte.py --help + +""" + +import os, argparse +import subprocess +from glob import glob + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('mrcdir', type=str, help='Directory containing mrc image volumes.') + args = parser.parse_args() + + #read in the argument + mrcdir = args.mrcdir + + #gather the mrc filepaths + fnames = glob(os.path.join(mrcdir, '*.mrc')) + print('Found {} mrc files'.format(len(fnames))) + FNULL = open(os.devnull, 'w') + + for fn in fnames: + #create the IMOD command and run it + command = ['newstack', fn, fn, '-by', '0'] + subprocess.call(command, stdout=FNULL, stderr=subprocess.STDOUT) + + #IMOD won't overwrite the old file + #instead it renames it with a '~' at + #the end. here we remove that old file + os.remove(fn + '~') \ No newline at end of file diff --git a/dataset/preprocess/vid2stack.py b/dataset/preprocess/vid2stack.py new file mode 100644 index 0000000..1a49974 --- /dev/null +++ b/dataset/preprocess/vid2stack.py @@ -0,0 +1,64 @@ +""" +Description: +------------ + +Simple script for converting video files (mp4 and avi) into +nrrd image volumes. + +Example usage: +-------------- + +python vid2stack.py {viddir} + +For help with arguments: +------------------------ + +python vid2stack.py --help + +""" + +import os, argparse +import cv2 +import numpy as np +import SimpleITK as sitk +from glob import glob + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('viddir', type=str, help='Directory containing video files: avi or mp4') + args = parser.parse_args() + + viddir = args.viddir + + # only avi and mp4 support + vidfiles = glob(os.path.join(viddir, '*.mp4')) + vidfiles = vidfiles + glob(os.path.join(viddir, '*.avi')) + + print(f'Found {len(vidfiles)} video files.') + + for vf in vidfiles: + cap = cv2.VideoCapture(vf) + + # load the first frame + success, frame = cap.read() + # note that grayscale videos have 3 duplicate channels, + frames = [frame[..., 0]] + + while success: + success, frame = cap.read() + if frame is not None: + frames.append(frame[..., 0]) + + video = np.stack(frames, axis=0) + + fdir = os.path.dirname(vf) + fname = os.path.basename(vf) + fext = fname.split('.')[-1] + + if 'video' not in fname.lower(): + suffix = '_video.nrrd' + else: + suffix = '.nrrd' + + outpath = os.path.join(fdir, fname.replace(fext, suffix)) + sitk.WriteImage(sitk.GetImageFromArray(video), outpath) \ No newline at end of file diff --git a/dataset/filtered/train_nn.py b/dataset/train_nn.py similarity index 100% rename from dataset/filtered/train_nn.py rename to dataset/train_nn.py diff --git a/dataset/train_patch_classifier.py b/dataset/train_patch_classifier.py new file mode 100644 index 0000000..c17c280 --- /dev/null +++ b/dataset/train_patch_classifier.py @@ -0,0 +1,267 @@ +""" +Description: +------------ + +Fits a ResNet34 model to images that have manually been labeled as "informative" or "uninformative". It's assumed that +images have been manually labeled using the corrector.py utilities running in a Jupyter notebook (see notebooks/labeling.ipynb). + +The results of this script are the roc curve plot on a randomly chosen validation set of images and the +model state dict as a .pth file. + +Example usage: +-------------- + +python train_nn.py {impaths_file} {labels_fpath} {savedir} + +For help with arguments: +------------------------ + +python train_nn.py --help +""" + +import os, sys, cv2, argparse +import numpy as np +from sklearn.metrics import confusion_matrix, accuracy_score, roc_curve, roc_auc_score +from sklearn.model_selection import train_test_split + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torchvision.models import resnet34 +from torch.optim import Adam +from torch.utils.data import DataLoader, Dataset + +import albumentations as A +from albumentations.pytorch import ToTensorV2 +from tqdm import tqdm +from matplotlib import pyplot as plt + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Classifies a set of images by fitting a random forest to an array of descriptive features' + ) + parser.add_argument('impaths_file', type=str, metavar='imdir', + help='A .npy file containing the absolute paths to a set of images.') + parser.add_argument('labels_fpath', type=str, metavar='labels_fpath', + help='A .npy file containing the labels for the images in the impaths_file, valid labels are (good,bad,none).') + parser.add_argument('savedir', type=str, metavar='savedir', + help='Directory in which to save model weights and evaluation plots') + + args = parser.parse_args() + impaths_file = args.impaths_file + labels_fpath = args.labels_fpath + savedir = args.savedir + + # make sure the savedir exists + if not os.path.isdir(savedir): + os.mkdir(savedir) + + impaths = np.load(impaths_file) + gt_labels = np.load(labels_fpath) + + assert(len(impaths) == len(gt_labels)), "Number of impaths and labels are different!" + + # it's expected that the gt_labels were generated within a Jupyter notebook by + # using the the corrector.py labeling utilities + # in that case the labels are text with the possible options of "good", "bad", and "none" + # those with the label "none" are considered the unlabeled set and we make predictions + # about their labels using the random forest that we train on the labeled images + good_indices = np.where(gt_labels == 'good')[0] + bad_indices = np.where(gt_labels == 'bad')[0] + labeled_indices = np.concatenate([good_indices, bad_indices], axis=0) + + assert len(labeled_indices) >= 64, \ + f'Need at least 64 labeled patches to train model, got {len(labeled_indices)}' + + # fix the seed to pick validation set + np.random.seed(1227) + trn_indices, val_indices = train_test_split(labeled_indices, test_size=0.15) + + # unset the seed for random augmentations + np.random.seed(None) + + # str to int labels + labels = np.zeros((len(impaths), )) + labels[good_indices] = 1 + + # separate train and validation sets + trn_impaths = impaths[trn_indices] + trn_labels = labels[trn_indices] + + val_impaths = impaths[val_indices] + val_labels = labels[val_indices] + + # augmentations are carefully chosen such that the amount of distortion would not + # transform an otherwise "informative" patch into an "uninformative" patch + imsize = 224 + normalize = A.Normalize() # default is imagenet normalization + tfs = A.Compose([ + A.Resize(imsize, imsize), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), + A.GaussNoise(var_limit=(40, 100.0), p=0.5), + A.GaussianBlur(blur_limit=5, p=0.5), + A.HorizontalFlip(), + A.VerticalFlip(), + normalize, + ToTensorV2() + ]) + + eval_tfs = Compose([ + A.Resize(imsize, imsize), + normalize, + ToTensorV2() + ]) + + class SimpleDataset(Dataset): + def __init__(self, imfiles, labels, tfs=None): + super(SimpleDataset, self).__init__() + self.imfiles = imfiles + self.labels = labels + self.tfs = tfs + + def __len__(self): + return len(self.imfiles) + + def __getitem__(self, idx): + # load the image + image = cv2.imread(self.imfiles[idx], 0) + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + + # load the label + label = self.labels[idx] + + # apply transforms + if self.tfs is not None: + image = self.tfs(image=image)['image'] + + return {'image': image, 'label': label} + + # create datasets for the train and validation sets + trn_data = SimpleDataset(trn_impaths, trn_labels, tfs) + val_data = SimpleDataset(val_impaths, val_labels, eval_tfs) + + # create dataloaders + bsz = 64 + train = DataLoader(trn_data, batch_size=bsz, shuffle=True, pin_memory=True, drop_last=True, num_workers=4) + valid = DataLoader(val_data, batch_size=bsz, shuffle=False, pin_memory=True, num_workers=4) + + # create the model initialized with ImageNet weights + model = resnet34(pretrained=True) + + # freeze all parameters + for param in model.parameters(): + param.requires_grad = False + + # modify the output layer to predict 1 class only + model.fc = nn.Linear(in_features=512, out_features=1) + + # move the model to a cuda device + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + # unfreeze all parameters below the given finetune layer + finetune_layer = 'layer4' + backbone_groups = [mod[1] for mod in model.named_children()] + if finetune_layer != 'none': + layer_index = {'all': 0, 'layer1': 4, 'layer2': 5, 'layer3': 6, 'layer4': 7} + start_layer = layer_index[finetune_layer] + + #always finetune from the start layer to the last layer in the resnet + for group in backbone_groups[start_layer:]: + for param in group.parameters(): + param.requires_grad = True + + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + print(f'Using model with {params} trainable parameters!') + + criterion = nn.BCEWithLogitsLoss() + optimizer = Adam(model.parameters(), lr=1e-3) + + cudnn.benchmark = True + + def accuracy(output, labels): + output = output.squeeze() + labels = labels.squeeze() > 0 + + output = nn.Sigmoid()(output) > 0.5 + + # measure correct + correct = torch.sum(output == labels).float() + return (correct / len(labels)).item() + + + # runs model training and validation loops for 30 epochs + for epoch in range(30): + rl = 0 + ra = 0 + for data in tqdm(train): + images = data['image'].to(device, non_blocking=True) + labels = data['label'].to(device, non_blocking=True) + + optimizer.zero_grad() + + output = model.train()(images) + loss = criterion(output, labels.unsqueeze(1)) + + loss.backward() + + optimizer.step() + + rl += loss.item() + ra += accuracy(output, labels) + + print(f'Epoch {epoch + 1}, Loss {rl / len(train)}, Accuracy {ra / len(train)}') + + rl = 0 + ra = 0 + for data in valid: + images = data['image'].to(device, non_blocking=True) + labels = data['label'].to(device, non_blocking=True) + + output = model.eval()(images) + loss = criterion(output, labels.unsqueeze(1)) + rl += loss.item() + ra += accuracy(output, labels) + + print(f'Val Loss {rl / len(valid)}, Accuracy {ra / len(valid)}') + + torch.save(model.state_dict(), os.path.join(savedir, 'patch_quality_classifier_nn.pth')) + print(f'Model finished training, weights saved to {savedir}') + + # run more extensive validation and print results + print(f'Evaluating model predictions...') + val_predictions = [] + for data in tqdm(valid): + #load data onto gpu + images = data['image'].to(device, non_blocking=True) + + #forward + output = model.eval()(images) + pred = nn.Sigmoid()(output) + val_predictions.append(pred.detach().cpu().numpy()) + + val_predictions = np.concatenate(val_predictions, axis=0) + + tn, fp, fn, tp = confusion_matrix(val_labels, val_predictions > 0.5).ravel() + acc = accuracy_score(val_labels, val_predictions > 0.5) + + print(f'Total validation images: {len(val_data)}') + print(f'True Positives: {tp}') + print(f'True Negatives: {tn}') + print(f'False Positives: {fp}') + print(f'False Negatives: {fn}') + print(f'Accuracy: {acc}') + + fpr_nn, tpr_nn, _ = roc_curve(val_labels, val_predictions) + plt.plot(fpr_nn, tpr_nn, linewidth=8, label=f'ConvNet (AUC = {roc_auc_score(val_labels, val_predictions):.3f})') + plt.xlabel('False positive rate', labelpad=16, fontsize=18, fontweight="bold") + plt.xticks(fontsize=14) + plt.ylabel('True positive rate', labelpad=16, fontsize=18, fontweight="bold") + plt.yticks(fontsize=14) + plt.title('NN Patch Quality Classifier ROC Curve', fontdict={'fontsize': 22, 'fontweight': "bold"}, pad=24) + plt.tight_layout() + plt.legend(loc='best', fontsize=18) + plt.savefig(os.path.join(savedir, "patch_quality_nn_roc_curve.png")) + + print('Finished training patch classifier.') \ No newline at end of file From c17f9ebe4864e6d8929521fa2d163d5393f71302 Mon Sep 17 00:00:00 2001 From: conradry Date: Fri, 15 Apr 2022 11:55:47 -0400 Subject: [PATCH 10/19] removed duplicates --- dataset/cleanup2d.py | 129 -------- dataset/filtered/classify_nn.py | 182 ----------- .../random_forest/calculate_rf_features.py | 89 ------ dataset/filtered/random_forest/classify_rf.py | 118 ------- dataset/mrc2byte.py | 51 --- dataset/train_nn.py | 293 ------------------ dataset/vid2stack.py | 62 ---- 7 files changed, 924 deletions(-) delete mode 100644 dataset/cleanup2d.py delete mode 100644 dataset/filtered/classify_nn.py delete mode 100644 dataset/filtered/random_forest/calculate_rf_features.py delete mode 100644 dataset/filtered/random_forest/classify_rf.py delete mode 100644 dataset/mrc2byte.py delete mode 100644 dataset/train_nn.py delete mode 100644 dataset/vid2stack.py diff --git a/dataset/cleanup2d.py b/dataset/cleanup2d.py deleted file mode 100644 index db9b6c6..0000000 --- a/dataset/cleanup2d.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -Description: ------------- - -This script accepts a directory with 2d images and processes them -to make sure they all have unsigned 8-bit pixels. The resultant images -are saved in the given save directory. - -Importantly, the saved image files are given a slightly different filename: -We add '-LOC-2d' to the end of the filename. Once images from 2d and 3d datasets -start getting mixed together, it can be difficult to keep track of the -provenance of each patch. Everything that appears before '-LOC-2d' is the -name of the original dataset and we, of course, know that that dataset is -a 2d EM image. - -Example usage: --------------- - -python cleanup2d.py {imdir} {savedir} --processes 4 - -For help with arguments: ------------------------- - -python cleanup2d.py --help -""" - -import os, math -import argparse -import numpy as np -from glob import glob -from skimage.io import imread, imsave -from multiprocessing import Pool - -#main function of the script -if __name__ == "__main__": - - #setup the argument parser - parser = argparse.ArgumentParser(description='Create dataset for nn experimentation') - parser.add_argument('imdir', type=str, metavar='imdir', help='Directory containing 2d EM images') - parser.add_argument('savedir', type=str, metavar='savedir', help='Path to save the processed images') - parser.add_argument('-p', '--processes', dest='processes', type=int, metavar='processes', default=4, - help='Number of processes to run, more processes will run faster but consume more memory') - - - args = parser.parse_args() - - #read in the parser arguments - imdir = args.imdir - savedir = args.savedir - processes = args.processes - - #check if the savedir exists, if not create it - if not os.path.isdir(savedir): - os.mkdir(savedir) - - #get the list of all images of any format - fpaths = np.array(glob(imdir + '*')) - - print(f'Found {len(fpaths)} images') - - #loop over each fpath and save the slices - def process_image(fp): - #try to read the image, if it's not possible then pass - try: - im = imread(fp) - except: - print('Failed to open: ', fp) - pass - - #first check if we're working with signed or unsigned pixels - dtype = str(im.dtype) - - is_float = False - unsigned = False - if dtype[0] == 'u': - unsigned = True - elif dtype[0] == 'f': - is_float = True - - if dtype == 'uint8': #nothing to do - pass - else: - #get the number of bits per pixel - bits = int(dtype[-2]) #16, 32, or 64 - - #explicitly convert the image to float - im = im.astype('float') - - if unsigned: - #unsigned conversion just requires division by max and - #multiplication by 255 - im /= (2 ** bits) #scales from 0-1 - im *= 255 - elif unsigned is False and is_float is False: - #signed conversion adds the additional step of - #subtracting the minimum negative value - im -= -(2 ** (bits - 1)) - im /= (2 ** bits) #scales from 0-1 - im *= 255 - else: - #this means we're working with float. - #because the range is variable, we'll just subtract - #the minimum, divide by maximum, and multiply by 255 - #this means the pixel range is always 0-255 - im -= im.min() - im /= im.max() - im *= 255 - - #convert to uint8 - im = im.astype(np.uint8) - - #establish a filename prefix from the filepath - fext = fp.split('.')[-1] - exp_name = fp.split('/')[-1].split(f'.{fext}')[0] - - #create a new filename with the -LOC- identifier to help - #find the experiment name in later scripts, the 2d helps - #to distinguish this image from the cross sections of 3d volumes - im_name = f'{exp_name}-LOC-2d.tiff' - - #save the processed image - imsave(os.path.join(savedir, im_name), im, check_contrast=False) - - #running the function with multiple processes - #results in a much faster runtime - with Pool(processes) as pool: - pool.map(process_image, fpaths) - - print('Finished') \ No newline at end of file diff --git a/dataset/filtered/classify_nn.py b/dataset/filtered/classify_nn.py deleted file mode 100644 index 05c7cd5..0000000 --- a/dataset/filtered/classify_nn.py +++ /dev/null @@ -1,182 +0,0 @@ -""" -Description: ------------- - -Fits a ResNet34 model to images that have manually been labeled as "informative" or "uninformative". It's assumed that -images have been manually labeled using the corrector.py utilities running in a Jupyter notebook (see notebooks/labeling.ipynb). - -The results of this script are the roc curve plot on a randomly chosen validation set of images, the -model state dict as a .pth file and the model's predictions on all the remaining unlabeled images. - -Example usage: --------------- - -python classify_nn.py {impaths_file} {savedir} --labels {label_file} --weights {weights_file} - -For help with arguments: ------------------------- - -python classify_nn.py --help -""" - -DEFAULT_WEIGHTS = "https://www.dropbox.com/s/2libiwgx0qdgxqv/patch_quality_classifier_nn.pth?raw=1" - -import os, sys, cv2, argparse -import numpy as np -import dask.array as da -from sklearn.metrics import confusion_matrix, accuracy_score, roc_curve, roc_auc_score -from sklearn.model_selection import train_test_split - -import torch -import torch.nn as nn -import torch.backends.cudnn as cudnn -from torchvision.models import resnet34 -from torch.optim import Adam -from torch.utils.data import DataLoader, Dataset - -from albumentations import Compose, Normalize, Resize -from albumentations.pytorch import ToTensorV2 -from tqdm import tqdm - -#main function of the script -if __name__ == "__main__": - - #setup the argument parser - parser = argparse.ArgumentParser( - description='Classifies a set of images by fitting a random forest to an array of descriptive features' - ) - parser.add_argument('impaths_file', type=str, metavar='impaths_file', - help='Path to .npz dask array file containing patch filepaths (for example deduplicated_fpaths.npz)') - parser.add_argument('savedir', type=str, metavar='savedir', - help='Directory in which to save predictions') - parser.add_argument('--labels', type=str, metavar='labels', - help='Optional, path to array file containing image labels (informative or uninformative)') - parser.add_argument('--weights', type=str, metavar='weights', - help='Optional, path to nn weights file. The default is to download weights used in the paper.') - - #parse the arguments - args = parser.parse_args() - impaths_file = args.impaths_file - savedir = args.savedir - gt_labels = args.labels - weights = args.weights - - #make sure the savedir exists - if not os.path.isdir(savedir): - os.mkdir(savedir) - - #load the dask array - impaths = da.from_npy_stack(impaths_file) - - #load the labels array (if there is one) - if gt_labels is not None: - gt_labels = np.load(gt_labels) - else: - gt_labels = np.array(len(impaths) * ['none']) - - #sanity check that the number of labels and impaths are the same - assert(len(impaths) == len(gt_labels)), "Number of impaths and labels are different!" - - #it's expected that the gt_labels were generated within a Jupyter notebook by - #using the the corrector.py labeling utilities - #in that case the labels are text with the possible options of "good", "bad", and "none" - #those with the label "none" are considered the unlabeled set and we make predictions - #about their labels using the random forest that we train on the labeled images - good_indices = np.where(gt_labels == 'informative')[0] - bad_indices = np.where(gt_labels == 'uninformative')[0] - labeled_indices = np.concatenate([good_indices, bad_indices], axis=0) - unlabeled_indices = np.setdiff1d(range(len(impaths)), labeled_indices) - - #create the test set - tst_impaths = impaths[unlabeled_indices].compute() - - #set up evaluation transforms (assumes imagenet pretrained as default - #in train_nn.py) - imsize = 224 - normalize = Normalize() #default is imagenet normalization - eval_tfs = Compose([ - Resize(imsize, imsize), - normalize, - ToTensorV2() - ]) - - #make a basic dataset class for loading and augmenting images - #WITHOUT any labels - class SimpleDataset(Dataset): - def __init__(self, imfiles, tfs=None): - super(SimpleDataset, self).__init__() - self.imfiles = imfiles - self.tfs = tfs - - def __len__(self): - return len(self.imfiles) - - def __getitem__(self, idx): - #load the image - image = cv2.imread(self.imfiles[idx], 0) - image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) - - #apply transforms - if self.tfs is not None: - image = self.tfs(image=image)['image'] - - return {'image': image} - - #create datasets for the train, validation, and test sets - tst_data = SimpleDataset(tst_impaths, eval_tfs) - - #create the test dataload - test = DataLoader(tst_data, batch_size=128, shuffle=False, pin_memory=True, num_workers=4) - - #create the resnet34 model - model = resnet34() - - #modify the output layer to predict 1 class only - model.fc = nn.Linear(in_features=512, out_features=1) - - #load the weights from file or from online - if weights is not None: - state_dict = torch.load(weights) - else: - state_dict = torch.hub.load_state_dict_from_url(DEFAULT_WEIGHTS) - - #load in the weights (strictly) - msg = model.load_state_dict(state_dict) - - #move the model to a cuda device - model = model.cuda() - - #faster training - cudnn.benchmark = True - - #lastly run inference on the entire set of unlabeled images - print(f'Running inference on test set...') - tst_predictions = [] - for data in tqdm(test): - with torch.no_grad(): - #load data onto gpu - #then forward pass - images = data['image'].cuda(non_blocking=True) - - output = model.eval()(images) - pred = nn.Sigmoid()(output) - - tst_predictions.append(pred.detach().cpu().numpy()) - - tst_predictions = np.concatenate(tst_predictions, axis=0) - - #create an array of labels that are all zeros and fill in the values from a combination - #of the ground truth labels from training and validation sets and the predicted - #labels for unlabeled indices - #convert gt_labels from strings to integers - predicted_labels = (gt_labels == 'informative').astype(np.uint8) - predicted_labels[unlabeled_indices] = (tst_predictions[:, 0] > 0.5).astype(np.uint8) - - print(f'Saving predictions...') - np.save(os.path.join(savedir, "nn_predictions.npy"), predicted_labels) - - print(f'Saving filepaths...') - filtered_fpaths = da.from_array(impaths[predicted_labels == 1].compute()) - da.to_npy_stack(os.path.join(savedir, 'nn_filtered_fpaths.npz'), filtered_fpaths) - - print('Finished.') \ No newline at end of file diff --git a/dataset/filtered/random_forest/calculate_rf_features.py b/dataset/filtered/random_forest/calculate_rf_features.py deleted file mode 100644 index bd0fe4c..0000000 --- a/dataset/filtered/random_forest/calculate_rf_features.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Description: ------------- - -It is assumed that a deduplicated dataset has already by created -from deduplicate.py. This script calculates four image features with which to train a random -forest model for filtering "good" and "bad" quality patches. - -The impaths_file argument expects a dask array file containing fpaths to tiff images. For -example, the deduplicated_fpaths.npz file created by deduplicate.py. Results are saved -in the given savedir with the same name as the input dask array file, but with the -suffix _rf_features.npy instead. - -Example usage: --------------- - -python calculate_rf_features.py {path}/deduplicated_fpaths.npz {savedir} - -For help with arguments: ------------------------- - -python calculate_rf_features.py --help -""" - -import os -import numpy as np -import dask.array as da -from multiprocessing import Pool -from glob import glob - -from skimage import io -from skimage.morphology import square -from skimage.feature import canny, local_binary_pattern -from skimage.filters.rank import entropy, geometric_mean - -#main function of the script -if __name__ == "__main__": - - #setup the argument parser - parser = argparse.ArgumentParser(description='Create dataset for nn experimentation') - parser.add_argument('impaths_file', type=str, metavar='impaths_file', - help='Path to .npz dask array file containing patch filepaths (for example output of deduplicate.py)') - parser.add_argument('savedir', type=str, metavar='savedir', - help='Directory in which to save the array of calculated features') - - #parse the arguments - args = parser.parse_args() - impaths_file = args.impaths_file - savedir = args.savedir - - #make sure the savedir exists - if not os.path.isdir(savedir): - os.mkdir(savedir) - - #load the dask array - impaths = da.from_npy_stack(impaths_file) - - def calculate_features(imfile): - #these features are based on validation results that showed these four - #features were the most important out of a set of ~30 features - #for the classification of "good" and "bad" patches. although - #restricting ourselves to only these 4 features may result in - #slightly lower prediction accuracy, the time it takes to calculate - #30 features on ~1 million images is excessive - #call .compute() for dask array element - image = io.imread(imfile.compute()) - - features = [] - #first lbp stdev - features.append(local_binary_pattern(image, 8, 8).std()) - - #second median of geo. mean - features.append(np.median(geometric_mean(image, square(11)))) - - #third stdev of entropy - features.append(entropy(image, square(7)).std()) - - #fourth mean of the canny filter - features.append(canny(image, sigma=1).mean()) - return features - - with Pool(32) as pool: - features = np.array(list(pool.map(calculate_features, impaths))) - - #save the features as a .npy array in the save directory - #first, get the name of the list of impaths - source_name = impaths_file.split('/').split('.npz')[0] - save_path = os.path.join(savedir, f'{source_name}_rf_features.npy') - np.save(save_path, features) \ No newline at end of file diff --git a/dataset/filtered/random_forest/classify_rf.py b/dataset/filtered/random_forest/classify_rf.py deleted file mode 100644 index 348be8e..0000000 --- a/dataset/filtered/random_forest/classify_rf.py +++ /dev/null @@ -1,118 +0,0 @@ -""" -Description: ------------- - -Fits a random forest model to an array of features that describe images and uses that model to -predict whether unlabeled images are "good" or "bad" quality. It's assumed that the calculate_rf_features.py -script has already been run and that some images have been manually labeled using the corrector.py utilities -running in a Jupyter notebook. - -The results of this script are the roc curve plot on a randomly chosen validation set of images, the -model object as a .sav file and the model's predictions on all the features in the given features array. - -Example usage: --------------- - -python classify_rf.py {features_fpath} {labels_fpath} {savedir} - -For help with arguments: ------------------------- - -python classify_rf.py --help -""" - -import os, argparse, pickle -from sklearn.ensemble import RandomForestClassifier -from sklearn.metrics import confusion_matrix, accuracy_score, plot_roc_curve -from sklearn.model_selection import train_test_split - - -#main function of the script -if __name__ == "__main__": - - #setup the argument parser - parser = argparse.ArgumentParser(description='Classifies a set of images by fitting a random forest to an array of descriptive features') - parser.add_argument('features_fpath', type=str, metavar='features_fpath', help='Path to array file containing image features') - parser.add_argument('labels_fpath', type=str, metavar='labels_fpath', help='Path to array file containing image labels (good or bad)') - parser.add_argument('savedir', type=str, metavar='savedir', help='Directory in which to save model, roc curve, and predictions') - - #parse the arguments - args = parser.parse_args() - features_fpath = args.features_fpath - labels_fpath = args.labels_fpath - savedir = args.savedir - - #make sure the savedir exists - if not os.path.isdir(savedir): - os.mkdir(savedir) - - #load the features and labels arrays - features = np.load(features_fpath) - gt_labels = np.load(labels_fpath) - - #it's expected that the gt_labels were generated within a Jupyter notebook by - #using the the corrector.py labeling utilities - #in that case the labels are text with the possible options of "good", "bad", and "none" - #those with the label "none" are considered the unlabeled set and we make predictions - #about their labels using the random forest that we train on the labeled images - good_indices = np.where(gt_labels == 'good')[0] - bad_indices = np.where(gt_labels == 'bad')[0] - labeled_indices = np.concatenate([good_indices, bad_indices], axis=0) - unlabeled_indices = np.setdiff1d(range(len(features)), labeled_indices) - - #setting the random seed ensures that we're always comparing results against the same - #validation dataset (otherwise hyperparameter tuning results would be indecipherable) - np.random.seed(1227) - trn_indices, val_indices = train_test_split(labeled_indices, test_size=0.15) - - #convert the labels from text to integers (0 = "bad", 1= "good") - labels = np.zeros((len(features), )) - labels[good_indices] = 1 - - #separate train and validation sets - trn_features = features[trn_indices] - trn_labels = labels[trn_indices] - val_features = features[val_indices] - val_labels = labels[val_indices] - - #fit the random forest model to the training data - print(f'Fitting random forest model to {len(trn_features)} images...') - rf = RandomForestClassifier(n_estimators=100, max_depth=8, class_weight='balanced', min_samples_split=8) - rf = rf.fit(trn_features, trn_labels) - - #save the model object - pickle.dump(rf, open(os.path.join(savedir, 'random_forest.sav'), 'wb')) - - #evaluate the model on the heldout validation test - print(f'Evaluating model predictions...') - val_predictions_rf = rf.predict_proba(val_features)[:, 1] - tn, fp, fn, tp = confusion_matrix(val_labels, val_predictions_rf > 0.5).ravel() - acc = accuracy_score(val_labels, val_predictions_rf > 0.5) - - print(f'Total validation images: {len(val_features)}') - print(f'True Positives: {tp}') - print(f'True Negatives: {tn}') - print(f'False Positives: {fp}') - print(f'False Negatives: {fn}') - print(f'Accuracy: {acc}') - - #plot roc curve and save it - plot_roc_curve(rf, val_features, val_labels) - plt.savefig(os.path.join(savedir, "rf_roc_curve.png")) - - print(f'Predicting labels for {len(unlabeled_indices)} unlabeled images...') - tst_features = features[unlabeled_indices] - tst_predictions = rf.predict_proba(tst_features)[:, 1] - - #create an array of labels that are all zeros and fill in the values from a combination - #of the ground truth labels from training and validation sets and the predicted - #labels for unlabeled indices - predicted_labels = np.zeros(len(features), dtype=np.uint8) - predicted_labels[trn_indices] = trn_labels.astype(np.uint8) - predicted_labels[val_indices] = val_labels.astype(np.uint8) - predicted_labels[unlabeled_indices] = (tst_predictions > 0.5).astype(np.uint8) - - print(f'Saving results...') - np.save(os.path.join(savedir, "rf_predictions.npy"), predicted_labels) - - print('Finished.') \ No newline at end of file diff --git a/dataset/mrc2byte.py b/dataset/mrc2byte.py deleted file mode 100644 index 976e988..0000000 --- a/dataset/mrc2byte.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Description: ------------- - -Throughout this repository we use SimpleITK to load image volumes. The MRC -file format can sometimes cause issues when the files are saved with signed -bytes. To prevent errors in which images are cliiped from 0-127, it is necessary -to make the mrc volumes unsigned byte type. This script takes a directory -containing mrc files and performs that conversion using IMOD. - -For help downloading and installing IMOD, see: -https://bio3d.colorado.edu/imod/ - -Example usage: --------------- - -python mrc2byte.py {mrcdir} - -For help with arguments: ------------------------- - -python mrc2byte.py --help - -""" - -import os, argparse -import subprocess -from glob import glob - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('mrcdir', type=str, help='Directory containing mrc image volumes') - args = parser.parse_args() - - #read in the argument - mrcdir = args.mrcdir - - #gather the mrc filepaths - fnames = glob(os.path.join(mrcdir, '*.mrc')) - print('Found {} mrc files'.format(len(fnames))) - FNULL = open(os.devnull, 'w') - - for fn in fnames: - #create the IMOD command and run it - command = ['newstack', fn, fn, '-by', '0'] - subprocess.call(command, stdout=FNULL, stderr=subprocess.STDOUT) - - #IMOD won't overwrite the old file - #instead it renames it with a '~' at - #the end. here we remove that old file - os.remove(fn + '~') \ No newline at end of file diff --git a/dataset/train_nn.py b/dataset/train_nn.py deleted file mode 100644 index 1948df4..0000000 --- a/dataset/train_nn.py +++ /dev/null @@ -1,293 +0,0 @@ -""" -Description: ------------- - -Fits a ResNet34 model to images that have manually been labeled as "informative" or "uninformative". It's assumed that -images have been manually labeled using the corrector.py utilities running in a Jupyter notebook (see notebooks/labeling.ipynb). - -The results of this script are the roc curve plot on a randomly chosen validation set of images and the -model state dict as a .pth file. - -Example usage: --------------- - -python train_nn.py {impaths_file} {labels_fpath} {savedir} - -For help with arguments: ------------------------- - -python train_nn.py --help -""" - -import os, sys, cv2, argparse -import numpy as np -import dask.array as da -from sklearn.metrics import confusion_matrix, accuracy_score, roc_curve, roc_auc_score -from sklearn.model_selection import train_test_split - -import torch -import torch.nn as nn -import torch.backends.cudnn as cudnn -from torchvision.models import resnet34 -from torch.optim import Adam -from torch.utils.data import DataLoader, Dataset - -from albumentations import ( - Compose, PadIfNeeded, Normalize, HorizontalFlip, VerticalFlip, RandomBrightnessContrast, - CropNonEmptyMaskIfExists, GaussNoise, RandomBrightnessContrast, RandomResizedCrop, Rotate, RandomCrop, - GaussianBlur, CenterCrop, RandomGamma, ElasticTransform, Resize -) -from albumentations.pytorch import ToTensorV2 -from tqdm import tqdm - -#main function of the script -if __name__ == "__main__": - - #setup the argument parser - parser = argparse.ArgumentParser( - description='Classifies a set of images by fitting a random forest to an array of descriptive features' - ) - parser.add_argument('impaths_file', type=str, metavar='impaths_file', - help='Path to .npz dask array file containing patch filepaths (for example deduplicated_fpaths.npz)') - parser.add_argument('labels_fpath', type=str, metavar='labels_fpath', - help='Path to array file containing image labels (good or bad)') - parser.add_argument('savedir', type=str, metavar='savedir', - help='Directory in which to save model and evaluation plots') - - #parse the arguments - args = parser.parse_args() - impaths_file = args.impaths_file - labels_fpath = args.labels_fpath - savedir = args.savedir - - #make sure the savedir exists - if not os.path.isdir(savedir): - os.mkdir(savedir) - - #load the dask array - impaths = da.from_npy_stack(impaths_file) - - #load the labels array - gt_labels = np.load(labels_fpath) - - #sanity check that the number of labels and impaths are the same - assert(len(impaths) == len(gt_labels)), "Number of impaths and labels are different!" - - #it's expected that the gt_labels were generated within a Jupyter notebook by - #using the the corrector.py labeling utilities - #in that case the labels are text with the possible options of "informative", "uninformative", and "none" - #those with the label "none" are considered the unlabeled set and we make predictions - #about their labels using the random forest that we train on the labeled images - good_indices = np.where(gt_labels == 'informative')[0] - bad_indices = np.where(gt_labels == 'uninformative')[0] - labeled_indices = np.concatenate([good_indices, bad_indices], axis=0) - - #setting the random seed ensures that we're always comparing results against the same - #validation dataset (otherwise hyperparameter tuning results would be indecipherable) - #resetting the seed to something random ensures that there aren't issues with the - #random augmentations - np.random.seed(1227) - trn_indices, val_indices = train_test_split(labeled_indices, test_size=0.15) - np.random.seed(None) - - #convert the labels from text to integers (0 = "uninformative", 1= "informative") - labels = np.zeros((len(impaths), )) - labels[good_indices] = 1 - - #separate train and validation sets - trn_impaths = impaths[trn_indices].compute() - trn_labels = labels[trn_indices] - - val_impaths = impaths[val_indices].compute() - val_labels = labels[val_indices] - - #augmentations are carefully chosen such that the amount of distortion would not - #transform an otherwise "informative" patch into an "uninformative" patch (for example, by making it low contrast) - imsize = 224 - normalize = Normalize() #default is imagenet normalization - tfs = Compose([ - Resize(imsize, imsize), - RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), - GaussNoise(var_limit=(40, 100.0), p=0.5), - GaussianBlur(blur_limit=5, p=0.5), - HorizontalFlip(), - VerticalFlip(), - normalize, - ToTensorV2() - ]) - - eval_tfs = Compose([ - Resize(imsize, imsize), - normalize, - ToTensorV2() - ]) - - #make a basic dataset class for loading and augmenting images - class SimpleDataset(Dataset): - def __init__(self, imfiles, labels, tfs=None): - super(SimpleDataset, self).__init__() - self.imfiles = imfiles - self.labels = labels - self.tfs = tfs - - def __len__(self): - return len(self.imfiles) - - def __getitem__(self, idx): - #load the image - image = cv2.imread(self.imfiles[idx], 0) - image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) - - #load the label - label = self.labels[idx] - - #apply transforms - if self.tfs is not None: - image = self.tfs(image=image)['image'] - - return {'image': image, 'label': label} - - #create datasets for the train and validation sets - trn_data = SimpleDataset(trn_impaths, trn_labels, tfs) - val_data = SimpleDataset(val_impaths, val_labels, eval_tfs) - - #create dataloaders - bsz = 64 - train = DataLoader(trn_data, batch_size=bsz, shuffle=True, pin_memory=True, drop_last=True, num_workers=4) - valid = DataLoader(val_data, batch_size=bsz, shuffle=False, pin_memory=True, num_workers=4) - - #create the model initialized with ImageNet weights - model = resnet34(pretrained=True) - - #freeze all parameters - for param in model.parameters(): - param.requires_grad = False - - #modify the output layer to predict 1 class only - model.fc = nn.Linear(in_features=512, out_features=1) - - #move the model to a cuda device - model = model.cuda() - - #unfreeze all parameters below the given finetune layer - finetune_layer = 'layer4' - backbone_groups = [mod[1] for mod in model.named_children()] - if finetune_layer != 'none': - layer_index = {'all': 0, 'layer1': 4, 'layer2': 5, 'layer3': 6, 'layer4': 7} - start_layer = layer_index[finetune_layer] - - #always finetune from the start layer to the last layer in the resnet - for group in backbone_groups[start_layer:]: - for param in group.parameters(): - param.requires_grad = True - - - #print the number of trainable parameters - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) - params = sum([np.prod(p.size()) for p in model_parameters]) - print(f'Using model with {params} trainable parameters!') - - #define loss and optiimizer - criterion = nn.BCEWithLogitsLoss() - optimizer = Adam(model.parameters(), lr=1e-3) - - #faster training - cudnn.benchmark = True - - #basic accuracy calculation - def accuracy(output, labels): - #squeeze both - output = output.squeeze() - labels = labels.squeeze() > 0 - - #sigmoid output - output = nn.Sigmoid()(output) > 0.5 - - #measure correct - correct = torch.sum(output == labels).float() - return (correct / len(labels)).item() - - - #runs model training and validation loops for 20 epochs - for epoch in range(20): - rl = 0 - ra = 0 - for data in tqdm(train): - #load data onto gpu - images = data['image'].cuda(non_blocking=True) - labels = data['label'].cuda(non_blocking=True) - - #zero grad - optimizer.zero_grad() - - #forward - output = model.train()(images) - loss = criterion(output, labels.unsqueeze(1)) - - #backward - loss.backward() - - #step - optimizer.step() - - rl += loss.item() - ra += accuracy(output, labels) - - print(f'Epoch {epoch + 1}, Loss {rl / len(train)}, Accuracy {ra / len(train)}') - - rl = 0 - ra = 0 - for data in valid: - #load data onto gpu - images = data['image'].cuda(non_blocking=True) - labels = data['label'].cuda(non_blocking=True) - - #forward - output = model.eval()(images) - loss = criterion(output, labels.unsqueeze(1)) - rl += loss.item() - ra += accuracy(output, labels) - - print(f'Val Loss {rl / len(valid)}, Accuracy {ra / len(valid)}') - - #save the model - torch.save(model.state_dict(), os.path.join(savedir, 'patch_quality_classifier_nn.pth')) - print(f'Model finished training, weights saved to {savedir}') - - #run more extensive validation and print results - print(f'Evaluating model predictions...') - val_predictions = [] - for data in tqdm(valid): - #load data onto gpu - images = data['image'].cuda(non_blocking=True) - - #forward - output = model.eval()(images) - pred = nn.Sigmoid()(output) - val_predictions.append(pred.detach().cpu().numpy()) - - val_predictions = np.concatenate(val_predictions, axis=0) - - tn, fp, fn, tp = confusion_matrix(val_labels, val_predictions > 0.5).ravel() - acc = accuracy_score(val_labels, val_predictions > 0.5) - - print(f'Total validation images: {len(val_data)}') - print(f'True Positives: {tp}') - print(f'True Negatives: {tn}') - print(f'False Positives: {fp}') - print(f'False Negatives: {fn}') - print(f'Accuracy: {acc}') - - #plot roc curve and save it - fpr_nn, tpr_nn, _ = roc_curve(val_labels, val_predictions) - plt.plot(fpr_nn, tpr_nn, linewidth=8, label=f'ConvNet (AUC = {roc_auc_score(val_labels, val_predictions):.3f})') - plt.xlabel('False positive rate', labelpad=16, fontsize=18, fontweight="bold") - plt.xticks(fontsize=14) - plt.ylabel('True positive rate', labelpad=16, fontsize=18, fontweight="bold") - plt.yticks(fontsize=14) - plt.title('NN Patch Quality Classifier ROC Curve', fontdict={'fontsize': 22, 'fontweight': "bold"}, pad=24) - plt.tight_layout() - plt.legend(loc='best', fontsize=18) - plt.savefig(os.path.join(savedir, "patch_quality_nn_roc_curve.png")) - - print('Finished!') \ No newline at end of file diff --git a/dataset/vid2stack.py b/dataset/vid2stack.py deleted file mode 100644 index 9e2e8c7..0000000 --- a/dataset/vid2stack.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Description: ------------- - -Simple script for converting video files (mp4 and avi) into -nrrd image volumes. - -Example usage: --------------- - -python vid2stack.py {viddir} - -For help with arguments: ------------------------- - -python vid2stack.py --help - -""" - -import os, argparse -import cv2 -import numpy as np -import SimpleITK as sitk -from glob import glob - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('viddir', type=str, help='Directory containing video files: avi or mp4') - args = parser.parse_args() - - #read in the argument - viddir = args.viddir - - #get a list of all mp4 and avi filepaths - vidfiles = glob(os.path.join(viddir, '*.mp4')) - vidfiles = vidfiles + glob(os.path.join(viddir, '*.avi')) - - print(f'Found {len(vidfiles)} video files.') - - for vf in vidfiles: - #load the video into cv2 - cap = cv2.VideoCapture(vf) - - #load the first frame - success, frame = cap.read() - - #loop over the video frames and store them in a list. - #note that grayscale videos have 3 duplicate channels, - #we only extract the first of these channels - frames = [frame[:, :, 0]] - while success: - success, frame = cap.read() - if frame is not None: - frames.append(frame[:, :, 0]) - - #stack the frames with the z-axis in the first - #dimension - video = np.stack(frames, axis=0) - - #save as nrrd files - nrrd_path = '.'.join(vf.split('.')[:-1]) + '.nrrd' - sitk.WriteImage(sitk.GetImageFromArray(video), nrrd_path) \ No newline at end of file From 35dffdc3268c98e96cefd7d00239a03b531bb931 Mon Sep 17 00:00:00 2001 From: conradry Date: Fri, 15 Apr 2022 11:56:36 -0400 Subject: [PATCH 11/19] update readme --- dataset/README.md | 69 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/dataset/README.md b/dataset/README.md index 3c97d29..5edb95c 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -1,30 +1,69 @@ # Dataset Curation -Beginning with a collection of 2d and 3d EM images, the scripts in this directory handle all of the dataset preprocessing and curation. +The scripts in this directory handle all dataset preprocessing and curation. Below is an example workflow, read the script headers for more details. Or to see all available parameters use: -For preliminary data preparation the vid2stack.py and mrc2byte.py scripts in the preprocess directory convert videos into image volumes and mrc volumes from signed to unsigned bytes. +```bash +python {script_name}.py --help +``` -The main curation pipeline starts with the cross-sectioning of 3d volumes into 2d image slices that can be combined together with any 2d EM datasets. Cross-sectioning for 3d data is handled by the raw/cross_section3d.py script and some basic type checking and file renaming for 2d data is done by the raw/cleanup2d.py script. It's recommended that 2d and 3d "corpuses" of data be kept in separate directories in order to ensure that the two scripts run smoothly; however, the outputs from the raw/cleanup2d.py and raw/cross_section3d.py scripts should all be saved in the same directory. The collection of 2d that results, which are all 8-bit unsigned tiffs, can then be cropped into patches of a given size using the raw/crop_patches.py script. In summary the first step of the workflow is: +**Note: The ```patchify2d.py```, ```patchify3d.py```, and ```classify_patches.py``` scripts are all designed for continuous integration. Datasets that have been processed previously and are in the designated output directories will be ignored by all of them.** -1. Run raw/cleanup2d.py on directory of 2d EM images. Save results to *save_dir* -2. Run raw/cross_section3d.py on directory of 3d EM images. Save results to *save_dir* -3. Run raw/rop_patches.py on images in *save_dir*. Save results to *raw_save_dir* +## 2D Data Preparation -The completion of this first step in the workflow yields the *Raw* dataset. Note that the raw/crop_patches.py sript not only creates tiff images for each of the patches, but also creates a numpy array of the patch's difference hash. The hashes are used for deduplication. +2D images are expected to be organized into directories, where each directory contains a group of images generated +as part of the same imaging project or at least with roughly the same biological metadata. -Deduplication uses the deduplicated/deduplicate.py script. As input the script expects *raw_save_dir* containing the .tiff images and .npy hashes. If new data is added to the *raw_save_dir* after the deduplication script has already been run, the script will only deduplicate the new datasets. This makes it easy to add new datasets without the somewhat time-consuming burden of rerunning deduplication for the entire *Raw* dataset. In summary: +First, standardize images to single channel grayscale and unsigned 8-bit: -1. Run deduplicated/deduplicate.py on *raw_save_dir*. Save results, which are .npy files for each 2d/3d dataset that contain a list of filepaths for exemplar images, to *deduplicated_save_dir*. +```bash +# make copies in new_directory +python preprocess/cleanup2d.py {dir_of_2d_image_groups} -o {new_directory} --processes 4 +# or, instead, overwrite images inplace +python preprocess/cleanup2d.py {dir_of_2d_image_groups} --processes 4 +``` +Second, crop each image into fixed size patches (typically 224x224): -In addition to .npy files for each datasets, the script also outputs a dask array file called deduplicated_fpaths.npz that contains the list of file paths for exemplar images from all 2d/3d datasets. This collection of file paths defines the *Deduplicated* dataset. +```bash +python patchify2d.py {dir_of_2d_image_groups} {patch_dir} -cs 224 --processes 4 +``` -In the last curation step, uninformative patches are filtered out using a ResNet34 classifier. The filtered/train_nn.py script trains the classifier on a collection of manually labeled image files contained in deduplicated_fpaths.npz. It is assumed that the labeling was performed using the labeling.ipynb notebook included in this repository. In general, training a new classifier shouldn't be necessary; we release the weights for the classifer that we trained on 12,000 labeled images. The filtered/classify_nn.py script performs inference on the set of unlabeled images in deduplicated_fpaths.npz. By default, the script will download and use the weights that we released. In summary: +The ```patchify2d.py``` script will save a ```.pkl``` file with the name of each 2D image subdirectory. Pickle files contain a dictionary of patches from all images in the subdirectory along with corresponding filenames. These files are ready for filtering (see below). -1. (Optional) Manually label images in deduplicated_fpaths.npz using labeling.ipynb. -2. (Optional) Run filtered/train_nn.py to train and evaluate a ResNet34 on the images labeled in step 1. -3. Run filtered/classify.py on images images in deduplicated_fpaths.npz. Save dask array of all informative images, nn_filtered_fpaths.npz, to *filtered_save_dir*. +## Video Preparation -These last steps result in the *Filtered* dataset. That's the complete curation pipeline. An optional last step, to generate 3d data, is to run the 3d/reconstruct3d.py script. This script takes the set of filtered images, nn_filtered_fpaths.npz, and the original directory of 3d volumes (i.e. the directory given to cross_section3d.py earlier) and makes data volumes of a given z-thickness. Note that one limitation of this script is that it currently assumes patches are 224x224. +Convert videos in ```.avi``` or ```.mp4``` format to ```.nrrd``` images with correct naming convention (i.e., put the word 'video' in the filename). +```bash +python preprocess/vid2stack.py {dir_of_videos} +``` + +## 3D Data Preparation + +3D datasets are expected to be in a single directory (this includes any video stacks created in the previous section). Supported formats are anything that can be [read by SimpleITK](https://simpleitk.readthedocs.io/en/v1.2.3/Documentation/docs/source/IO.html). It's important that if any volumes are in ```.mrc``` format they be converted to unsigned bytes. With IMOD installed this can be done using: + +```bash +python preprocess/mrc2byte.py {dir_of_mrc_files} +``` + +Next, cross-section, patch, and deduplicate volume files. If processing a combination of isotropic and anisotropic volumes, it's crucial that each dataset has a correct header recording the voxel size. If Z resolution is greater that 25% +different from xy resolution, then cross-sections will only be cut from the xy plane, even if axes 0, 1, 2 are passed to the script (see usage example below). + +```bash +python patchify3d.py {dir_of_3d_datasets} {patch_dir} -cs 224 --axes 0 1 2 --processes 4 +``` + +The ```patchify3d.py``` script will save a ```.pkl``` file with the name of each volume file. Pickle files contain a dictionary of patches along with corresponding filenames. These files are ready for filtering (see below). + +## Filtering + +2D, video, and 3D datasets can be filtered with the same script just put all the ```.pkl``` files in the same directory. By default, filtering uses a ResNet34 model that was trained on 12,000 manually annotated patches. The weights for this model are downloaded from [Zenodo](https://zenodo.org/record/6458015#.YlmNaS-cbTR) automatically. A new model can be trained, if needed, using the ```train_patch_classifier.py``` script. + +Filtering will be fastest with a GPU installed, but it's not required. + +```bash +python classify_patches.py {patch_dir} {save_dir} +``` + +After running filtering, the ```save_dir``` with have one subdirectory for each of the ```.pkl``` files that were processed. Each subdirectory contains single channel grayscale, unsigned 8-bit tiff images. From fd79e5a6300120dcaa36c95bd19678000633092f Mon Sep 17 00:00:00 2001 From: conradry Date: Fri, 15 Apr 2022 12:00:17 -0400 Subject: [PATCH 12/19] swav with weighted sampling --- pretraining/swav/dataset.py | 68 +++- pretraining/swav/models/__init__.py | 2 + pretraining/swav/models/regnet.py | 462 ++++++++++++++++++++++ pretraining/swav/{ => models}/resnet50.py | 0 pretraining/swav/sampler.py | 170 ++++++++ pretraining/swav/train_swav.py | 17 +- 6 files changed, 700 insertions(+), 19 deletions(-) create mode 100644 pretraining/swav/models/__init__.py create mode 100644 pretraining/swav/models/regnet.py rename pretraining/swav/{ => models}/resnet50.py (100%) create mode 100644 pretraining/swav/sampler.py diff --git a/pretraining/swav/dataset.py b/pretraining/swav/dataset.py index 5a694dc..0422028 100644 --- a/pretraining/swav/dataset.py +++ b/pretraining/swav/dataset.py @@ -17,30 +17,70 @@ class MultiCropDataset(Dataset): def __init__( self, - data_path, - transforms + data_dir, + transforms, + weight_gamma=None ): super(MultiCropDataset, self).__init__() - self.data_path = data_path + self.data_dir = data_dir - manifest_file = os.path.join(data_path, 'manifest.pkl') - if os.path.isfile(manifest_file): - with open(manifest_file, mode='rb') as f: - self.fpaths = pickle.load(f) + self.subdirs = [] + for sd in os.listdir(data_dir): + if os.path.isdir(os.path.join(data_dir, sd)): + self.subdirs.append(sd) + + # images and masks as dicts ordered by subdirectory + self.paths_dict = {} + for sd in self.subdirs: + sd_fps = glob(os.path.join(data_dir, f'{sd}/*.tiff')) + if len(sd_fps) > 0: + self.paths_dict[sd] = sd_fps + + # calculate weights per example, if weight gamma is not None + self.weight_gamma = weight_gamma + if weight_gamma is not None: + self.weights = self._example_weights(self.paths_dict, gamma=weight_gamma) else: - self.fpaths = glob(data_path + '**/*') - with open(manifest_file, mode='wb') as f: - pickle.dump(self.fpaths, f) + self.weights = None + + # unpack dicts to lists of images + self.paths = [] + for paths in self.paths_dict.values(): + self.paths.extend(paths) + + print(f'Found {len(self.subdirs)} subdirectories with {len(self.paths)} images.') - print(f'Found {len(self.fpaths)} images in dataset.') self.tfs = transforms def __len__(self): - return len(self.fpaths) + return len(self.paths) + + @staticmethod + def _example_weights(paths_dict, gamma=0.3): + # counts by source subdirectory + counts = np.array( + [len(paths) for paths in paths_dict.values()] + ) + + # invert and gamma the distribution + weights = 1 / counts + weights = weights ** gamma + + # for interpretation, normalize weights + # s.t. they sum to 1 + total_weights = weights.sum() + weights /= total_weights + + # repeat weights per n images + example_weights = [] + for w,c in zip(weights, counts): + example_weights.extend([w] * c) + + return torch.tensor(example_weights) def __getitem__(self, index): - #get the filepath to load - f = self.fpaths[index] + # get the filepath to load + f = self.paths[index] # process multiple transformed crops of the image image = Image.open(f) diff --git a/pretraining/swav/models/__init__.py b/pretraining/swav/models/__init__.py new file mode 100644 index 0000000..f5edcd6 --- /dev/null +++ b/pretraining/swav/models/__init__.py @@ -0,0 +1,2 @@ +from models.resnet50 import * +from models.regnet import * diff --git a/pretraining/swav/models/regnet.py b/pretraining/swav/models/regnet.py new file mode 100644 index 0000000..eb559f8 --- /dev/null +++ b/pretraining/swav/models/regnet.py @@ -0,0 +1,462 @@ +""" +RegNet models from https://arxiv.org/abs/2103.06877 and +https://github.com/facebookresearch/pycls/blob/main/pycls/models/anynet.py + +TODO: +Add scaling rules from RegNetZ +Add correct initialization for ResNet +""" + +import numpy as np +import torch +import torch.nn as nn + +__all__ = [ + 'RegNet', + 'regnetx_6p4gf', + 'regnety_200mf', + 'regnety_800mf', + 'regnety_3p2gf', + 'regnety_4gf', + 'regnety_6p4gf', + 'regnety_8gf', + 'regnety_16gf' +] + +def init_weights(m): + """Performs ResNet-style weight initialization.""" + if isinstance(m, nn.Conv2d): + # Note that there is no bias due to BN + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out)) + elif isinstance(m, nn.BatchNorm2d): + zero_init_gamma = hasattr(m, "final_bn") and m.final_bn + m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) + m.bias.data.zero_() + +def conv_bn_act( + nin, + nout, + kernel_size, + stride=1, + groups=1, + activation=nn.ReLU(inplace=True) +): + padding = (kernel_size - 1) // 2 + # regular convolution and batchnorm + layers = [ + nn.Conv2d(nin, nout, kernel_size, stride=stride, padding=padding, groups=groups, bias=False), + nn.BatchNorm2d(nout) + ] + + # add activation if necessary + if activation: + layers.append(activation) + + return nn.Sequential(*layers) + +class Resample2d(nn.Module): + def __init__( + self, + nin, + nout, + stride=1, + activation=None + ): + super(Resample2d, self).__init__() + + # convolution to downsample channels, if needed + if nin != nout or stride > 1: + self.conv = conv_bn_act(nin, nout, kernel_size=1, stride=stride, activation=activation) + else: + self.conv = nn.Identity() + + def forward(self, x): + x = self.conv(x) + return x + +class SqueezeExcite(nn.Module): + def __init__(self, nin): + super(SqueezeExcite, self).__init__() + self.avg_pool = nn.AvgPool2d((1, 1)) + + # hard code the squeeze factor at 4 + ns = nin // 4 + self.se = nn.Sequential( + nn.Conv2d(nin, ns, 1, bias=True), # squeeze + nn.ReLU(inplace=True), + nn.Conv2d(ns, nin, 1, bias=True), # excite + nn.Sigmoid() + ) + + def forward(self, x): + return x * self.se(self.avg_pool(x)) + +class Stem(nn.Module): + """ + Simple input stem. + """ + def __init__(self, w_in, w_out, kernel_size=3): + super(Stem, self).__init__() + self.cbr = conv_bn_act(w_in, w_out, kernel_size, stride=2) + + def forward(self, x): + x = self.cbr(x) + return x + +class Bottleneck(nn.Module): + """ + ResNet-style bottleneck block. + """ + def __init__( + self, + w_in, + w_out, + bottle_ratio=1, + groups=1, + stride=1, + use_se=False + ): + super(Bottleneck, self).__init__() + w_b = int(round(w_out * bottle_ratio)) + self.a = conv_bn_act(w_in, w_b, 1) + self.b = conv_bn_act(w_b, w_b, 3, stride=stride, groups=groups) + self.se = SqueezeExcite(w_b) if use_se else None + self.c = conv_bn_act(w_b, w_out, 1, activation=None) + self.c[1].final_bn = True # layer 1 is the BN layer + + def forward(self, x): + for layer in self.children(): + x = layer(x) + + return x + +class BottleneckBlock(nn.Module): + def __init__( + self, + w_in, + w_out, + bottle_ratio, + groups=1, + stride=1, + use_se=False + ): + super(BottleneckBlock, self).__init__() + self.bottleneck = Bottleneck(w_in, w_out, bottle_ratio, groups, stride, use_se) + self.downsample = Resample2d(w_in, w_out, stride=stride) + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + return self.act(self.downsample(x) + self.bottleneck(x)) + +class Stage(nn.Module): + def __init__( + self, + block, + w_in, + w_out, + depth, + bottle_r=1, + groups=1, + stride=1, + use_se=False + ): + super(Stage, self).__init__() + + assert depth > 0, "Each stage has minimum depth of 1 layer." + + for i in range(depth): + if i == 0: + # only the first layer in a stage + # has expansion and downsampling + layer = block(w_in, w_out, bottle_r, groups, stride, use_se=use_se) + else: + layer = block(w_out, w_out, bottle_r, groups, use_se=use_se) + + self.add_module(f'block{i + 1}', layer) + + def forward(self, x): + for layer in self.children(): + x = layer(x) + + return x + +class MultiPrototypes(nn.Module): + def __init__(self, output_dim, nmb_prototypes): + super(MultiPrototypes, self).__init__() + self.nmb_heads = len(nmb_prototypes) + for i, k in enumerate(nmb_prototypes): + self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False)) + + def forward(self, x): + out = [] + for i in range(self.nmb_heads): + out.append(getattr(self, "prototypes" + str(i))(x)) + return out + +class RegNet(nn.Module): + """ + Simplest RegNetX/Y-like encoder without classification head + """ + def __init__( + self, + cfg, + im_channels=1, + output_stride=32, + block=BottleneckBlock, + normalize=False, + output_dim=0, + hidden_mlp=0, + nmb_prototypes=0, + eval_mode=False + ): + super(RegNet, self).__init__() + + assert output_stride in [16, 32] + if output_stride == 16: + cfg.strides[-1] = 1 + + # make the stages with correct widths and depths + self.cfg = cfg + groups = cfg.groups + depths = cfg.depths + w_ins = [cfg.w_stem] + cfg.widths[:-1] + w_outs = cfg.widths + strides = cfg.strides + use_se = cfg.use_se + + self.eval_mode = eval_mode + self.padding = nn.ConstantPad2d(1, 0.0) + + self.stem = Stem(im_channels, cfg.w_stem, kernel_size=3) + + for i in range(cfg.num_stages): + stage = Stage(block, w_ins[i], w_outs[i], depths[i], + groups=groups[i], stride=strides[i], use_se=use_se) + + self.add_module(f'stage{i + 1}', stage) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + # normalize output features + self.l2norm = normalize + + # projection head + num_out_filters = w_outs[-1] + if output_dim == 0: + self.projection_head = None + elif hidden_mlp == 0: + self.projection_head = nn.Linear(num_out_filters, output_dim) + else: + self.projection_head = nn.Sequential( + nn.Linear(num_out_filters, hidden_mlp), + nn.BatchNorm1d(hidden_mlp), + nn.ReLU(inplace=True), + nn.Linear(hidden_mlp, output_dim), + ) + + # prototype layer + self.prototypes = None + if isinstance(nmb_prototypes, list): + self.prototypes = MultiPrototypes(output_dim, nmb_prototypes) + elif nmb_prototypes > 0: + self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False) + + self.apply(init_weights) + + def forward_backbone(self, x): + x = self.padding(x) + + x = self.stem(x) + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + + if self.eval_mode: + return x + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + return x + + def forward_head(self, x): + if self.projection_head is not None: + x = self.projection_head(x) + + if self.l2norm: + x = nn.functional.normalize(x, dim=1, p=2) + + if self.prototypes is not None: + return x, self.prototypes(x) + + return x + + def forward(self, inputs): + if not isinstance(inputs, list): + inputs = [inputs] + + idx_crops = torch.cumsum(torch.unique_consecutive( + torch.tensor([inp.shape[-1] for inp in inputs]), + return_counts=True, + )[1], 0) + + start_idx = 0 + for end_idx in idx_crops: + _out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True)) + if start_idx == 0: + output = _out + else: + output = torch.cat((output, _out)) + + start_idx = end_idx + + return self.forward_head(output) + +class RegNetConfig: + w_stem = 32 + bottle_ratio = 1 + strides = [2, 2, 2, 2] + + def __init__( + self, + depth, + w_0, + w_a, + w_m, + group_w, + q=8, + use_se=False, + **kwargs + ): + assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 + self.depth = depth + self.w_0 = w_0 + self.w_a = w_a + self.w_m = w_m + self.group_w = group_w + self.q = q + self.use_se = use_se + + for k,v in kwargs.items(): + setattr(self, k, v) + + self.set_params() + self.adjust_params() + + def adjust_params(self): + """ + Adjusts widths and groups to guarantee compatibility. + """ + ws = self.widths + gws = self.group_widths + b = self.bottle_ratio + + adj_ws = [] + adj_groups = [] + for w, gw in zip(ws, gws): + # group width can't exceed width + # in the bottleneck + w_b = int(max(1, w * b)) + gw = int(min(gw, w_b)) + + # fix width s.t. it is always divisible by + # group width for any bottleneck_ratio + m = np.lcm(gw, b) if b > 1 else gw + w_b = max(m, int(m * round(w_b / m))) + w = int(w_b / b) + + adj_ws.append(w) + adj_groups.append(w_b // gw) + + assert all(w * b % g == 0 for w, g in zip(adj_ws, adj_groups)) + self.widths = adj_ws + self.groups = adj_groups + + def set_params(self): + """ + Generates RegNet parameters following: + https://arxiv.org/pdf/2003.13678.pdf + """ + # capitals for complete sets + # widths of blocks + U = self.w_0 + np.arange(self.depth) * self.w_a # eqn (2) + + # quantize stages by solving eqn (3) for sj + S = np.round( + np.log(U / self.w_0) / np.log(self.w_m) + ) + + # block widths from eqn (4) + W = self.w_0 * np.power(self.w_m, S) + + # round the widths to nearest factor of q + # (makes best use of tensor cores) + W = self.q * np.round(W / self.q).astype(int) + + # group stages by the quantized widths, use + # as many stages as there are unique widths + W, D = np.unique(W, return_counts=True) + assert len(W) == 4, "Bad parameters, only 4 stage networks allowed!" + + self.num_stages = len(W) + self.group_widths = len(W) * [self.group_w] + self.widths = W.tolist() + self.depths = D.tolist() + +def regnetx_6p4gf(**kwargs): + params = { + 'depth': 17, 'w_0': 184, 'w_a': 60.83, + 'w_m': 2.07, 'group_w': 56 + } + return RegNet(RegNetConfig(**params, **kwargs), block=BottleneckBlock, **kwargs) + +def regnety_200mf(**kwargs): + params = { + 'depth': 13, 'w_0': 24, 'w_a': 36.44, + 'w_m': 2.49, 'group_w': 8 + } + return RegNet(RegNetConfig(**params, **kwargs), block=BottleneckBlock, **kwargs) + +def regnety_800mf(**kwargs): + params = { + 'depth': 14, 'w_0': 56, 'w_a': 38.84, + 'w_m': 2.4, 'group_w': 16 + } + return RegNet(RegNetConfig(**params, **kwargs), block=BottleneckBlock, **kwargs) + +def regnety_3p2gf(**kwargs): + params = { + 'depth': 21, 'w_0': 80, 'w_a': 42.63, + 'w_m': 2.66, 'group_w': 24 + } + return RegNet(RegNetConfig(**params, **kwargs), block=BottleneckBlock, **kwargs) + +def regnety_4gf(**kwargs): + params = { + 'depth': 22, 'w_0': 96, 'w_a': 31.41, + 'w_m': 2.24, 'group_w': 64 + } + return RegNet(RegNetConfig(**params, **kwargs), block=BottleneckBlock, **kwargs) + +def regnety_6p4gf(**kwargs): + params = { + 'depth': 25, 'w_0': 112, 'w_a': 33.22, + 'w_m': 2.27, 'group_w': 72, 'use_se': True + } + return RegNet(RegNetConfig(**params, **kwargs), block=BottleneckBlock, **kwargs) + +def regnety_8gf(**kwargs): + params = { + 'depth': 17, 'w_0': 192, 'w_a': 76.82, + 'w_m': 2.19, 'group_w': 56, 'use_se': True + } + return RegNet(RegNetConfig(**params, **kwargs), block=BottleneckBlock, **kwargs) + +def regnety_16gf(**kwargs): + params = { + 'depth': 18, 'w_0': 200, 'w_a': 106.23, + 'w_m': 2.48, 'group_w': 112, 'use_se': True + } + return RegNet(RegNetConfig(**params, **kwargs), block=BottleneckBlock, **kwargs) \ No newline at end of file diff --git a/pretraining/swav/resnet50.py b/pretraining/swav/models/resnet50.py similarity index 100% rename from pretraining/swav/resnet50.py rename to pretraining/swav/models/resnet50.py diff --git a/pretraining/swav/sampler.py b/pretraining/swav/sampler.py new file mode 100644 index 0000000..ca467f2 --- /dev/null +++ b/pretraining/swav/sampler.py @@ -0,0 +1,170 @@ +import math +import torch +import torch.distributed as dist +from torch.utils.data import Sampler + +class DistributedWeightedSampler(Sampler): + """ + Adapted from https://discuss.pytorch.org/t/how-to-use-my-own-sampler-when-i-already-use-distributedsampler/62143/7. + """ + def __init__( + self, + dataset, + weights, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True + ): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + self.weights = weights + + if self.drop_last and len(self.dataset) % self.num_replicas != 0: + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) + + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + rand_tensor = torch.multinomial( + self.weights[self.rank:self.total_size:self.num_replicas], + self.num_samples + ) + + return iter(rand_tensor.tolist()) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + +class DistributedEvalSampler(Sampler): + r""" + DistributedEvalSampler is different from DistributedSampler. + It does NOT add extra samples to make it evenly divisible. + DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. + See this issue for details: https://github.com/pytorch/pytorch/issues/22584 + shuffle is disabled by default + + DistributedEvalSampler is for evaluation purpose where synchronization does not happen every epoch. + Synchronization should be done outside the dataloader loop. + + Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each + process can pass a :class`~torch.utils.data.DistributedSampler` instance as a + :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the + original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`rank` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + + .. warning:: + In distributed mode, calling the :meth`set_epoch(epoch) ` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + Example:: + + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__(self, dataset, num_replicas=None, rank=None): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + # self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + # self.total_size = self.num_samples * self.num_replicas + self.total_size = len(self.dataset) # true value without extra samples + indices = list(range(self.total_size)) + indices = indices[self.rank:self.total_size:self.num_replicas] + self.num_samples = len(indices) # true value without extra samples + + + def __iter__(self): + indices = list(range(len(self.dataset))) + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/pretraining/swav/train_swav.py b/pretraining/swav/train_swav.py index 02d904a..ed226fa 100644 --- a/pretraining/swav/train_swav.py +++ b/pretraining/swav/train_swav.py @@ -27,7 +27,9 @@ from torch.cuda.amp import autocast from torch.cuda.amp import GradScaler -import resnet50 as resnet_models +import models as models + +from sampler import DistributedWeightedSampler from LARC import LARC from utils import ( fix_random_seeds, @@ -129,10 +131,15 @@ def print_pass(*args): # build data train_dataset = MultiCropDataset( args.data_path, - transforms + transforms, + args.weight_gamma ) - sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + if train_dataset.weights is None: + sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + sampler = DistributedWeightedSampler(train_dataset, train_dataset.weights) + train_loader = torch.utils.data.DataLoader( train_dataset, sampler=sampler, @@ -143,7 +150,7 @@ def print_pass(*args): ) # build model - model = resnet_models.__dict__[args.arch]( + model = models.__dict__[args.arch]( normalize=True, hidden_mlp=args.hidden_mlp, output_dim=args.feat_dim, @@ -230,6 +237,7 @@ def print_pass(*args): mlflow.log_param('temperature', args.temperature) mlflow.log_param('feature_dim', args.feat_dim) mlflow.log_param('queue_length', args.queue_length) + mlflow.log_param('weight_gamma', args.weight_gamma) else: # resume existing run mlflow.start_run(run_id=run_id) @@ -258,7 +266,6 @@ def print_pass(*args): # train the network scores, queue = train(train_loader, model, optimizer, scaler, epoch, lr_schedule, queue, args) - training_stats.update(scores) # save checkpoints if args.rank == 0: From f5d69fa05557f6ae6a34adbe5bfc06844f4bc8f3 Mon Sep 17 00:00:00 2001 From: conradry Date: Fri, 15 Apr 2022 12:18:12 -0400 Subject: [PATCH 13/19] updated mocov2 dataset --- pretraining/README.md | 19 +++- pretraining/mocov2/dataset.py | 86 +++++++++++++++---- pretraining/mocov2/mocov2_config.yaml | 13 +-- pretraining/mocov2/train_mocov2.py | 10 +-- pretraining/swav/models/__init__.py | 2 +- .../swav/models/{resnet50.py => resnet.py} | 0 pretraining/swav/swav_config.yaml | 21 ++--- 7 files changed, 109 insertions(+), 42 deletions(-) rename pretraining/swav/models/{resnet50.py => resnet.py} (100%) diff --git a/pretraining/README.md b/pretraining/README.md index 9d4a423..ca2f41d 100644 --- a/pretraining/README.md +++ b/pretraining/README.md @@ -1,11 +1,22 @@ # Pretraining -## MoCoV2 +Before getting started download the latest CEM dataset from EMPIAR. At minimum you'll need access to a system +with 4 high-end GPUs (P100 or V100). Typically, pre-training takes 4-5 days (depending on the size of the dataset). + +## SwAV -To run pretraining you'll need to have downloaded the CEM500K data. Update the data_file and model_dir parameters in the mocov2_config.yaml file. Then run: +To run pretraining with SwAV, first update the ```data_path``` and ```model_path``` parameters in ```swav/swav_config.yaml```, then run: +```bash +python swav/train_swav.py swav_config.yaml ``` -python train_mocov2.py mocov2_config.yaml + +## MoCoV2 + +To run pretraining with MoCoV2, first update the ```data_path``` and ```model_path``` parameters in ```mocov2/mocov2_config.yaml```, then run: + +```bash +python swav/train_swav.py swav_config.yaml ``` -The script was tested on machines with either 4 NVidia V100s or P100s. Runtime for a full 200 epochs on CEM500K is 3-4 days (faster if using V100s). + diff --git a/pretraining/mocov2/dataset.py b/pretraining/mocov2/dataset.py index 17fe51b..2fbe046 100644 --- a/pretraining/mocov2/dataset.py +++ b/pretraining/mocov2/dataset.py @@ -7,27 +7,70 @@ from torch.utils.data import Dataset from PIL import Image from PIL import ImageFilter - -class EMData(Dataset): - """ - Dataset class for loading and augmenting unsupervised data. - """ - def __init__(self, fpaths_dask_array, tfs): +class EMData(Dataset): + def __init__( + self, + data_dir, + transforms, + weight_gamma=None + ): super(EMData, self).__init__() - #self.fpaths_dask_array = fpaths_dask_array - self.fpaths = np.array(glob(os.path.join(fpaths_dask_array, '*.tiff'))) - self.tfs = tfs + self.data_dir = data_dir + + self.subdirs = [] + for sd in os.listdir(data_dir): + if os.path.isdir(os.path.join(data_dir, sd)): + self.subdirs.append(sd) + + # images and masks as dicts ordered by subdirectory + self.paths_dict = {} + for sd in self.subdirs: + sd_fps = glob(os.path.join(data_dir, f'{sd}/*.tiff')) + if len(sd_fps) > 0: + self.paths_dict[sd] = sd_fps - benchmarks = ['urocell', 'guay', 'cremi', 'perez', 'lucchi', 'kasthuri'] - for bnk in benchmarks: - indices = np.where(np.core.defchararray.find(self.fpaths, bnk) == -1)[0] - self.fpaths = self.fpaths[indices] + # calculate weights per example, if weight gamma is not None + self.weight_gamma = weight_gamma + if weight_gamma is not None: + self.weights = self._example_weights(self.paths_dict, gamma=weight_gamma) + else: + self.weights = None - print(f'Found {len(self.fpaths)} tiff images') + # unpack dicts to lists of images + self.paths = [] + for paths in self.paths_dict.values(): + self.paths.extend(paths) + + print(f'Found {len(self.subdirs)} subdirectories with {len(self.paths)} images.') + + self.tfs = transforms def __len__(self): - return len(self.fpaths) + return len(self.paths) + + @staticmethod + def _example_weights(paths_dict, gamma=0.3): + # counts by source subdirectory + counts = np.array( + [len(paths) for paths in paths_dict.values()] + ) + + # invert and gamma the distribution + weights = 1 / counts + weights = weights ** gamma + + # for interpretation, normalize weights + # s.t. they sum to 1 + total_weights = weights.sum() + weights /= total_weights + + # repeat weights per n images + example_weights = [] + for w,c in zip(weights, counts): + example_weights.extend([w] * c) + + return torch.tensor(example_weights) def __getitem__(self, idx): #get the filepath to load @@ -44,6 +87,19 @@ def __getitem__(self, idx): #the channel dimension, we'll split it later return torch.cat([image1, image2], dim=0) + def __getitem__(self, index): + # get the filepath to load + f = self.paths[index] + + # process multiple transformed crops of the image + image = Image.open(f) + + # transform the images + image1 = self.tfs(image) + image2 = self.tfs(image) + + return torch.cat([image1, image2], dim=0) + class GaussianBlur: """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" diff --git a/pretraining/mocov2/mocov2_config.yaml b/pretraining/mocov2/mocov2_config.yaml index 8fb58d4..a95ad21 100644 --- a/pretraining/mocov2/mocov2_config.yaml +++ b/pretraining/mocov2/mocov2_config.yaml @@ -1,7 +1,8 @@ #basic definitions -experiment_name: "Filtered_CellEMNet_MoCoV2_No_Benchmarks" -data_file: "/data/IASEM/conradrw/data/cem500k/" # a .npz dask array of filepaths -model_dir: "/data/conradrw/FEM_No_Benchmarks/" # the directory in which to save model states +experiment_name: "MoCo_CEM" +data_file: "cem_dataset/" +model_dir: "models/" + arch: "resnet50" workers: 16 epochs: 200 @@ -15,7 +16,7 @@ schedule: momentum: 0.9 weight_decay: 0.0001 -resume: "/data/conradrw/FEM_No_Benchmarks/current.pth.tar" +resume: "" world_size: 1 rank: 0 dist_url: "tcp://localhost:10001" @@ -32,5 +33,5 @@ cos: False logging: True norms: - mean: 0.58331613 - std: 0.09966064 \ No newline at end of file + mean: 0.57287007 + std: 0.12740536 \ No newline at end of file diff --git a/pretraining/mocov2/train_mocov2.py b/pretraining/mocov2/train_mocov2.py index 15890c4..06f1520 100755 --- a/pretraining/mocov2/train_mocov2.py +++ b/pretraining/mocov2/train_mocov2.py @@ -33,12 +33,11 @@ import torch.multiprocessing as mp import torch.utils.data import torch.utils.data.distributed -import torchvision.models as models import torchvision.transforms as tf import mocov2.builder as builder from mocov2.dataset import EMData, GaussianBlur, GaussNoise -from ..resnet import resnet50 +import resnet as models import mlflow @@ -105,9 +104,8 @@ def print_pass(*args): print("=> creating model '{}'".format(config['arch'])) - #hardcoding the resnet50 for the time being model = builder.MoCo( - resnet50, + models.__dict[config['arch']], config['moco_dim'], config['moco_k'], config['moco_m'], config['moco_t'], config['mlp'] ) @@ -191,7 +189,7 @@ def print_pass(*args): normalize ]) - train_dataset = EMData(config['data_file'], augmentation) + train_dataset = EMData(config['data_path'], augmentation) if config['distributed']: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) @@ -214,7 +212,7 @@ def print_pass(*args): #we don't want to add everything in the config #to mlflow parameters, we'll just add the most #likely to change parameters - mlflow.log_param('data_file', config['data_file']) + mlflow.log_param('data_path', config['data_path']) mlflow.log_param('architecture', config['arch']) mlflow.log_param('epochs', config['epochs']) mlflow.log_param('batch_size', config['batch_size']) diff --git a/pretraining/swav/models/__init__.py b/pretraining/swav/models/__init__.py index f5edcd6..22919eb 100644 --- a/pretraining/swav/models/__init__.py +++ b/pretraining/swav/models/__init__.py @@ -1,2 +1,2 @@ -from models.resnet50 import * +from models.resnet import * from models.regnet import * diff --git a/pretraining/swav/models/resnet50.py b/pretraining/swav/models/resnet.py similarity index 100% rename from pretraining/swav/models/resnet50.py rename to pretraining/swav/models/resnet.py diff --git a/pretraining/swav/swav_config.yaml b/pretraining/swav/swav_config.yaml index 33f5cc3..794dd6f 100644 --- a/pretraining/swav/swav_config.yaml +++ b/pretraining/swav/swav_config.yaml @@ -1,26 +1,26 @@ # training parameters -experiment_name: "SWaV_CEM1.5M" -data_path: "/data/IASEM/conradrw/data/cem_datasets/v1_cem/cem1.4M/" -model_path: "/data/IASEM/conradrw/models/SWaV_cem1.4M/" +experiment_name: "SWaV_CEM" +data_path: "cem_dataset/" +model_path: "models/" -print_freq: 1 +print_freq: 500 arch: "resnet50" hidden_mlp: 2048 -workers: 16 +workers: 8 checkpoint_freq: 25 use_fp16: True seed: 1447 resume: null -epochs: 400 +epochs: 200 warmup_epochs: 0 start_warmup: 0 batch_size: 64 -base_lr: 0.015 +base_lr: 0.6 final_lr: 0.0006 -wd: 0.000001 freeze_prototypes_niters: 5005 +wd: 0.000001 # distributed training parameters world_size: 1 @@ -53,6 +53,7 @@ nmb_prototypes: 3000 queue_length: 3840 epoch_queue_starts: 15 +weight_gamma: 0.5 norms: - - 0.58331613 - - 0.09966064 + - 0.575710 + - 0.127650 From 92a7327c3e7411d1cd88c024a68e80fcfb7dc2eb Mon Sep 17 00:00:00 2001 From: conradry Date: Fri, 15 Apr 2022 12:58:22 -0400 Subject: [PATCH 14/19] weights to zenodo --- evaluation/resources/utils.py | 48 +++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/evaluation/resources/utils.py b/evaluation/resources/utils.py index c105b8d..00df5ad 100644 --- a/evaluation/resources/utils.py +++ b/evaluation/resources/utils.py @@ -1,9 +1,10 @@ +import torch import torch.hub __all__ = ['moco_to_unet_prefixes', 'load_pretrained_state_for_unet'] cellemnet_moco_model_urls = { - 'resnet50': 'https://www.dropbox.com/s/bqw4h2x23v3cgup/cellemnet_filtered_moco_v2_200ep.pth.tar?raw=1' + 'resnet50': 'https://zenodo.org/record/6453140/files/cem500k_mocov2_resnet50_200ep.pth.tar?download=1' } imagenet_moco_model_urls = { @@ -24,25 +25,62 @@ def moco_to_unet_prefixes(state_dict): return state_dict +def swav_to_unet_prefixes(state_dict): + # rename moco pre-trained keys + for k in list(state_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('module.') and not k.startswith('module.projection_head'): + #for unet, we need to remove module.encoder_q. from the prefix + #and add encoder instead + state_dict['encoder.' + k[len("module."):]] = state_dict[k] + + # delete renamed or unused k + del state_dict[k] + + return state_dict + +def pixpro_to_unet_prefixes(state_dict): + # rename pixpro pre-trained keys + for k in list(state_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('module.encoder.backbone.'): + #for unet, we need to remove module.encoder_q. from the prefix + #and add encoder instead + state_dict['encoder.' + k[len("module.encoder.backbone."):]] = state_dict[k] + + # delete renamed or unused k + del state_dict[k] + + return state_dict + def load_pretrained_state_for_unet(model_name='resnet50', pretraining='cellemnet_mocov2'): #validate the pretraining dataset name - if pretraining not in ['cellemnet_mocov2', 'imagenet_mocov2']: + if pretraining not in ['cellemnet_mocov2', 'imagenet_mocov2', 'cellemnet_pixpro']: raise Exception(f'Pretraining must be either cellemnet_mocov2 or imagenet_mocov2, got {pretraining}') #get the url if pretraining == 'cellemnet_mocov2': url = cellemnet_moco_model_urls[model_name] + model_state = torch.hub.load_state_dict_from_url(url) + state_dict = model_state['state_dict'] + state_dict = moco_to_unet_prefixes(state_dict) + + elif pretraining == 'cellemnet_pixpro': + model_state = torch.load('/data/IASEM/conradrw/models/CEM500K_PixPro100/current.pth.tar') + state_dict = model_state['state_dict'] + state_dict = pixpro_to_unet_prefixes(state_dict) else: url = imagenet_moco_model_urls[model_name] + state_dict = torch.hub.load_state_dict_from_url(url) #download and save the weights to TORCH_HOME; after initial download, weight are #loaded from that path instead - model_state = torch.hub.load_state_dict_from_url(url) + #model_state = torch.hub.load_state_dict_from_url(url) # rename moco pre-trained keys - state_dict = model_state['state_dict'] - state_dict = moco_to_unet_prefixes(state_dict) + #state_dict = model_state['state_dict'] + #state_dict = moco_to_unet_prefixes(state_dict) #return both the model state_dict and the norms used during training #if there are no norms, then we return None and will assume ImageNet From 45685798ee7465f5ca2cf1313a5a711d41e6ad9d Mon Sep 17 00:00:00 2001 From: conradry Date: Tue, 13 Dec 2022 12:15:30 -0500 Subject: [PATCH 15/19] flipbook and subvolume reconstruction --- dataset/3d/reconstruct3d.py | 498 +++++++++++------------------------- 1 file changed, 143 insertions(+), 355 deletions(-) diff --git a/dataset/3d/reconstruct3d.py b/dataset/3d/reconstruct3d.py index e450783..6c09bac 100644 --- a/dataset/3d/reconstruct3d.py +++ b/dataset/3d/reconstruct3d.py @@ -2,19 +2,20 @@ Description: ------------ -It is assumed that this script will be run after the cross_section3d.py and -crop_patches.py scripts. Errors are certain to occur if that is not the case. +It is assumed that this script will be run after the patchify3d.py and +classify_patches.py scripts. Errors are certain to occur if that is not the case. -This script takes a dask array of 2d image filepaths and a directory of 3d image -volumes. It is assumed that at least some of the 2d images are cross sections from -the 3d volumes in the given directory. The cross_section3d.py creates a "trail" in -the 2d image filenames that make it easy to find their associated 3d volume. +This script takes a directory of filtered 2D image patches and corresponding directories containing +complete 3D reconstructions. It is assumed that at least some of the images in the filtered +directory are cross-sections from the 3D volumes. Matching is performed based on the subdirectory +names and volume names. Multiple volumes may exist for each subdirectory name. This occurs when +the the volumes are actually ROIs from the same dataset. In these cases, it's essential that the +'-ROI-' identifier exists in each subvolume name. -Although it may seem circuitous to go from 3d volumes to 2d images back to 3d -volumes, this method allows us to use the simple and fast 2d filtering algorithms -to effectively filter out uninformative regions of a 3d volume. Reconstructions -are only performed in regions of the volume that contain a few "informative" 2d -patches. +Volumes must be in a format readable by SimpleITK (primarily .mrc, .nrrd, .tif, etc.). THIS +SCRIPT DOES NOT SUPPORT (OME)ZARRs. As a rule, such datasets are usually created because of +their large size. Sparsely sampled ROIs from such NGFF-style datasets can be saved in one of +the supported formats using the ../scraping/ngff_download.py script. Example usage: -------------- @@ -29,380 +30,167 @@ import os, argparse, math import numpy as np -import dask.array as da import SimpleITK as sitk +from skimage import io from glob import glob from multiprocessing import Pool -#main function of the script if __name__ == "__main__": - - #setup the argument parser parser = argparse.ArgumentParser(description='Create dataset for nn experimentation') - parser.add_argument('filtered_impaths_file', type=str, metavar='filtered_impaths_file', help='Dask array file with filtered impaths') - parser.add_argument('volume_dir', type=str, metavar='volume_dir', help='Directory containing source EM volumes') - parser.add_argument('savedir', type=str, metavar='savedir', + parser.add_argument('filtered_dir', type=str, metavar='filtered_dir', help='Filtered image directory') + parser.add_argument('-vd', type=str, dest='volume_dirs', metavar='volume_dirs', nargs='+', + help='Directories containing source EM volumes') + parser.add_argument('-sd', type=str, metavar='savedir', dest='savedir', help='Path to save 3d reconstructions') - parser.add_argument('-nz', dest='nz', type=int, metavar='nz', default=224, + parser.add_argument('-nz', dest='nz', type=int, metavar='nz', default=5, help='Number of z slices in reconstruction') parser.add_argument('-p', '--processes', dest='processes', type=int, metavar='processes', default=32, help='Number of processes to run, more processes run faster but consume memory') - parser.add_argument('--cross-plane', dest='cross_plane', action='store_true', - help='Whether to create 3d volumes sliced orthogonal to imaging plane, useful when nz < image_shape') + parser.add_argument('--limit', dest='limit', type=int, metavar='limit', + help='Maximum number of reconstructions per volume.') - #parse the arguments args = parser.parse_args() - filtered_impaths_file = args.filtered_impaths_file - volume_dir = args.volume_dir + filtered_dir = args.filtered_dir + volume_dirs = args.volume_dirs savedir = args.savedir numberz = args.nz processes = args.processes - cross_plane = args.cross_plane + limit = args.limit - #to avoid running this long script only to get a nasty error - #let's make sure that the savedir exists if not os.path.isdir(savedir): os.mkdir(savedir) - - #glob the volumes - volume_paths = glob(os.path.join(volume_dir, '*')) - print(f'Found {len(volume_paths)} in {volume_dir}') - #extract the volume names - #NOTE: this is the same code used to generate the names - #from cross_section3d.py - volume_names = [] - for vp in volume_paths: - fext = vp.split('.')[-1] if vp[-5:] != 'nii.gz' else 'nii.gz' - volume_names.append(vp.split('/')[-1].split(f'.{fext}')[0]) + # images stored in subdirectories by source volume + img_fdirs = glob(os.path.join(filtered_dir, '*')) + img_fpaths_dict = {} + for fdir in img_fdirs: + source_name = fdir.split('/')[-1] + fnames = np.array([os.path.basename(f) for f in glob(os.path.join(fdir, '*.tiff'))]) + if limit is not None: + fnames = np.random.choice(fnames, min(limit, len(fnames)), replace=False) + + img_fpaths_dict[source_name] = fnames - volume_names = np.array(volume_names) - - #convert filtered to numpy straightaway - #dask.array doesn't have good support for string operations - filtered_impaths = da.from_npy_stack(filtered_impaths_file).compute() - - #the first thing that we need to do it to isolate - #images from 3d source datasets. during creation we - #gave 2d files the handy identifying -LOC-2d- - source3d = np.where(np.core.defchararray.find(filtered_impaths, '-LOC-2d') == -1) - print(f'Isolated {len(source3d[0])} images from 3d volumes out of {len(filtered_impaths)}') - - #overwrite filtered_impaths to save space - #and sort the results such that images from the same - #source datasets are grouped together - filtered_impaths = np.sort(filtered_impaths[source3d]) - - #just as in the deduplication script, we want to group - #together images from the same source volume - def get_dataset_name(imf): - #function to extract the name of a dataset from the patch image file path - #in the cross_section.py script we added the handy -LOC- indicator to - #easily identify the source dataset from location information - return imf.split('/')[-1].split('-LOC-')[0] - - #extract the set of unique dataset names from all the impaths - with Pool(processes) as pool: - datasets = np.sort(pool.map(get_dataset_name, filtered_impaths)) - - #because we sorted the impaths, we know that all images from the - #same dataset will be grouped together. therefore, we only need - #to know the index of the first instance of a unique dataset name - #in order to get the indices of all the patches from that dataset - unq_datasets, indices = np.unique(datasets, return_index=True) - - #we can delete the datasets array - del datasets - - #add the last index for impaths such that we have complete intervals - #len(indices) == len(unq_datasets) + 1 - indices = np.append(indices, len(filtered_impaths)) - - #get the intersect of unq_datasets and volume_names - intersect_datasets, unq_indices, _ = np.intersect1d(unq_datasets, volume_names, return_indices=True) - start_indices = indices[:-1][unq_indices] - end_indices = indices[1:][unq_indices] - - #make groups of image patches by source dataset - groups_impaths = [] - for si, ei in zip(start_indices, end_indices): - #have to call .compute() for a dask array - groups_impaths.append(filtered_impaths[si:ei]) + print(f'Found {len(img_fpaths_dict.keys())} image directories.') - #we can delete the filtered_impaths and the indices - del filtered_impaths, indices + # volumes may be in multiple directories + volume_fpaths = [] + for voldir in volume_dirs: + volume_fpaths.extend(glob(os.path.join(voldir, '*'))) + print(f'Found {len(volume_fpaths)} source volumes.') - #define a function for non-maxmium suppression of - #boxes in 3d - def box_nms(boxes, scores, iou_threshold=0.2): - #order the boxes by scores in descending - #order (i.e. highest scores first) - boxes = boxes[np.argsort(scores)[::-1]] - - #create a new list to save picked boxes - picked_boxes = [] - - #loop over boxes picking the highest scoring box - #at each step and then eliminating any overlapping - #boxes. continue until all the boxes have been exhausted - while len(boxes) > 0: - #pick the bounding box with largest confidence score - #which will always be the first one in what's left of the - #array (because we sorted boxes by score earlier) - picked_boxes.append(boxes[0]) - - #extract the coordinates from the boxes - #(N, 6) --> 6 * (N, 1) - z1, y1, x1, z2, y2, x2 = np.split(boxes, 6, axis=1) - - #calculate the volumes of all the remaining boxes - #(N, 1) (by construction in this script, volumes will - #all be the same; this NMS function is generic though). - volumes = (z2 - z1) * (y2 - y1) * (x2 - x1) - - #compute the intersections over union - #between the first box and all boxes - #IoU between the first box and itself will be 1 - #but we've already saved it - zz1 = np.maximum(z1[0], z1) - yy1 = np.maximum(y1[0], y1) - xx1 = np.maximum(x1[0], x1) - zz2 = np.minimum(z2[0], z2) - yy2 = np.minimum(y2[0], y2) - xx2 = np.minimum(x2[0], x2) - - #compute the volume of all the intersections - d = np.maximum(0, zz2 - zz1) - h = np.maximum(0, yy2 - yy1) - w = np.maximum(0, xx2 - xx1) - intersection_volumes = (d * h * w) - - #compute intersection over unions (N, 1) - union_volumes = (volumes[0] + volumes - intersection_volumes) - ious = intersection_volumes / union_volumes + def find_children(vol_fpath): + """ + Finds child images from a source volume. + """ + # name of volume + volname = os.path.basename(vol_fpath) + + # strip the suffix + suffix = volname.split('.')[-1] + assert (suffix in ['mrc', 'nrrd']), \ + f"Found invalid volume file type: {suffix}" + volname = '.'.join(volname.split('.')[:-1]) + + # directory with images will be + # volname or prefix before -ROI- + if '-ROI-' in volname: + dirname = volname.split('-ROI-')[0] + else: + dirname = volname - #indices of boxes to be removed - remove_boxes = np.where(ious >= iou_threshold)[0] - - #update boxes (N, 6) --> (N-RB, 6) - boxes = np.delete(boxes, remove_boxes, axis=0) - - return np.array(picked_boxes).astype('int') + # lookup the img_source + #assert(dirname in img_fpaths_dict), \ + #f"Directory {dirname} not found in image paths!" + if dirname not in img_fpaths_dict: + return [], dirname + + img_fpaths = img_fpaths_dict[dirname] + vol_img_fpaths = [] + for fp in img_fpaths: + if volname in fp: + vol_img_fpaths.append(fp) + + return vol_img_fpaths, dirname - #alright now that all the setup is out of the way we want to - #suppress images that would result in overlapping volumes. - #by default we won't allow any overlap from stacks generated - #from slices in the same plane, but we will allow up to 20% - #overlap when slices come from different planes - #(20% percent is assuming using MoCo pretraining with 20%-100% - #sized crops. This overlap criterion makes it less likely to have - #identical cubes from two separate volumes.) - def save_box_volumes(volume, dataset_name, boxes, is_isotropic): - #convert to numpy - volume = sitk.GetArrayFromImage(volume) - - if len(volume.shape) == 4: - volume = volume[:, :, :, 0] - - #for filenames, get the digit padding - zpad = math.ceil(math.log(volume.shape[0], 10)) - ypad = math.ceil(math.log(volume.shape[1], 10)) - xpad = math.ceil(math.log(volume.shape[2], 10)) - - #the box indices are in the right format for - #numpy dimensions because the each slices metadata - #was constructed based on numpy dimensions (see cross_section3d.py) - for box in boxes: - z1, y1, x1, z2, y2, x2 = box - zstr, ystr, xstr = str(z1).zfill(zpad), str(y1).zfill(ypad), str(x1).zfill(xpad) + def extract_subvolume(volume, img_fpath): + # extract location of image from filename + img_fpath = os.path.basename(img_fpath) + volname, loc = img_fpath.split('-LOC-') + loc = loc.split('.tiff')[0] + + # first the axis + axis, index, yrange, xrange = loc.split('_') + + # convert to integers + # NOTE: these indices are for a numpy array! + axis = int(axis) + index = int(index) + yslice = slice(*[int(s) for s in yrange.split('-')]) + xslice = slice(*[int(s) for s in xrange.split('-')]) + + # expand the to include range + # around index + span = numberz // 2 + lowz = index - span + highz = index + span + 1 + + # pass images that don't have enough + # context to be annotated as a flipbook + if lowz < 0 or highz >= volume.shape[axis]: + return None, None + else: + axis_span = slice(lowz, highz) - #extract the subvolume - subvolume = volume[z1:z2, y1:y2, x1:x2] - - #if we're allowing cross plane, then we transpose - #the dimensions such the - if cross_plane: - #order the dimensions from smallest to largest - dim_order = np.argsort(subvolume.shape) - dim_names = {0: 'z', 1: 'y', 2: 'x'} - dim_str = ''.join([dim_names[d] for d in dim_order]) - subvolume = np.transpose(subvolume, tuple(dim_order)) + if axis == 0: + flipbook = volume[axis_span, yslice, xslice] + elif axis == 1: + flipbook = volume[yslice, axis_span, xslice] + flipbook = flipbook.transpose(1, 0, 2) + elif axis == 2: + flipbook = volume[yslice, xslice, axis_span] + flipbook = flipbook.transpose(2, 0, 1) else: - dim_str = 'zyx' - - #create slightly different file names if the dataset is - #isotropic or not - if is_isotropic: - fname = f'{dataset_name}-LOC-3d-ISO-{zstr}_{ystr}_{xstr}_{dim_str}.npy' - else: - fname = f'{dataset_name}-LOC-3d-ANISO-{zstr}_{ystr}_{xstr}_{dim_str}.npy' + raise Exception(f'Axis cannot be {axis}, must be in [0, 1, 2]') + flipbook_fname = f'{volname}-LOC-{axis}_{lowz}-{highz}_{yrange}_{xrange}' - #make sure that all dimensions are at least greater than - #numberz // 2 - if all(s >= numberz // 2 for s in subvolume.shape): - np.save(os.path.join(savedir, fname), subvolume) - - - #first handle images from the same planes - def overlap_suppression(impath_group): - #at this point impath_group contains a bunch of paths - #to images from th same dataset. in order to - #extract areas in the 3d dataset to sample, we next want - #to group them by "columns". a column is a group of images - #that were sliced from the same plane and in the same - #y and x location from within that plane. before we can - #handle any of this we need to extract the metadata from the - #filenames that appears after the -LOC- indicator - #ex: {dataset_name}-LOC-{axis}_{slice_index}_{ys}_{xs}.tiff - dataset_name = impath_group[0].split('/')[-1].split('-LOC-')[0] - - #if the dataset_name is not in volume names, then - #we're going to skip it (this is in case the there are multiple - #directories that containing source volumes that generated the images) - #for example, we used 1 directory of volumes called 'internal' and - #another called 'external'. meaning that this script needs to be - #run twice (specifying the different volume_dir for each one) - #find the index of the dataset_name - #in volume_names - try: - vol_index = np.where(volume_names == dataset_name)[0][0] - print(f'Found volume {dataset_name} at {vol_index}') - except: - print(f'Could not find volume {dataset_name}') - return None - - #load the volume - volume = sitk.ReadImage(volume_paths[vol_index]) - - #get it's size in numpy style - volsize = sitk.GetArrayFromImage(volume).shape - - axes = [] - slice_indices = [] - ys = [] - xs = [] - for f in impath_group: - location_info = f.split('-LOC-')[-1].split('_') - axes.append(int(location_info[0])) - slice_indices.append(int(location_info[1])) - ys.append(int(location_info[2])) - xs.append(int(location_info[3].split('.tiff')[0])) - - axes = np.array(axes) - slice_indices = np.array(slice_indices) - ys = np.array(ys) - xs = np.array(xs) - - #we're going to make a reasonable assumption that - #if the set of images from a source volume only - #has slices from axis 0 then that volume is likely - #to be anisotropic. Ideally, this information would - #be available from metadata on the volume; however, - #this isn't alway the case. When we perform the cross - #sectioning and filtering though, we tend to remove - #images that have anisotropic pixels (the filtering nn - #was trained to recognize these images as containing "artifacts") - #so let's use this info: if unique axes are 0, then anisotropic - #volume, otherwise, isotropic volume - is_isotropic = True - unique_axes = np.unique(axes) - if len(unique_axes) == 1 and 0 in unique_axes: - is_isotropic = False - - #there are a fewer nested loops that we need to run - #through. the first are the plane axes: 0-yx, 1-zx, 2-zy - boxes = [] - scores = [] - for axis in [0, 1, 2]: - #get all the indices of images sliced from - #the given axis (because of sorting they ought - #to be contiguous) - axis_indices = np.where(axes == axis)[0] - if len(axis_indices) == 0: - continue + return flipbook, flipbook_fname + + def create_flipbooks(vp): + children, dirname = find_children(vp) + + vol_savedir = os.path.join(savedir, dirname) + if os.path.isdir(vol_savedir): + print('Skipping', dirname) + return + + if children: + # load the volume and convert to numpy + volume = sitk.ReadImage(vp) + volume = sitk.GetArrayFromImage(volume) - #extract unique pairs of y and x coordinates from - #this axis - axis_slice_indices = slice_indices[axis_indices] - axis_ys = ys[axis_indices] - axis_xs = xs[axis_indices] - unique_2d = np.unique(np.stack([axis_ys, axis_xs], axis=1), axis=0) + if volume.ndim > 3: + volume = volume[..., 0] - #the inner loop is to go through pairs of ys and xs - #the so called "columns" - for y,x in unique_2d: - #extract the indices of the slices that - #remained in the filtered impaths - column_indices = np.where(np.logical_and(axis_ys == y, axis_xs == x))[0] - column_slice_indices = axis_slice_indices[column_indices] - - #construct intervals of numberz thickness that run - #from 0 to the largest slice index. - intervals = np.arange(0, volsize[axis], numberz) - - #complete the intervals by appending the last index - #from an additional interval that includes the maximum - #slice index - bins = np.append(intervals, volsize[axis]) - - #we're about finished. last step is to append - #the left edge of every interval that contains at least 1 - #one of the axis_slice_indices - #the simplest way to do this is with a histogram - counts, _ = np.histogram(column_slice_indices, bins=bins) - #print(counts) - - #define the minimum number of informative - #slices that must exist within an interval for - #it to be considered informative - min_count = numberz // 10 - - #size is (n_good_intervals, 6) - #(zstart, ystart, xstart, zend, yend, xend) - column_boxes = np.zeros((len(intervals[counts > min_count]), 6)) - - #size is (n_good_intervals, 6) - #(zstart, ystart, xstart, zend, yend, xend) - column_boxes = np.zeros((len(intervals[counts > min_count]), 6)) - column_boxes[:, axis] = intervals[counts > min_count] - column_boxes[:, axis + 3] = intervals[counts > min_count] + numberz - - #of all the axes remove the one we're currently analyzing - #e.g. testing axis 1: axes == [0, 1, 2] --> axes == [0, 2] - #--> y_axis == 0, x_axis == 2 - y_axis, x_axis = np.delete(np.arange(3), axis) - - #this assumes that the cropped image is 224x224 - #TODO: can this be adaptive without loading the image? - column_boxes[:, y_axis] = y - column_boxes[:, y_axis + 3] = y + 224 - column_boxes[:, x_axis] = x - column_boxes[:, x_axis + 3] = x + 224 + if np.any(np.array(volume.shape) < numberz): + raise Exception(f'Flipbooks of size {numberz} cannot be created from {vp} with size {volume.shape}') - scores.extend(counts[counts > min_count]) - boxes.extend(column_boxes) + # directory in which to save flipbooks + # from this volume dataset + if not os.path.isdir(vol_savedir): + os.makedirs(vol_savedir, exist_ok=True) - #convert boxes to an array - boxes = np.array(boxes) #(N, 6) - scores = np.array(scores) #(N,) - - #at this juncture the 3d boxes are such that cut boxes - #will have no overlap that is not transposed. - #the last step is to remove overlaps that are transposed. - #overlap is measured by the intersection-over-union of - #two boxes. given a choice, we prefer boxes that contained - #more informative images - boxes = box_nms(boxes, scores, iou_threshold=0.1) - - #save results - save_box_volumes(volume, dataset_name, boxes, is_isotropic) - - return None - - #get the sets of all boxes for reconstructing 3d data from - #all the datasets + # extract and save flipbooks + count = 0 + for child in children: + if count >= 50: + break + flipbook, flipbook_fname = extract_subvolume(volume, child) + if flipbook_fname is not None: + io.imsave(os.path.join(vol_savedir, flipbook_fname + '.tif'), + flipbook, check_contrast=False) + count += 1 + with Pool(processes) as pool: - result = list(pool.map(overlap_suppression, groups_impaths)) - - print('Finished') \ No newline at end of file + output = pool.map(create_flipbooks, volume_fpaths) \ No newline at end of file From 205f61456c251d7544841e3e34984e56cc101084 Mon Sep 17 00:00:00 2001 From: conradry Date: Tue, 13 Dec 2022 12:16:48 -0500 Subject: [PATCH 16/19] update links and info in readme --- README.md | 41 +++++++++++++++++++++++------------------ environment.yml | 1 + pretraining/README.md | 2 +- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index a421f1e..af5dd11 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,7 @@ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cem500k-a-large-scale-heterogeneous-unlabeled/electron-microscopy-image-segmentation-on-1)](https://paperswithcode.com/sota/electron-microscopy-image-segmentation-on-1?p=cem500k-a-large-scale-heterogeneous-unlabeled) - -Code for the paper: [CEM500K - A large-scale heterogeneous unlabeled cellular electron microscopy image dataset for deep learning.](https://www.biorxiv.org/content/10.1101/2020.12.11.421792v1) +Code for the paper: [CEM500K, a large-scale heterogeneous unlabeled cellular electron microscopy image dataset for deep learning](https://elifesciences.org/articles/65894) ## Getting Started @@ -11,7 +10,7 @@ Code for the paper: [CEM500K - A large-scale heterogeneous unlabeled cellular el First clone this repository: ``` -git clone https://github.com/volume-em/cellemnet +git clone https://github.com/volume-em/cem-dataset.git ``` If using conda, install dependencies in a new environment: @@ -32,15 +31,20 @@ Otherwise, required dependencies can be installed with another package manager ( - scikit-learn - imagehash -## Download CEM500K +## Download the Dataset -The CEM500K dataset, metadata and pretrained_weights are available through [EMPIAR ID 10592](https://www.ebi.ac.uk/pdbe/emdb/empiar/entry/10592/). +The latest iteration of the CEM dataset is CEM1.5M. Images and metadata are available for download through [EMPIAR ID 11035](https://www.ebi.ac.uk/empiar/EMPIAR-11035/). -## Use the pre-trained weights +## Pre-trained weights Currently, pre-trained weights are only available for PyTorch. For an example of how to use them see ```evaluation/benchmark_configs``` and ```notebooks/pretrained_weights.ipynb```. -We're working to convert the weights for use with TensorFlow/Keras. If you have any experience with this kind of conversion and would like to help with testing, please open an issue. +| Model architecture | Pre-training method | Dataset | Link | +| ------------------- | ------------------- | ----------- | ---------------------------------------------- | +| ResNet50 | MoCoV2 | CEM500K | https://zenodo.org/record/6453140#.Y5inAC2B1Qg | +| ResNet50 | SWaV | CEM1.5M | https://zenodo.org/record/6453140#.Y5inAC2B1Qg | + + ## Data Curation @@ -49,16 +53,17 @@ For image deduplication and filtering routines see the ```dataset``` directory R ## Citing this work Please cite this work. -``` -@article {Conrad2020.12.11.421792, - author = {Conrad, Ryan W and Narayan, Kedar}, - title = {CEM500K - A large-scale heterogeneous unlabeled cellular electron microscopy image dataset for deep learning.}, - elocation-id = {2020.12.11.421792}, - year = {2020}, - doi = {10.1101/2020.12.11.421792}, - publisher = {Cold Spring Harbor Laboratory}, - URL = {https://www.biorxiv.org/content/early/2020/12/11/2020.12.11.421792}, - eprint = {https://www.biorxiv.org/content/early/2020/12/11/2020.12.11.421792.full.pdf}, - journal = {bioRxiv} + +```bibtex +@article {Conrad2021, + author = {Conrad, Ryan and Narayan, Kedar}, + doi = {10.7554/eLife.65894}, + issn = {2050-084X}, + journal = {eLife}, + month = {apr}, + title = {{CEM500K, a large-scale heterogeneous unlabeled cellular electron microscopy image dataset for deep learning}}, + url = {https://elifesciences.org/articles/65894}, + volume = {10}, + year = {2021} } ``` \ No newline at end of file diff --git a/environment.yml b/environment.yml index 7e4bc35..349d3a6 100644 --- a/environment.yml +++ b/environment.yml @@ -4,6 +4,7 @@ channels: - conda-forge - simpleitk dependencies: + - python=3.9 - pip - pytorch - torchvision diff --git a/pretraining/README.md b/pretraining/README.md index ca2f41d..50ba058 100644 --- a/pretraining/README.md +++ b/pretraining/README.md @@ -16,7 +16,7 @@ python swav/train_swav.py swav_config.yaml To run pretraining with MoCoV2, first update the ```data_path``` and ```model_path``` parameters in ```mocov2/mocov2_config.yaml```, then run: ```bash -python swav/train_swav.py swav_config.yaml +python mocov2/train_mocov2.py mocov2_config.yaml ``` From b22665ffe2de314bfbe44bd22f0948876223a696 Mon Sep 17 00:00:00 2001 From: conradry Date: Tue, 13 Dec 2022 12:17:52 -0500 Subject: [PATCH 17/19] replace cem1.5m weights link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index af5dd11..fe5a1aa 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Currently, pre-trained weights are only available for PyTorch. For an example of | Model architecture | Pre-training method | Dataset | Link | | ------------------- | ------------------- | ----------- | ---------------------------------------------- | | ResNet50 | MoCoV2 | CEM500K | https://zenodo.org/record/6453140#.Y5inAC2B1Qg | -| ResNet50 | SWaV | CEM1.5M | https://zenodo.org/record/6453140#.Y5inAC2B1Qg | +| ResNet50 | SWaV | CEM1.5M | https://zenodo.org/record/6453160#.Y5iznS2B1Qh | From c3a2b96e49a717b991348362f1a0c57a9b0f082f Mon Sep 17 00:00:00 2001 From: conradry Date: Tue, 13 Dec 2022 14:25:51 -0500 Subject: [PATCH 18/19] scripts for scraping data from ngffs --- dataset/scraping/crop_rois_from_volume.py | 160 ++++++++ dataset/scraping/h01_download.py | 72 ++++ dataset/scraping/ngff_datasets.csv | 46 +++ dataset/scraping/ngff_download.py | 480 ++++++++++++++++++++++ 4 files changed, 758 insertions(+) create mode 100644 dataset/scraping/crop_rois_from_volume.py create mode 100644 dataset/scraping/h01_download.py create mode 100644 dataset/scraping/ngff_datasets.csv create mode 100644 dataset/scraping/ngff_download.py diff --git a/dataset/scraping/crop_rois_from_volume.py b/dataset/scraping/crop_rois_from_volume.py new file mode 100644 index 0000000..f8aa998 --- /dev/null +++ b/dataset/scraping/crop_rois_from_volume.py @@ -0,0 +1,160 @@ +import os, sys, math +import argparse +import numpy as np +import SimpleITK as sitk +from glob import glob +from tqdm import tqdm + +def convert_to_byte(image): + """ + Verify that image is byte type + """ + if image.dtype == np.uint8: + return image + else: + image = image.astype(np.float32) + image -= image.min() + im_max = image.max() + if im_max > 0: # avoid zero division + image /= im_max + + image *= 255 + + return image.astype(np.uint8) + +def sparse_roi_boxes( + reference_volume, + roi_size, + padding_value=0 + min_frac=0.7 +): + """ + Finds all ROIs of a given size within an overview volume + that contain at least some non-padding values. + + Arguments: + ---------- + reference_volume (np.ndarray): A low-resolution overview image + that can fit in memory. Typically this is the lowest-resolution + available in an image pyramid. + + roi_size (Tuple[d, h, w]): Size of ROIs in voxels relative to + the reference volume. + + padding_value (float): Value used to pad the reference volume. + Must be confirmed by manual inspection. + + min_frac (float): Fraction from 0-1 of and ROI that must be + non-padding values. + + Returns: + -------- + roi_boxes (np.ndarray): Array of (N, 6) defining bounding boxes + for ROIs that passed that min_frac condition. + + """ + # grid for box indices + xcs, ycs, zcs = roi_size + xsize, ysize, zsize = reference_volume.shape + + xgrid = np.arange(0, xsize + 1, xcs) + ygrid = np.arange(0, ysize + 1, ycs) + zgrid = np.arange(0, zsize + 1, zcs) + + max_padding = (1 - min_frac) * np.prod(roi_size) + + # make sure that there's always an ending index + # so we have complete ranges + if len(xgrid) < 2 or xsize % xcs > 0.5 * xcs: + xgrid = np.append(xgrid, np.array(xsize)[None], axis=0) + if len(ygrid) < 2 or ysize % ycs > 0.5 * ycs: + ygrid = np.append(ygrid, np.array(ysize)[None], axis=0) + if len(zgrid) < 2 or zsize % ycs > 0.5 * zcs: + zgrid = np.append(zgrid, np.array(zsize)[None], axis=0) + + roi_boxes = [] + for xi, xf in zip(xgrid[:-1], xgrid[1:]): + for yi, yf in zip(ygrid[:-1], ygrid[1:]): + for zi, zf in zip(zgrid[:-1], zgrid[1:]): + box_slices = tuple([slice(xi, xf), slice(yi, yf), slice(zi, zf)]) + n_not_padding = np.count_nonzero( + reference_volume[box_slices] == padding_value + ) + if n_not_padding < max_padding: + roi_boxes.append([xi, yi, zi, xf, yf, zf]) + + return np.array(roi_boxes) + +def crop_volume( + volume, + volume_name, + resolution, + save_path, + cube_size=256, + padding_value=0, + min_frac=0.7 +): + # find possible ROIs that are not just blank padding + # they may still be uniform resin though + roi_boxes = sparse_roi_boxes( + volume, cube_size, padding_value, min_frac + ) + + # randomly select n_cubes ROI boxes + box_indices = np.random.choice( + range(len(roi_boxes)), size=(min(n_cubes, len(roi_boxes)),), replace=False + ) + roi_boxes = roi_boxes[box_indices] + + # loop through the boxes that we selected + for bbox in tqdm(roi_boxes): + x1, y1, z1, x2, y2, z2 = bbox + bbox_slices = [ + slice(x1, x2), + slice(y1, y2), + slice(z1, z2) + ] + + # crop the cube + cube = volume[bbox_slices] + cube = convert_to_byte(cube) + + cube_fname = f'{volume_name}-ROI-x{x1}-{x2}_y{y1}-{y2}_z{z1}-{z2}.nrrd' + cube_fpath = os.path.join(save_path, cube_fname) + + cube = sitk.GetImageFromArray(cube) + cube.SetSpacing(resolution) + sitk.WriteImage(cube, cube_fpath) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('fpath', type=str, help='path to a volume file (nrrd or mrc)') + parser.add_argument('save_path', type=str, help='path to save the volumes') + parser.add_argument('-cs', type=int, default=256, help='dimension of cubes to crop') + parser.add_argument('-gb', type=float, default=5, help='maximum number of GBs to crop') + args = parser.parse_args() + + directory = args.directory + save_path = args.save_path + target_dir = args.target_dir + cube_size = args.cs + max_gbs = args.max_gbs + + os.makedirs(save_path, exist_ok=True) + os.makedirs(target_dir, exist_ok=True) + + # load the image into array + volume = sitk.ReadImage(fp) + resolution = volume.GetSpacing() + volume = sitk.GetArrayFromImage(volume) + + # crop into cubes + volname = '.'.join(os.path.basename(fp).split('.')[:-1]) + crop_volume( + volume, + volname, + resolution, + save_path, + cube_size, + padding_value=0 + ) \ No newline at end of file diff --git a/dataset/scraping/h01_download.py b/dataset/scraping/h01_download.py new file mode 100644 index 0000000..70690d1 --- /dev/null +++ b/dataset/scraping/h01_download.py @@ -0,0 +1,72 @@ +""" +Download script for the H01 Dataset +See: https://h01-release.storage.googleapis.com/landing.html + +""" + +import os +import numpy as np +import SimpleITK as sitk +# Ensure tensorstore does not attempt to use GCE credentials +os.environ['GCE_METADATA_ROOT'] = 'metadata.google.internal.invalid' +import tensorstore as ts + + +context = ts.Context({'cache_pool': {'total_bytes_limit': 1000000000}}) +volname = 'shapson-coe2021_h01' + +em_8nm = ts.open({ + 'driver': 'neuroglancer_precomputed', + 'kvstore': {'driver': 'gcs', 'bucket': 'h01-release'}, + 'path': 'data/20210601/8nm_raw'}, + read=True, context=context).result()[ts.d['channel'][0]] + +xsize, ysize, zsize = em_8nm.shape + +# best option is to randomly pick cubes +# of given size +crop_size = 512 +max_gbs = 5 + +# compute the number of cubes from the crop_size +bytes_per_cube = crop_size ** 3 +n_cubes = int((max_gbs * 1024 ** 3) / bytes_per_cube) + +save_path = './H01' + +if not os.path.isdir(save_path): + os.mkdir(save_path) + +# crop and save many will be blank padding +n_cropped = 0 +while n_cropped < n_cubes: + # pick indices for n_cubes + # ranges set manually from checking + # overview in neuroglancer + x = np.random.randint(70000, 450000) + y = np.random.randint(40000, 300000) + z = np.random.randint(0, zsize - crop_size) + xf = x + crop_size + yf = y + crop_size + zf = z + crop_size + + # check if first pixel is + # zero, if yes then skip + if em_8nm[x, y, z].read().result() == 0: + continue + else: + cube = em_8nm[x:xf, y:yf, z:zf].read().result() + + # this dataset is already uint8 and inverted + # we only need to transpose from xyz to zyx + cube = cube.transpose(2, 1, 0) + + # save the result with given resolution + cube_fname = f'{volname}-ROI-x{x}-{xf}_y{y}-{yf}_z{z}-{zf}.nrrd' + cube_fpath = os.path.join(save_path, cube_fname) + + cube = sitk.GetImageFromArray(cube) + cube.SetSpacing([8, 8, 30]) + sitk.WriteImage(cube, cube_fpath) + + n_cropped += 1 \ No newline at end of file diff --git a/dataset/scraping/ngff_datasets.csv b/dataset/scraping/ngff_datasets.csv new file mode 100644 index 0000000..ca14278 --- /dev/null +++ b/dataset/scraping/ngff_datasets.csv @@ -0,0 +1,46 @@ +url,source,api,download_url,volume_name,mip,voxel_x,voxel_y,voxel_z,crop,crop_size,invert,padding_value +https://openorganelle.janelia.org/datasets/jrc_hela-1,openorganelle,xarray,s3://janelia-cosem/jrc_hela-1/neuroglancer/em/fibsem-uint8.precomputed,jrc_hela-1_openorganelle,1,8,8,8,TRUE,256,TRUE,0 +https://openorganelle.janelia.org/datasets/jrc_hela-h89-1,openorganelle,xarray,s3://janelia-cosem/jrc_hela-h89-1/jrc_hela-h89-1.n5/em/fibsem-uint16,jrc_hela-h89-1_openorganelle,1,8,8,8,TRUE,256,TRUE,0 +https://openorganelle.janelia.org/datasets/jrc_ctl-id8-1,openorganelle,xarray,s3://janelia-cosem/jrc_ctl-id8-1/jrc_ctl-id8-1.n5/em/fibsem-uint16,jrc_ctl-id8-1_openorganelle,2,4,4,3.48,TRUE,256,TRUE,0 +https://openorganelle.janelia.org/datasets/jrc_hela-h89-2,openorganelle,xarray,s3://janelia-cosem/jrc_hela-h89-2/jrc_hela-h89-2.n5/em/fibsem-uint16,jrc_hela-h89-2_openorganelle,1,8,8,8,TRUE,256,TRUE,0 +https://openorganelle.janelia.org/datasets/jrc_hela-21,openorganelle,xarray,s3://janelia-cosem/jrc_hela-21/jrc_hela-21.n5/em/fibsem-uint8,jrc_hela-21_openorganelle,0,8,8,8,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_hela-2,openorganelle,xarray,s3://janelia-cosem/jrc_hela-2/neuroglancer/em/fibsem-uint8.precomputed,jrc_hela-2_openorganelle,1,4,4,5.24,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_cos7-11,openorganelle,xarray,s3://janelia-cosem/jrc_cos7-11/jrc_cos7-11.n5/em/fibsem-uint16,jrc_cos7-11_openorganelle,1,8,8,8,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_hela-3,openorganelle,xarray,s3://janelia-cosem/jrc_hela-3/neuroglancer/em/fibsem-uint8.precomputed,jrc_hela-3_openorganelle,1,4,4,3.24,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_choroid-plexus-2,openorganelle,xarray,s3://janelia-cosem/jrc_choroid-plexus-2/neuroglancer/em/fibsem-uint8.precomputed,jrc_choroid-plexus-2_openorganelle,0,8,8,8,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_hela-4,openorganelle,xarray,s3://janelia-cosem/jrc_hela-4/neuroglancer/em/fibsem-uint8.precomputed,jrc_hela-4_openorganelle,1,4,4,4.28,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_hela-22,openorganelle,xarray,s3://janelia-cosem/jrc_hela-22/jrc_hela-22.n5/em/fibsem-uint8,jrc_hela-22_openorganelle,0,8,8,8,TRUE,256,TRUE,0 +https://openorganelle.janelia.org/datasets/jrc_macrophage-2,openorganelle,xarray,s3://janelia-cosem/jrc_macrophage-2/neuroglancer/em/fibsem-uint8.precomputed,jrc_macrophage-2_openorganelle,1,4,4,3.36,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_fly-fsb-1,openorganelle,xarray,s3://janelia-cosem/jrc_fly-fsb-1/jrc_fly-fsb-1.n5/em/fibsem-uint16,jrc_fly-fsb-1_openorganelle,2,4,4,4,TRUE,256,TRUE,0 +https://openorganelle.janelia.org/datasets/jrc_hela-bfa,openorganelle,xarray,s3://janelia-cosem/jrc_hela-bfa/jrc_hela-bfa.n5/em/fibsem-uint8,jrc_hela-bfa_openorganelle,0,8,8,8,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_jurkat-1,openorganelle,xarray,s3://janelia-cosem/jrc_jurkat-1/neuroglancer/em/fibsem-uint8.precomputed,jrc_jurkat-1_openorganelle,1,4,4,3.44,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_mus-pancreas-1,openorganelle,xarray,s3://janelia-cosem/jrc_mus-pancreas-1/neuroglancer/em/fibsem-uint8.precomputed,jrc_mus-pancreas-1_openorganelle,1,4,4,3.4,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_sum159-1,openorganelle,xarray,s3://janelia-cosem/jrc_sum159-1/neuroglancer/em/fibsem-uint8.precomputed,jrc_sum159-1_openorganelle,1,4,4,4.56,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_ctl-id8-5,openorganelle,xarray,s3://janelia-cosem/jrc_ctl-id8-5/jrc_ctl-id8-5.n5/em/fibsem-uint8,jrc_ctl-id8-5_openorganelle,0,8,8,8,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_fly-acc-calyx-1,openorganelle,xarray,s3://janelia-cosem/jrc_fly-acc-calyx-1/jrc_fly-acc-calyx-1.n5/em/fibsem-uint16,jrc_fly-acc-calyx-1_openorganelle,2,4,4,3.72,TRUE,256,TRUE,0 +https://openorganelle.janelia.org/datasets/jrc_ctl-id8-4,openorganelle,xarray,s3://janelia-cosem/jrc_ctl-id8-4/jrc_ctl-id8-4.n5/em/fibsem-uint8,jrc_ctl-id8-4_openorganelle,0,8,8,8,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_ctl-id8-2,openorganelle,xarray,s3://janelia-cosem/jrc_ctl-id8-2/jrc_ctl-id8-2.n5/em/fibsem-uint8,jrc_ctl-id8-2_openorganelle,0,8,8,8,TRUE,256,FALSE,0 +https://openorganelle.janelia.org/datasets/jrc_ctl-id8-3,openorganelle,xarray,s3://janelia-cosem/jrc_ctl-id8-3/jrc_ctl-id8-3.n5/em/fibsem-uint8,jrc_ctl-id8-3_openorganelle,0,8,8,8,TRUE,256,FALSE,0 +https://github.com/mobie/platybrowser-datasets,mobie,xarray,https://s3.embl.de/platybrowser/rawdata/sbem-6dpf-1-whole-raw.n5/setup0/timepoint0,platybrowser_mobie,0,10,10,25,TRUE,256,FALSE,0 +https://github.com/mobie/sponge-fibsem-project,mobie,xarray,https://s3.embl.de/sponge-fibsem/cell1/images/local/fibsem-raw.n5/setup0/timepoint0,sponge-fibsem-cell1_mobie,0,8,8,8,FALSE,256,FALSE,0 +https://github.com/mobie/sponge-fibsem-project,mobie,xarray,https://s3.embl.de/sponge-fibsem/cell3/images/local/fibsem-raw.n5/setup0/timepoint0,sponge-fibsem-cell3_mobie,0,5,5,8,FALSE,256,FALSE,0 +https://github.com/mobie/sponge-fibsem-project,mobie,xarray,https://s3.embl.de/sponge-fibsem/choanocyte-chamber/images/local/fibsem-raw.n5/setup0/timepoint0,sponge-fibsem-choanocyte-chamber_mobie,0,15,15,15,FALSE,256,FALSE,0 +https://github.com/mobie/yeast-clem-datasets,mobie,xarray,https://s3.embl.de/yeast-clem/yeast/images/local/em-raw-overview.n5/setup0/timepoint0,yeast-clem_mobie,0,10,10,2000,FALSE,256,FALSE,0 +https://www.microns-explorer.org/phase1,microns,CloudVolume,precomputed://https://bossdb-open-data.s3.amazonaws.com/microns/minnie/minnie65-phase3-em/aligned/v1,microns_cortical_mm3,1,4,4,40,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/bock11/image,bock2011,1,4,4,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/bhatla/ritaN2/image,bhatla2015,2,x,x,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/bloss/bloss18/image,bloss2018,2,4,4,50,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/bumbarger/bumbarger13/image,bumbarger2013,1,x,x,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/collman/collman15v2/EM25K,collman2015,2,2,2,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/kharris15/spine/em,kharris2015_spine,2,2,2,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/kharris15/apical/em,kharris2015_apical,2,2,2,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/kharris15/oblique/em,kharris2015_oblique,2,2,2,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/hildebrand/130201zf142/160515_SWiFT_60nmpx,hildebrand2017,1,4,4,x,TRUE,256,FALSE,250 +,openconnectome,CloudVolume,s3://open-neurodata/kasthuri/kasthuri11/image,kasthuri2011,2,3,3,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/kasthuri/kasthuri14Maine/image,kasthuri2014_maine,2,3,3,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/kasthuri/kasthuri14s1colEM/image,kasthuri2014_column,2,3,3,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/lee/lee16/image,lee2016,1,4,4,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/takemura/takemura13/image,takemura2013,1,x,x,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/templier/Wafer1/C1_EM,templier_wafer1,1,x,x,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/templier/Wafer3/EM,templier_wafer3,0,x,x,x,TRUE,256,FALSE,0 +,openconnectome,CloudVolume,s3://open-neurodata/wanner16/AA201605/SBEM1,wanner2016,1,x,x,x,TRUE,256,FALSE,0 \ No newline at end of file diff --git a/dataset/scraping/ngff_download.py b/dataset/scraping/ngff_download.py new file mode 100644 index 0000000..5230270 --- /dev/null +++ b/dataset/scraping/ngff_download.py @@ -0,0 +1,480 @@ +import os, sys, math +import argparse +import numpy as np +import pandas as pd +import SimpleITK as sitk +from skimage import io +from fibsem_tools.io import read_xarray +from cloudvolume import CloudVolume, Bbox +from tqdm import tqdm + +def convert_to_byte(image): + """ + Verify that image is byte type + """ + if image.dtype == np.uint8: + return image + else: + image = image.astype(np.float32) + image -= image.min() + im_max = image.max() + if im_max > 0: # avoid zero division + image /= im_max + + image *= 255 + + return image.astype(np.uint8) + +def load_volume(volume, invert=False): + """ + Takes an xarray or CloudVolume and loads it + as byte array. Optionally, inverts contrast. + """ + volume = np.squeeze(np.array(volume[:])) + volume = convert_to_byte(volume) + if invert: + volume = np.invert(volume) + + return volume + +def sparse_roi_boxes( + reference_volume, + roi_size, + padding_value=0 + min_frac=0.7 +): + """ + Finds all ROIs of a given size within an overview volume + that contain at least some non-padding values. + + Arguments: + ---------- + reference_volume (np.ndarray): A low-resolution overview image + that can fit in memory. Typically this is the lowest-resolution + available in an image pyramid. + + roi_size (Tuple[d, h, w]): Size of ROIs in voxels relative to + the reference volume. + + padding_value (float): Value used to pad the reference volume. + Must be confirmed by manual inspection. + + min_frac (float): Fraction from 0-1 of and ROI that must be + non-padding values. + + Returns: + -------- + roi_boxes (np.ndarray): Array of (N, 6) defining bounding boxes + for ROIs that passed that min_frac condition. + + """ + # grid for box indices + xcs, ycs, zcs = roi_size + xsize, ysize, zsize = reference_volume.shape + + xgrid = np.arange(0, xsize + 1, xcs) + ygrid = np.arange(0, ysize + 1, ycs) + zgrid = np.arange(0, zsize + 1, zcs) + + max_padding = (1 - min_frac) * np.prod(roi_size) + + # make sure that there's always an ending index + # so we have complete ranges + if len(xgrid) < 2 or xsize % xcs > 0.5 * xcs: + xgrid = np.append(xgrid, np.array(xsize)[None], axis=0) + if len(ygrid) < 2 or ysize % ycs > 0.5 * ycs: + ygrid = np.append(ygrid, np.array(ysize)[None], axis=0) + if len(zgrid) < 2 or zsize % ycs > 0.5 * zcs: + zgrid = np.append(zgrid, np.array(zsize)[None], axis=0) + + roi_boxes = [] + for xi, xf in zip(xgrid[:-1], xgrid[1:]): + for yi, yf in zip(ygrid[:-1], ygrid[1:]): + for zi, zf in zip(zgrid[:-1], zgrid[1:]): + box_slices = tuple([slice(xi, xf), slice(yi, yf), slice(zi, zf)]) + n_not_padding = np.count_nonzero( + reference_volume[box_slices] == padding_value + ) + if n_not_padding < max_padding: + roi_boxes.append([xi, yi, zi, xf, yf, zf]) + + return np.array(roi_boxes) + +def crop_cloud_volume( + url, + save_path, + volume_name, + target_mip, + n_cubes=100, + cube_size=256, + invert=False, + padding_value=0, + min_frac=0.7 +): + """ + Crops non-empty ROIs from a CloudVolume and saves + the results as .nrrd images with correct voxel size. + + Arguments: + ---------- + url (str): URL from which to load the CloudVolume. + + save_path (str): Directory in which to save crops. + + volume_name (str): Name used to identify this volume. It + will be the prefix of all crop filenames. + + target_mip (int): The mip level from which to crop data. + + n_cubes (int): The maximum number of subvolumes to crop + from the CloudVolume. + + cube_size (int): The size of crops. Assumes crops are have + cubic dimensions. + + invert (bool): Whether to invert the intensity of the image. + + padding_value (float): Value in the CloudVolume used as image + padding. If invert is True, this value will be inverted as well. + + min_frac (float): Fraction from 0-1 of and ROI that must be + non-padding values. + + """ + # check available mip levels + high_mip = target_mip + volume_high = CloudVolume(url, mip=high_mip, use_https=True, + fill_missing=True, progress=False) + + cv_box = np.array(volume_high.bounds.to_list()) + x0, y0, z0 = cv_box[:3] + + if invert: + padding_value = 255 - padding_value + + # get the nm resolution + high_resolution = list(volume_high.available_resolutions)[high_mip] + + # use lowest resolution as reference, unless + # it's a factor of cube_size smaller than the scale + # at the high mip level + low_mip = max(list(volume_high.available_mips)) + factor = (2 ** (low_mip + 1)) / (2 ** (high_mip + 1)) + if factor >= cube_size: + max_mip_diff = int(math.log(cube_size, 2)) - 1 + low_mip = high_mip + max_mip_diff + + volume_low = CloudVolume(url, mip=low_mip, use_https=True, + fill_missing=True, progress=True) + + # check whether the reference volume already exists + # create it if not + reference_fpath = os.path.join(save_path, f'{volume_name}_reference_mip{low_mip}.tif') + if not os.path.exists(reference_fpath): + reference_volume = load_volume(volume_low, invert) + io.imsave(reference_fpath, reference_volume) + + reference_volume = io.imread(reference_fpath) + + # find possible ROIs that are not just blank padding + factors = np.array(volume_high.shape[:3]) / np.array(volume_low.shape[:3]) + low_res_roi_size = np.floor(cube_size / factors).astype('int') + roi_boxes = sparse_roi_boxes( + reference_volume, low_res_roi_size, padding_value, min_frac + ) + + # randomly select n_cubes ROI boxes + box_indices = np.random.choice( + range(len(roi_boxes)), size=(min(n_cubes, len(roi_boxes)),), replace=False + ) + bboxes_low = roi_boxes[box_indices] + + # loop through the boxes that we selected + for bbox_low in tqdm(bboxes_low): + # convert between mips + bbox_low = Bbox(bbox_low[:3], bbox_low[3:]) + + # convert the ROI box from low to high resolution scale + bbox_high = np.array( + volume_high.bbox_to_mip(bbox_low, low_mip, high_mip).to_list() + ) + + # add bounding indices as an offset + bbox_high[0] += x0 + bbox_high[3] += x0 + bbox_high[1] += y0 + bbox_high[4] += y0 + bbox_high[2] += z0 + bbox_high[5] += z0 + + # boundaries aren't always consistent + # clip the bbox by the volume bounds + bbox_high1 = np.clip(bbox_high[:3], cv_box[:3], cv_box[3:]) + bbox_high2 = np.clip(bbox_high[3:], cv_box[:3], cv_box[3:]) + + bbox_high = np.concatenate([bbox_high1, bbox_high2]) + + bbox_high_slices = tuple([ + slice(bbox_high[0], bbox_high[3]), + slice(bbox_high[1], bbox_high[4]), + slice(bbox_high[2], bbox_high[5]) + ]) + + # handle case of an unitary channel + if len(volume_high.shape) == 4: + bbox_high_slices += (0,) + + # crop the high-resolution cube + cube = load_volume(volume_high[bbox_high_slices], invert) + + # extract ranges for filename + x1, x2 = bbox_high[0], bbox_high[3] + y1, y2 = bbox_high[1], bbox_high[4] + z1, z2 = bbox_high[2], bbox_high[5] + + cube_fname = f'{volume_name}-ROI-x{x1}-{x2}_y{y1}-{y2}_z{z1}-{z2}.nrrd' + cube_fpath = os.path.join(save_path, cube_fname) + + # transpose from xyz to zyx + cube = cube.transpose(2, 1, 0) + + # save as nrrd with appropriate voxel size in header + cube = sitk.GetImageFromArray(cube) + cube.SetSpacing(high_resolution) + sitk.WriteImage(cube, cube_fpath) + +def crop_xarray( + url, + save_path, + volume_name, + source, + target_mip, + high_resolution, + n_cubes=100, + cube_size=256, + invert=False, + padding_value=0, + storage_options=None +): + """ + Crops non-empty ROIs from a xarray and saves + the results as .nrrd images with correct voxel size. + + Arguments: + ---------- + url (str): URL from which to load the xarray. + + save_path (str): Directory in which to save crops. + + volume_name (str): Name used to identify this volume. It + will be the prefix of all crop filenames. + + target_mip (int): The mip level from which to crop data. + + n_cubes (int): The maximum number of subvolumes to crop + from the CloudVolume. + + cube_size (int): The size of crops. Assumes crops are have + cubic dimensions. + + invert (bool): Whether to invert the intensity of the image. + + padding_value (float): Value in the CloudVolume used as image + padding. If invert is True, this value will be inverted as well. + + min_frac (float): Fraction from 0-1 of and ROI that must be + non-padding values. + + storage_options (Dict): Storage options for reading an xarray. + + """ + # url to the dataset we want to download + high_mip = target_mip + mip_str = f's{target_mip}' + mip_url = os.path.join(url, mip_str) + + volume_high = read_xarray(mip_url, storage_options=storage_options) + + # only way to check available mip levels + # is 1 by 1 start from the smallest desirable + max_mip_diff = int(math.log(cube_size, 2)) - 1 + low_mip = high_mip + max_mip_diff + + for mip in list(range(low_mip, high_mip - 1, -1)): + mip_str = f's{mip}' + mip_url = os.path.join(url, mip_str) + try: + volume_high = read_xarray(mip_url, storage_options=storage_options) + low_mip = mip + break + except Exception as err: + continue + + if low_mip <= high_mip: + raise Exception('No low resolution volumes found! Are you sure?') + + # check whether the reference volume already exists + # create it if not + reference_fpath = os.path.join(save_path, f'{volume_name}_reference_mip{low_mip}.tif') + if not os.path.exists(reference_fpath): + reference_volume = load_volume(volume_low, invert) + io.imsave(reference_fpath, reference_volume) + + reference_volume = io.imread(reference_fpath) + + # find possible ROIs that are not just blank padding + # they may still be uniform resin though (filtering + # will happen later) + factors = np.array(volume_high.shape[:3]) / np.array(volume_low.shape[:3]) + low_res_roi_size = np.floor(cube_size / factors).astype('int') + roi_boxes = sparse_roi_boxes( + reference_volume, low_res_roi_size, padding_value, min_frac + ) + + # randomly select n_cubes ROI boxes + box_indices = np.random.choice( + range(len(roi_boxes)), size=(min(n_cubes, len(roi_boxes)),), replace=False + ) + bboxes_low = roi_boxes[box_indices] + + # loop through the boxes that we selected + for bbox_low in tqdm(bboxes_low): + # convert between mips + bbox_high = np.concatenate( + [bbox_low[:3] * factors, bbox_low[3:] * factors] + ).astype('int') + + # boundaries aren't always consistent + # clip the bbox by the volume size + bbox_high1 = np.clip(bbox_high[:3], 0, None) + bbox_high2 = np.clip(bbox_high[3:], None, volume_high.shape[:3]) + bbox_high = np.concatenate([bbox_high1, bbox_high2]) + + # make sure it's actually a volume with 5 slices + # skip this bounding box otherwise (it's on the edge) + if np.any(bbox_high[3:] - bbox_high[:3] < 5): + continue + + bbox_high_slices = tuple([ + slice(bbox_high[0], bbox_high[3]), + slice(bbox_high[1], bbox_high[4]), + slice(bbox_high[2], bbox_high[5]) + ]) + + # handle case of a unitary channel + if volume_high.ndim == 4: + bbox_high_slices += (0,) + + # crop the cube + cube = load_volume(volume_high[bbox_high_slices], invert) + + x1, x2 = bbox_high[0], bbox_high[3] + y1, y2 = bbox_high[1], bbox_high[4] + z1, z2 = bbox_high[2], bbox_high[5] + cube_fname = f'{volume_name}-ROI-x{x1}-{x2}_y{y1}-{y2}_z{z1}-{z2}.nrrd' + cube_fpath = os.path.join(save_path, cube_fname) + + cube = sitk.GetImageFromArray(cube) + cube.SetSpacing(high_resolution) + sitk.WriteImage(cube, cube_fpath) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('csv', type=str, help='path to url csv file') + parser.add_argument('save_path', type=str, help='path to save the volumes') + parser.add_argument('-gb', type=float, default=5, help='maximum number of GBs to crop') + args = parser.parse_args() + + # load the csv + df = pd.read_csv(args.csv) + save_path = args.save_path + max_gbs = args.max_gbs + + os.makedirs(save_path, exist_ok=True) + + for i, row in df.iterrows(): + source = row['source'] + url = row['download_url'] + volname = row['volume_name'] + mip = row['mip'] + crop = row['crop'] + crop_size = row['crop_size'] + invert = row['invert'] + padding_value = row['padding_value'] + api = row['api'] + + # read resolution from csv + # CloudVolume api reads it directly from metadata + if api in ['xarray']: + xyz = [ + float(row['voxel_x']), float(row['voxel_y']), float(row['voxel_z']) + ] + resolution = (mip + 1) * np.array(xyz) + resolution = resolution.tolist() + if source in ['openorganelle']: + storage_options = {'anon' : True} + else: + storage_options = None + + # compute the number of cubes from the crop_size + bytes_per_cube = crop_size ** 3 + n_cubes = max(1, (max_gbs * 1024 ** 3) // bytes_per_cube) + + print(f'Downloading from {url}') + if crop and api in ['CloudVolume']: + crop_cloud_volume( + url, + save_path, + volname, + mip, + n_cubes, + crop_size, + invert, + padding_value + ) + elif crop and api in ['xarray']: + crop_xarray( + url, + save_path, + volname, + source, + mip, + resolution, + n_cubes, + crop_size, + invert, + padding_value, + storage_options + ) + elif not crop and is_cloudvol: + volume = CloudVolume(url, mip=mip, use_https=True, + fill_missing=True, progress=False) + + # get the nm resolution + resolution = list(volume.available_resolutions)[mip] + + # download and process the volume + volume = load_volume(volume, invert) + + # transpose from xyz to zyx + volume = volume.transpose(2, 1, 0) + vol_fpath = os.path.join(save_path, f'{volname}.nrrd') + + volume = sitk.GetImageFromArray(volume) + volume.SetSpacing(resolution) + sitk.WriteImage(volume, vol_fpath) + elif not crop and not is_cloudvol: + mip_str = f's{mip}' + mip_url = os.path.join(url, mip_str) + + volume = read_xarray(mip_url, storage_options=storage_options) + + # download and process the volume + volume = load_volume(volume, invert) + vol_fpath = os.path.join(save_path, f'{volname}.nrrd') + + volume = sitk.GetImageFromArray(volume) + volume.SetSpacing(resolution) + sitk.WriteImage(volume, vol_fpath) + else: + raise Exception('Nothing to do with this dataset!') \ No newline at end of file From 79ad7dd0c614020c387fe4cd7fc45e804580fa1b Mon Sep 17 00:00:00 2001 From: conradry Date: Tue, 13 Dec 2022 14:52:53 -0500 Subject: [PATCH 19/19] updated readme with flipbooks and scraping --- dataset/3d/reconstruct3d.py | 69 +++++++++++++++++++++---------------- dataset/README.md | 53 ++++++++++++++++++++++++---- 2 files changed, 87 insertions(+), 35 deletions(-) diff --git a/dataset/3d/reconstruct3d.py b/dataset/3d/reconstruct3d.py index 6c09bac..8b4e9cb 100644 --- a/dataset/3d/reconstruct3d.py +++ b/dataset/3d/reconstruct3d.py @@ -13,14 +13,27 @@ '-ROI-' identifier exists in each subvolume name. Volumes must be in a format readable by SimpleITK (primarily .mrc, .nrrd, .tif, etc.). THIS -SCRIPT DOES NOT SUPPORT (OME)ZARRs. As a rule, such datasets are usually created because of -their large size. Sparsely sampled ROIs from such NGFF-style datasets can be saved in one of -the supported formats using the ../scraping/ngff_download.py script. +SCRIPT DOES NOT SUPPORT NGFFs. As a rule, such datasets are usually created because of +their large size. Sparsely sampled ROIs from such NGFF datasets can be downloaded and saved +in one of the supported formats using the ../scraping/ngff_download.py script. Example usage: -------------- -python reconstruct3d.py {impaths_file} {volume_dir} {savedir} -nz 224 -p 4 --cross-plane +python reconstruct3d.py {filtered_dir} \ + -vd {volume_dir1} {volume_dir2} {volume_dir3} \ + -sd {savedir} -nz 224 -p 4 --limit 100 + +Reconstruct a maximum of 100 subvolumes with 224 z-slices from each +dataset represented in {filtered_dir}. Save them in {savedir}, which +will contain a separate subdirectory corresponding to each dataset. + +Note1: For generating flipbooks, -nz should always be odd. While even +numbers strictly can be used, they're likely to cause confusion at +annotation time because there isn't a "real" middle slice. + +Note2: Z-slices will always be the first dimension in the subvolume +(this is essential for generating flipbooks). For help with arguments: ------------------------ @@ -99,9 +112,6 @@ def find_children(vol_fpath): else: dirname = volname - # lookup the img_source - #assert(dirname in img_fpaths_dict), \ - #f"Directory {dirname} not found in image paths!" if dirname not in img_fpaths_dict: return [], dirname @@ -114,6 +124,12 @@ def find_children(vol_fpath): return vol_img_fpaths, dirname def extract_subvolume(volume, img_fpath): + """ + Extracts the correct subvolume from the + full volumetric dataset based on the name + of a given image which must include the -LOC- + identifier. + """ # extract location of image from filename img_fpath = os.path.basename(img_fpath) volname, loc = img_fpath.split('-LOC-') @@ -135,29 +151,28 @@ def extract_subvolume(volume, img_fpath): lowz = index - span highz = index + span + 1 - # pass images that don't have enough - # context to be annotated as a flipbook + # pass images that don't have enough context if lowz < 0 or highz >= volume.shape[axis]: return None, None else: axis_span = slice(lowz, highz) if axis == 0: - flipbook = volume[axis_span, yslice, xslice] + subvol = volume[axis_span, yslice, xslice] elif axis == 1: - flipbook = volume[yslice, axis_span, xslice] - flipbook = flipbook.transpose(1, 0, 2) + subvol = volume[yslice, axis_span, xslice] + subvol = subvol.transpose(1, 0, 2) elif axis == 2: - flipbook = volume[yslice, xslice, axis_span] - flipbook = flipbook.transpose(2, 0, 1) + subvol = volume[yslice, xslice, axis_span] + subvol = subvol.transpose(2, 0, 1) else: raise Exception(f'Axis cannot be {axis}, must be in [0, 1, 2]') - flipbook_fname = f'{volname}-LOC-{axis}_{lowz}-{highz}_{yrange}_{xrange}' + subvol_fname = f'{volname}-LOC-{axis}_{lowz}-{highz}_{yrange}_{xrange}' - return flipbook, flipbook_fname + return subvol, subvol_fname - def create_flipbooks(vp): + def create_subvols(vp): children, dirname = find_children(vp) vol_savedir = os.path.join(savedir, dirname) @@ -174,23 +189,19 @@ def create_flipbooks(vp): volume = volume[..., 0] if np.any(np.array(volume.shape) < numberz): - raise Exception(f'Flipbooks of size {numberz} cannot be created from {vp} with size {volume.shape}') + raise Exception(f'Subvolume of size {numberz} cannot be created from {vp} with size {volume.shape}') - # directory in which to save flipbooks + # directory in which to save subvols # from this volume dataset if not os.path.isdir(vol_savedir): os.makedirs(vol_savedir, exist_ok=True) - # extract and save flipbooks - count = 0 + # extract and save subvols for child in children: - if count >= 50: - break - flipbook, flipbook_fname = extract_subvolume(volume, child) - if flipbook_fname is not None: - io.imsave(os.path.join(vol_savedir, flipbook_fname + '.tif'), - flipbook, check_contrast=False) - count += 1 + subvol, subvol_fname = extract_subvolume(volume, child) + if subvol_fname is not None: + io.imsave(os.path.join(vol_savedir, subvol_fname + '.tif'), + subvol, check_contrast=False) with Pool(processes) as pool: - output = pool.map(create_flipbooks, volume_fpaths) \ No newline at end of file + output = pool.map(create_subvols, volume_fpaths) \ No newline at end of file diff --git a/dataset/README.md b/dataset/README.md index 5edb95c..6151658 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -39,24 +39,32 @@ python preprocess/vid2stack.py {dir_of_videos} ## 3D Data Preparation -3D datasets are expected to be in a single directory (this includes any video stacks created in the previous section). Supported formats are anything that can be [read by SimpleITK](https://simpleitk.readthedocs.io/en/v1.2.3/Documentation/docs/source/IO.html). It's important that if any volumes are in ```.mrc``` format they be converted to unsigned bytes. With IMOD installed this can be done using: +3D datasets are expected to be in a single directory (this includes any video stacks created in the previous section). +Supported formats are anything that can be [read by SimpleITK](https://simpleitk.readthedocs.io/en/v1.2.3/Documentation/docs/source/IO.html). It's important that if any volumes are in +```.mrc``` format they be converted to unsigned bytes. With IMOD installed this can be done using: ```bash python preprocess/mrc2byte.py {dir_of_mrc_files} ``` -Next, cross-section, patch, and deduplicate volume files. If processing a combination of isotropic and anisotropic volumes, it's crucial that each dataset has a correct header recording the voxel size. If Z resolution is greater that 25% -different from xy resolution, then cross-sections will only be cut from the xy plane, even if axes 0, 1, 2 are passed to the script (see usage example below). +Next, cross-section, patch, and deduplicate volume files. If processing a combination of isotropic and anisotropic volumes, +it's crucial that each dataset has a correct header recording the voxel size. If Z resolution is greater that 25% +different from xy resolution, then cross-sections will only be cut from the xy plane, even if axes 0, 1, 2 are passed to +the script (see usage example below). ```bash python patchify3d.py {dir_of_3d_datasets} {patch_dir} -cs 224 --axes 0 1 2 --processes 4 ``` -The ```patchify3d.py``` script will save a ```.pkl``` file with the name of each volume file. Pickle files contain a dictionary of patches along with corresponding filenames. These files are ready for filtering (see below). +The ```patchify3d.py``` script will save a ```.pkl``` file with the name of each volume file. Pickle files contain a +dictionary of patches along with corresponding filenames. These files are ready for filtering (see below). ## Filtering -2D, video, and 3D datasets can be filtered with the same script just put all the ```.pkl``` files in the same directory. By default, filtering uses a ResNet34 model that was trained on 12,000 manually annotated patches. The weights for this model are downloaded from [Zenodo](https://zenodo.org/record/6458015#.YlmNaS-cbTR) automatically. A new model can be trained, if needed, using the ```train_patch_classifier.py``` script. +2D, video, and 3D datasets can be filtered with the same script just put all the ```.pkl``` files in the same directory. +By default, filtering uses a ResNet34 model that was trained on 12,000 manually annotated patches. The weights for this +model are downloaded from [Zenodo](https://zenodo.org/record/6458015#.YlmNaS-cbTR) automatically. A new model can be +trained, if needed, using the ```train_patch_classifier.py``` script. Filtering will be fastest with a GPU installed, but it's not required. @@ -64,6 +72,39 @@ Filtering will be fastest with a GPU installed, but it's not required. python classify_patches.py {patch_dir} {save_dir} ``` -After running filtering, the ```save_dir``` with have one subdirectory for each of the ```.pkl``` files that were processed. Each subdirectory contains single channel grayscale, unsigned 8-bit tiff images. +After running filtering, the ```save_dir``` with have one subdirectory for each of the ```.pkl``` files that were +processed. Each subdirectory contains single channel grayscale, unsigned 8-bit tiff images. +# Reconstructing subvolumes and flipbooks +Although the curation process always results in 2D image patches, it's possible to retrieve 3D subvolumes as long as one +has access to the original 3D datasets. Patch filenames from 3D datasets always include a suffix denoted by '-LOC-' that +records the slicing plane, the index of the slice, and the x and y positions of the patch. To extract a subvolume around +a patch, use the ```3d/reconstruct3d.py``` script. + +For example, to create short flipbooks of 5 consecutive images from a directory of curated patches: + +```bash +python reconstruct3d.py {filtered_patch_dir} \ + -vd {volume_dir1} {volume_dir2} {volume_dir3} \ + -sd {savedir} -nz -p 4 +``` + +See the script header for more details. + +# Scraping large online datasets + +The patching, deduplication, and filtering pipeline works for volumes in nrrd, mrc, and tif formats. However, very large +datasets like those generated for connectomics research are often to large to practically download and store in memory. +Instead they are commonly stored as NGFFs. Our workflow assumes that these datasets will be sparsely sampled. +The ```scraping/ngff_download.py``` script will download sparsely cropped cubes of image data and save them in the +nrrd format for compatibility with the rest of this workflow. + +For example, to download 5 gigabytes of image data from a list of datasets: + +```bash +python ngff_download.py ngff_datasets.csv {save_path} -gb 5 +``` + +Similarly, large datasets that are not stored in NGFF but are over some size threshold (we've used 5 GB in our work) +can be cropped into smaller ROIs with the ```crop_rois_from_volume.py``` script. \ No newline at end of file