Skip to content

Commit

Permalink
Merge branch 'v2' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Conrad authored Dec 13, 2022
2 parents 8cf5521 + 79ad7dd commit 325b038
Show file tree
Hide file tree
Showing 39 changed files with 3,892 additions and 1,498 deletions.
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Using CEM500K for unsupervised pre-training, we demonstrated a significant impro
First clone this repository:

```
git clone https://github.com/volume-em/cellemnet
git clone https://github.com/volume-em/cem-dataset.git
```

If using conda, install dependencies in a new environment:
Expand All @@ -52,15 +52,20 @@ Otherwise, required dependencies can be installed with another package manager (
- scikit-learn
- imagehash

## Download CEM500K
## Download the Dataset

The CEM500K dataset, metadata and pretrained_weights are available through [EMPIAR ID 10592](https://www.ebi.ac.uk/pdbe/emdb/empiar/entry/10592/).
The latest iteration of the CEM dataset is CEM1.5M. Images and metadata are available for download through [EMPIAR ID 11035](https://www.ebi.ac.uk/empiar/EMPIAR-11035/).

## Use the pre-trained weights
## Pre-trained weights

Currently, pre-trained weights are only available for PyTorch. For an example of how to use them see ```evaluation/benchmark_configs``` and ```notebooks/pretrained_weights.ipynb```.

We're working to convert the weights for use with TensorFlow/Keras. If you have any experience with this kind of conversion and would like to help with testing, please open an issue.
| Model architecture | Pre-training method | Dataset | Link |
| ------------------- | ------------------- | ----------- | ---------------------------------------------- |
| ResNet50 | MoCoV2 | CEM500K | https://zenodo.org/record/6453140#.Y5inAC2B1Qg |
| ResNet50 | SWaV | CEM1.5M | https://zenodo.org/record/6453160#.Y5iznS2B1Qh |



## Data Curation

Expand All @@ -69,6 +74,7 @@ For image deduplication and filtering routines see the ```dataset``` directory R
## Citing this work

Please cite this work.

```bibtex
@article {Conrad2021,
author = {Conrad, Ryan and Narayan, Kedar},
Expand Down
509 changes: 154 additions & 355 deletions dataset/3d/reconstruct3d.py

Large diffs are not rendered by default.

110 changes: 95 additions & 15 deletions dataset/README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,110 @@
# Dataset Curation

Beginning with a collection of 2d and 3d EM images, the scripts in this directory handle all of the dataset preprocessing and curation.
The scripts in this directory handle all dataset preprocessing and curation. Below is an example workflow, read the script headers for more details. Or to see all available parameters use:

For preliminary data preparation the vid2stack.py and mrc2byte.py scripts in the preprocess directory convert videos into image volumes and mrc volumes from signed to unsigned bytes.
```bash
python {script_name}.py --help
```

The main curation pipeline starts with the cross-sectioning of 3d volumes into 2d image slices that can be combined together with any 2d EM datasets. Cross-sectioning for 3d data is handled by the raw/cross_section3d.py script and some basic type checking and file renaming for 2d data is done by the raw/cleanup2d.py script. It's recommended that 2d and 3d "corpuses" of data be kept in separate directories in order to ensure that the two scripts run smoothly; however, the outputs from the raw/cleanup2d.py and raw/cross_section3d.py scripts should all be saved in the same directory. The collection of 2d that results, which are all 8-bit unsigned tiffs, can then be cropped into patches of a given size using the raw/crop_patches.py script. In summary the first step of the workflow is:
**Note: The ```patchify2d.py```, ```patchify3d.py```, and ```classify_patches.py``` scripts are all designed for continuous integration. Datasets that have been processed previously and are in the designated output directories will be ignored by all of them.**

1. Run raw/cleanup2d.py on directory of 2d EM images. Save results to *save_dir*
2. Run raw/cross_section3d.py on directory of 3d EM images. Save results to *save_dir*
3. Run raw/rop_patches.py on images in *save_dir*. Save results to *raw_save_dir*
## 2D Data Preparation

The completion of this first step in the workflow yields the *Raw* dataset. Note that the raw/crop_patches.py sript not only creates tiff images for each of the patches, but also creates a numpy array of the patch's difference hash. The hashes are used for deduplication.
2D images are expected to be organized into directories, where each directory contains a group of images generated
as part of the same imaging project or at least with roughly the same biological metadata.

Deduplication uses the deduplicated/deduplicate.py script. As input the script expects *raw_save_dir* containing the .tiff images and .npy hashes. If new data is added to the *raw_save_dir* after the deduplication script has already been run, the script will only deduplicate the new datasets. This makes it easy to add new datasets without the somewhat time-consuming burden of rerunning deduplication for the entire *Raw* dataset. In summary:
First, standardize images to single channel grayscale and unsigned 8-bit:

1. Run deduplicated/deduplicate.py on *raw_save_dir*. Save results, which are .npy files for each 2d/3d dataset that contain a list of filepaths for exemplar images, to *deduplicated_save_dir*.
```bash
# make copies in new_directory
python preprocess/cleanup2d.py {dir_of_2d_image_groups} -o {new_directory} --processes 4
# or, instead, overwrite images inplace
python preprocess/cleanup2d.py {dir_of_2d_image_groups} --processes 4
```
Second, crop each image into fixed size patches (typically 224x224):

In addition to .npy files for each datasets, the script also outputs a dask array file called deduplicated_fpaths.npz that contains the list of file paths for exemplar images from all 2d/3d datasets. This collection of file paths defines the *Deduplicated* dataset.
```bash
python patchify2d.py {dir_of_2d_image_groups} {patch_dir} -cs 224 --processes 4
```

In the last curation step, uninformative patches are filtered out using a ResNet34 classifier. The filtered/train_nn.py script trains the classifier on a collection of manually labeled image files contained in deduplicated_fpaths.npz. It is assumed that the labeling was performed using the labeling.ipynb notebook included in this repository. In general, training a new classifier shouldn't be necessary; we release the weights for the classifer that we trained on 12,000 labeled images. The filtered/classify_nn.py script performs inference on the set of unlabeled images in deduplicated_fpaths.npz. By default, the script will download and use the weights that we released. In summary:
The ```patchify2d.py``` script will save a ```.pkl``` file with the name of each 2D image subdirectory. Pickle files contain a dictionary of patches from all images in the subdirectory along with corresponding filenames. These files are ready for filtering (see below).

1. (Optional) Manually label images in deduplicated_fpaths.npz using labeling.ipynb.
2. (Optional) Run filtered/train_nn.py to train and evaluate a ResNet34 on the images labeled in step 1.
3. Run filtered/classify.py on images images in deduplicated_fpaths.npz. Save dask array of all informative images, nn_filtered_fpaths.npz, to *filtered_save_dir*.
## Video Preparation

These last steps result in the *Filtered* dataset. That's the complete curation pipeline. An optional last step, to generate 3d data, is to run the 3d/reconstruct3d.py script. This script takes the set of filtered images, nn_filtered_fpaths.npz, and the original directory of 3d volumes (i.e. the directory given to cross_section3d.py earlier) and makes data volumes of a given z-thickness. Note that one limitation of this script is that it currently assumes patches are 224x224.
Convert videos in ```.avi``` or ```.mp4``` format to ```.nrrd``` images with correct naming convention (i.e., put the word 'video' in the filename).

```bash
python preprocess/vid2stack.py {dir_of_videos}
```

## 3D Data Preparation

3D datasets are expected to be in a single directory (this includes any video stacks created in the previous section).
Supported formats are anything that can be [read by SimpleITK](https://simpleitk.readthedocs.io/en/v1.2.3/Documentation/docs/source/IO.html). It's important that if any volumes are in
```.mrc``` format they be converted to unsigned bytes. With IMOD installed this can be done using:

```bash
python preprocess/mrc2byte.py {dir_of_mrc_files}
```

Next, cross-section, patch, and deduplicate volume files. If processing a combination of isotropic and anisotropic volumes,
it's crucial that each dataset has a correct header recording the voxel size. If Z resolution is greater that 25%
different from xy resolution, then cross-sections will only be cut from the xy plane, even if axes 0, 1, 2 are passed to
the script (see usage example below).

```bash
python patchify3d.py {dir_of_3d_datasets} {patch_dir} -cs 224 --axes 0 1 2 --processes 4
```

The ```patchify3d.py``` script will save a ```.pkl``` file with the name of each volume file. Pickle files contain a
dictionary of patches along with corresponding filenames. These files are ready for filtering (see below).

## Filtering

2D, video, and 3D datasets can be filtered with the same script just put all the ```.pkl``` files in the same directory.
By default, filtering uses a ResNet34 model that was trained on 12,000 manually annotated patches. The weights for this
model are downloaded from [Zenodo](https://zenodo.org/record/6458015#.YlmNaS-cbTR) automatically. A new model can be
trained, if needed, using the ```train_patch_classifier.py``` script.

Filtering will be fastest with a GPU installed, but it's not required.

```bash
python classify_patches.py {patch_dir} {save_dir}
```

After running filtering, the ```save_dir``` with have one subdirectory for each of the ```.pkl``` files that were
processed. Each subdirectory contains single channel grayscale, unsigned 8-bit tiff images.

# Reconstructing subvolumes and flipbooks

Although the curation process always results in 2D image patches, it's possible to retrieve 3D subvolumes as long as one
has access to the original 3D datasets. Patch filenames from 3D datasets always include a suffix denoted by '-LOC-' that
records the slicing plane, the index of the slice, and the x and y positions of the patch. To extract a subvolume around
a patch, use the ```3d/reconstruct3d.py``` script.

For example, to create short flipbooks of 5 consecutive images from a directory of curated patches:

```bash
python reconstruct3d.py {filtered_patch_dir} \
-vd {volume_dir1} {volume_dir2} {volume_dir3} \
-sd {savedir} -nz -p 4
```

See the script header for more details.

# Scraping large online datasets

The patching, deduplication, and filtering pipeline works for volumes in nrrd, mrc, and tif formats. However, very large
datasets like those generated for connectomics research are often to large to practically download and store in memory.
Instead they are commonly stored as NGFFs. Our workflow assumes that these datasets will be sparsely sampled.
The ```scraping/ngff_download.py``` script will download sparsely cropped cubes of image data and save them in the
nrrd format for compatibility with the rest of this workflow.

For example, to download 5 gigabytes of image data from a list of datasets:

```bash
python ngff_download.py ngff_datasets.csv {save_path} -gb 5
```

Similarly, large datasets that are not stored in NGFF but are over some size threshold (we've used 5 GB in our work)
can be cropped into smaller ROIs with the ```crop_rois_from_volume.py``` script.
155 changes: 155 additions & 0 deletions dataset/classify_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
Description:
------------
Classifies EM images into "informative" or "uninformative".
Example usage:
--------------
python classify_nn.py {deduped_dir} {savedir} --labels {label_file} --weights {weights_file}
For help with arguments:
------------------------
python classify_nn.py --help
"""

DEFAULT_WEIGHTS = "https://zenodo.org/record/6458015/files/patch_quality_classifier_nn.pth?download=1"

import os, sys, cv2, argparse
import pickle
import numpy as np
from skimage import io
from glob import glob

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torchvision.models import resnet34
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset

from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Classifies a set of images by fitting a random forest to an array of descriptive features'
)
parser.add_argument('dedupe_dir', type=str, help='Directory containing ')
parser.add_argument('savedir', type=str)
parser.add_argument('--weights', type=str, metavar='weights',
help='Optional, path to nn weights file. The default is to download weights used in the paper.')

args = parser.parse_args()

# parse the arguments
dedupe_dir = args.dedupe_dir
savedir = args.savedir
weights = args.weights

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# make sure the savedir exists
if not os.path.isdir(savedir):
os.mkdir(savedir)

# list all pkl deduplicated files
fpaths = glob(os.path.join(dedupe_dir, '*.pkl'))

# set up evaluation transforms (assumes imagenet
# pretrained as default in train_nn.py)
imsize = 224
normalize = Normalize() #default is imagenet normalization
eval_tfs = Compose([
Resize(imsize, imsize),
normalize,
ToTensorV2()
])

# create the resnet34 model
model = resnet34()

# modify the output layer to predict 1 class only
model.fc = nn.Linear(in_features=512, out_features=1)

# load the weights from file or from online
# load the weights from file or from online
if weights is not None:
state_dict = torch.load(weights, map_location='cpu')
else:
state_dict = torch.hub.load_state_dict_from_url(DEFAULT_WEIGHTS)

# load in the weights (strictly)
msg = model.load_state_dict(state_dict)
model = model.to(device)
cudnn.benchmark = True

# make a basic dataset class for loading and
# augmenting images WITHOUT any labels
class SimpleDataset(Dataset):
def __init__(self, image_dict, tfs=None):
super(SimpleDataset, self).__init__()
self.image_dict = image_dict
self.tfs = tfs

def __len__(self):
return len(self.image_dict['names'])

def __getitem__(self, idx):
# load the image
fname = self.image_dict['names'][idx]
image = self.image_dict['patches'][idx]
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

# apply transforms
if self.tfs is not None:
image = self.tfs(image=image)['image']

return {'fname': fname, 'image': image}

for fp in tqdm(fpaths):
dataset_name = os.path.basename(fp)
if '-ROI-' in dataset_name:
dataset_name = dataset_name.split('-ROI-')[0]
else:
dataset_name = dataset_name[:-len('.pkl')]

dataset_savedir = os.path.join(savedir, dataset_name)
if not os.path.exists(dataset_savedir):
os.mkdir(dataset_savedir)
else:
continue

# load the patches_dict
with open(fp, mode='rb') as handle:
patches_dict = pickle.load(handle)

# create datasets for the train, validation, and test sets
tst_data = SimpleDataset(patches_dict, eval_tfs)
test = DataLoader(tst_data, batch_size=128, shuffle=False,
pin_memory=True, num_workers=4)

# lastly run inference on the entire set of unlabeled images
tst_fnames = []
tst_predictions = []
for data in test:
with torch.no_grad():
# load data onto gpu then forward pass
images = data['image'].to(device, non_blocking=True)
output = model.eval()(images)
predictions = nn.Sigmoid()(output)

predictions = predictions.detach().cpu().numpy()
tst_predictions.append(predictions)
tst_fnames.append(data['fname'])

tst_fnames = np.concatenate(tst_fnames, axis=0)
tst_predictions = np.concatenate(tst_predictions, axis=0)
tst_predictions = (tst_predictions[:, 0] > 0.5).astype(np.uint8)

for ix, (fn, img) in enumerate(zip(patches_dict['names'], patches_dict['patches'])):
if tst_predictions[ix] == 1:
io.imsave(os.path.join(dataset_savedir, fn + '.tiff'), img, check_contrast=False)
Loading

0 comments on commit 325b038

Please sign in to comment.