Skip to content

Commit

Permalink
Merge pull request #46 from BIOP/bugfixes
Browse files Browse the repository at this point in the history
Bugfixes and update to Cellpose 3 logging
  • Loading branch information
lacan authored Apr 30, 2024
2 parents e0ba95c + 2251104 commit 0b046e0
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 43 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ext.qupathVersion = gradle.ext.qupathVersion

description = 'QuPath extension to use Cellpose'

version = "0.9.2"
version = "0.9.3-SNAPSHOT"

dependencies {
implementation "io.github.qupath:qupath-gui-fx:${qupathVersion}"
Expand Down
105 changes: 73 additions & 32 deletions src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ public class Cellpose2D {
private final static Logger logger = LoggerFactory.getLogger(Cellpose2D.class);

public ImageOp extendChannelOp;
public boolean useGPU;

protected double simplifyDistance = 1.4;

Expand Down Expand Up @@ -412,6 +413,11 @@ public void detectObjectsImpl(ImageData<BufferedImage> imageData, Collection<? e

logger.info("All tiles and objects read, now resolving overlaps");

// In case alltiles is null, we are basically done
if( allTiles == null ) {
logger.info("No results from Cellpose", "There is nothing to recover from cellpose");
}

// Group the candidates per parent object, as this is needed to optimize when checking for overlap
Map<PathObject, List<CandidateObject>> candidatesPerParent = allTiles.values().stream()
.flatMap(t -> t.getCandidates().stream())
Expand Down Expand Up @@ -721,12 +727,15 @@ private VirtualEnvironmentRunner getVirtualEnvironmentRunner() {
*/
private LinkedHashMap<File, TileFile> runCellpose(LinkedHashMap<File, TileFile> allTiles) throws InterruptedException, IOException {


// Need to define the name of the command we are running. We used to be able to use 'cellpose' for both but not since Cellpose v2
String runCommand = this.parameters.containsKey("omni") ? "omnipose" : "cellpose";
VirtualEnvironmentRunner veRunner = getVirtualEnvironmentRunner();

// This is the list of commands after the 'python' call
List<String> cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", runCommand));
// We want to ignore all warnings to make sure the log is clean (-W ignore)
// We want to be able to call the module by name (-m)
// We want to make sure UTF8 mode is by default (-X utf8)
List<String> cellposeArguments = new ArrayList<>(Arrays.asList("-Xutf8", "-W", "ignore", "-m", runCommand));

cellposeArguments.add("--dir");
cellposeArguments.add("" + this.tempDirectory);
Expand All @@ -746,25 +755,32 @@ private LinkedHashMap<File, TileFile> runCellpose(LinkedHashMap<File, TileFile>

cellposeArguments.add("--no_npy");

cellposeArguments.add("--use_gpu");
if( this.useGPU ) cellposeArguments.add("--use_gpu");

cellposeArguments.add("--verbose");

veRunner.setArguments(cellposeArguments);

// Finally, we can run Cellpose
veRunner.runCommand();
veRunner.runCommand(false);

return processCellposeFiles(veRunner, allTiles);

}

private LinkedHashMap<File, TileFile> processCellposeFiles(VirtualEnvironmentRunner veRunner, LinkedHashMap<File, TileFile> allTiles) throws CancellationException, InterruptedException, IOException {

// Make sure that allTiles is not null, if it is, just return null
// as we are likely just running validation and thus do not need to give any results back
if (allTiles == null ) {
veRunner.getProcess().waitFor();
return null;
}

// Build a thread pool to process reading the images in parallel
ExecutorService executor = Executors.newFixedThreadPool(5);

if (!this.doReadResultsAsynchronously || allTiles == null) {
if (!this.doReadResultsAsynchronously) {
// We need to wait for the process to finish
veRunner.getProcess().waitFor();
allTiles.entrySet().forEach(entry -> {
Expand Down Expand Up @@ -893,7 +909,7 @@ private void runTraining() throws IOException, InterruptedException {
VirtualEnvironmentRunner veRunner = getVirtualEnvironmentRunner();

// This is the list of commands after the 'python' call
List<String> cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", runCommand));
List<String> cellposeArguments = new ArrayList<>(Arrays.asList( "-Xutf8", "-W", "ignore", "-m", runCommand));

cellposeArguments.add("--train");

Expand All @@ -917,18 +933,15 @@ private void runTraining() throws IOException, InterruptedException {
}
});


cellposeArguments.add("--use_gpu");
// Some people may deactivate this...
if( this.useGPU ) cellposeArguments.add("--use_gpu");

cellposeArguments.add("--verbose");

veRunner.setArguments(cellposeArguments);

// Finally, we can run Cellpose
veRunner.runCommand();

// Wait for the process to finish
veRunner.getProcess().waitFor();
veRunner.runCommand(true);

// Get the log
this.theLog = veRunner.getProcessLog();
Expand Down Expand Up @@ -989,7 +1002,8 @@ private ResultsTable runCellposeQC() throws IOException, InterruptedException {

qcRunner.setArguments(qcArguments);

qcRunner.runCommand();
qcRunner.runCommand(true);


// The results are stored in the validation directory, open them as a results table
File qcResults = new File( getValidationDirectory(), "QC-Results" + File.separator + "Quality_Control for " + this.modelFile.getName() + ".csv");
Expand Down Expand Up @@ -1047,20 +1061,25 @@ private ResultsTable parseTrainingResults() {

if (this.theLog != null) {
// Try to parse the output of Cellpose to give meaningful information to the user. This is very old school
// Look for "Epoch 0, Time 2.3s, Loss 1.0758, Loss Test 0.6007, LR 0.2000"
String epochPattern = ".*Epoch\\s*(\\d+),\\s*Time\\s*(\\d+\\.\\d)s,\\s*Loss\\s*(\\d+\\.\\d+),\\s*Loss Test\\s*(\\d+\\.\\d+),\\s*LR\\s*(\\d+\\.\\d+).*";
// Build Matcher
Pattern pattern = Pattern.compile(epochPattern);
Matcher m;
for (String line : this.theLog) {
m = pattern.matcher(line);
if (m.find()) {
trainingResults.incrementCounter();
trainingResults.addValue("Epoch", Double.parseDouble(m.group(1)));
trainingResults.addValue("Time[s]", Double.parseDouble(m.group(2)));
trainingResults.addValue("Loss", Double.parseDouble(m.group(3)));
trainingResults.addValue("Loss Test", Double.parseDouble(m.group(4)));
trainingResults.addValue("LR", Double.parseDouble(m.group(5)));
Matcher m;
for (LogParser parser : LogParser.values()) {
m = parser.getPattern().matcher(line);
if (m.find()) {
trainingResults.incrementCounter();
trainingResults.addValue("Epoch", Double.parseDouble(m.group("epoch")));
trainingResults.addValue("Time", Double.parseDouble(m.group("time")));
trainingResults.addValue("Loss", Double.parseDouble(m.group("loss")));
if (parser != LogParser.OMNI) { // Omnipose does not provide validation loss
trainingResults.addValue("Validation Loss", Double.parseDouble(m.group("val")));
trainingResults.addValue("LR", Double.parseDouble(m.group("lr")));

} else {
trainingResults.addValue("Validation Loss", Double.NaN);
trainingResults.addValue("LR", Double.NaN);

}
}
}
}
}
Expand Down Expand Up @@ -1104,7 +1123,7 @@ public void showTrainingGraph(boolean show, boolean save) {
//populating the series with data
for (int i = 0; i < output.getCounter(); i++) {
loss.getData().add(new XYChart.Data<>(output.getValue("Epoch", i), output.getValue("Loss", i)));
lossTest.getData().add(new XYChart.Data<>(output.getValue("Epoch", i), output.getValue("Loss Test", i)));
lossTest.getData().add(new XYChart.Data<>(output.getValue("Epoch", i), output.getValue("Validation Loss", i)));

}
lineChart.getData().add(loss);
Expand Down Expand Up @@ -1166,18 +1185,18 @@ private void saveImagePairs(List<PathObject> annotations, String imageName, Imag
if (annotations.isEmpty()) {
return;
}
int downsample = 1;
double downsample;
if (Double.isFinite(pixelSize) && pixelSize > 0) {
downsample = (int) Math.round(pixelSize / originalServer.getPixelCalibration().getAveragedPixelSize().doubleValue());
downsample = pixelSize / originalServer.getPixelCalibration().getAveragedPixelSize().doubleValue();
} else {
downsample = 1.0;
}

AtomicInteger idx = new AtomicInteger();
int finalDownsample = downsample;

annotations.forEach(a -> {
int i = idx.getAndIncrement();

RegionRequest request = RegionRequest.createInstance(originalServer.getPath(), finalDownsample, a.getROI());
RegionRequest request = RegionRequest.createInstance(originalServer.getPath(), downsample, a.getROI());
File imageFile = new File(saveDirectory, imageName + "_region_" + i + ".tif");
File maskFile = new File(saveDirectory, imageName + "_region_" + i + "_masks.tif");
try {
Expand Down Expand Up @@ -1348,6 +1367,7 @@ private Collection<CandidateObject> readObjectsFromFile(TileFile tileFile) throw
}
}
// Ignore the IDs, because they will be the same across different images, and we don't really need them
if(candidates.isEmpty()) return Collections.emptyList();
return candidates.values();
}

Expand Down Expand Up @@ -1424,4 +1444,25 @@ private static class CandidateObject {
geometry = geometry.getGeometryN(index);
}
}
public enum LogParser {

// Cellpose 2 pattern when training : "Look for "Epoch 0, Time 2.3s, Loss 1.0758, Loss Test 0.6007, LR 0.2000"
// Cellpose 3 pattern when training : "5, train_loss=2.6546, test_loss=2.0054, LR=0.1111, time 2.56s"
// Omnipose pattern when training : "Train epoch: 10 | Time: 0.22min | last epoch: 0.74s | <sec/epoch>: 0.73s | <sec/batch>: 0.33s | <Batch Loss>: 5.076259 | <Epoch Loss>: 4.429341"
// WARNING: Currently Omnipose does not provide any output to the validation loss (Test loss in Cellpose)
CP2("Cellpose v2", ".*Epoch\\s*(?<epoch>\\d+),\\s*Time\\s*(?<time>\\d+\\.\\d)s,\\s*Loss\\s*(?<loss>\\d+\\.\\d+),\\s*Loss Test\\s*(?<val>\\d+\\.\\d+),\\s*LR\\s*(?<lr>\\d+\\.\\d+).*"),
CP3( "Cellpose v3", ".* (?<epoch>\\d+), train_loss=(?<loss>\\d+\\.\\d+), test_loss=(?<val>\\d+\\.\\d+), LR=(?<lr>\\d+\\.\\d+), time (?<time>\\d+\\.\\d+)s.*"),
OMNI("Omnipose", ".*Train epoch: (?<epoch>\\d+) \\| Time: (?<time>\\d+\\.\\d+)min .*\\<Epoch Loss\\>: (?<loss>\\d+\\.\\d+).*");

private final String name;
private final Pattern pattern;

LogParser(String name, String regex) {
this.name = name;
this.pattern = Pattern.compile(regex);
}

public String getName() {return this.name;}
public Pattern getPattern() {return this.pattern;}
}
}
14 changes: 14 additions & 0 deletions src/main/java/qupath/ext/biop/cellpose/CellposeBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public class CellposeBuilder {
private ImageOp extendChannelOp = null;

private boolean doReadResultsAsynchronously = false;
private boolean useGPU = true;

/**
* can create a cellpose builder from a serialized JSON version of this builder.
Expand Down Expand Up @@ -134,9 +135,20 @@ protected CellposeBuilder(String modelPath) {

}

/**
* overwrite use GPU
* @param useGPU add or remove the option
* @return this builder
*/
public CellposeBuilder useGPU( boolean useGPU ) {
this.useGPU = useGPU;

return this;
}

/**
* Specify the training directory
*
*/
public CellposeBuilder groundTruthDirectory(File groundTruthDirectory) {
this.groundTruthDirectory = groundTruthDirectory;
Expand Down Expand Up @@ -771,6 +783,8 @@ public Cellpose2D build() {
// Give it the number of threads to use
cellpose.nThreads = nThreads;

cellpose.useGPU = useGPU;

// Check the model. If it is a file, then it is a custom model
File file = new File(this.modelNameOrPath);
if (file.exists()) {
Expand Down
31 changes: 21 additions & 10 deletions src/main/java/qupath/ext/biop/cmd/VirtualEnvironmentRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class VirtualEnvironmentRunner {
private final EnvType envType;
private WatchService watchService;
private String name;
private String environmentNameOrPath;
private String pythonPath;

private List<String> arguments;

Expand Down Expand Up @@ -61,7 +61,7 @@ public String toString() {
}

public VirtualEnvironmentRunner(String environmentNameOrPath, EnvType type, String name) {
this.environmentNameOrPath = environmentNameOrPath;
this.pythonPath = environmentNameOrPath;
this.envType = type;
this.name = name;
if (envType.equals(EnvType.OTHER))
Expand All @@ -84,27 +84,29 @@ private List<String> getActivationCommand() {
case CONDA:
switch (platform) {
case WINDOWS:
cmd.addAll(Arrays.asList("CALL", "conda.bat", "activate", environmentNameOrPath, "&", "python"));
// Adjust path to the folder with the env name based on the python location. On Windows it's at the root of the environment
cmd.addAll(Arrays.asList("CALL", "conda.bat", "activate", new File(pythonPath).getParent(), "&", "python"));
break;
case UNIX:
case OSX:
cmd.addAll(Arrays.asList("conda", "activate", environmentNameOrPath, ";", "python"));
// Adjust path to the folder with the env name based on the python location. In Linux/MacOS it's in the 'bin' sub folder
cmd.addAll(Arrays.asList("conda", "activate", new File(pythonPath).getParentFile().getParent(), ";", "python"));
break;
}
break;
case VENV:
switch (platform) {
case WINDOWS:
cmd.add(new File(environmentNameOrPath, "Scripts/python").getAbsolutePath());
cmd.add(new File(pythonPath, "Scripts/python").getAbsolutePath());
break;
case UNIX:
case OSX:
cmd.add(new File(environmentNameOrPath, "bin/python").getAbsolutePath());
cmd.add(new File(pythonPath, "bin/python").getAbsolutePath());
break;
}
break;
case EXE:
cmd.add(environmentNameOrPath);
cmd.add(pythonPath);
break;
case OTHER:
return null;
Expand All @@ -123,10 +125,10 @@ public void setArguments(List<String> arguments) {

/**
* This builds, runs the command and outputs it to the logger as it is being run
*
* @throws IOException // In case there is an issue starting the process
* @param waitUntilDone whether to wait for the process to end or not before exiting this method
* @throws IOException in case there is an issue with the process
*/
public void runCommand() throws IOException {
public void runCommand(boolean waitUntilDone) throws IOException {

// Get how to start the command, based on the VENV Type
List<String> command = getActivationCommand();
Expand Down Expand Up @@ -207,6 +209,15 @@ public void run() {


logger.info("Virtual Environment Runner Started");

// If we ask to wait, let's wait directly here rather than handle it outside
if(waitUntilDone) {
try {
this.process.waitFor();
} catch (InterruptedException e) {
logger.error(e.getMessage());
}
}
}

public Process getProcess() {
Expand Down

0 comments on commit 0b046e0

Please sign in to comment.