Skip to content

Commit

Permalink
differentiable subsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 8, 2024
1 parent 633b1d8 commit 057d66d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
24 changes: 9 additions & 15 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@ class JointReconstructionModule_V1(pl.LightningModule):
def __init__(self, model_name,log_dir="tb_logs"):
super().__init__()

# TODO : use a real reconstruction module
self.reconstruction_model = model_generator(model_name, None)
""" if torch.cuda.is_available():
self.reconstruction_model = self.reconstruction_model.cuda()
else:
self.reconstruction_model.to('cpu') """
#self.reconstruction_model = EmptyModule()
self.loss_fn = nn.MSELoss()
self.ssim_loss = SSIM(window_size=11, size_average=True)

Expand Down Expand Up @@ -91,9 +85,9 @@ def forward(self, x):
self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1)
acquired_cubes = self.acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C

filtering_cubes = subsample(filtering_cube, np.linspace(450, 650, filtering_cube.shape[-1]), np.linspace(450, 650, 28)).permute((0, 3, 1, 2))
filtering_cubes = subsample(filtering_cube, torch.linspace(450, 650, filtering_cube.shape[-1]), torch.linspace(450, 650, 28)).permute((0, 3, 1, 2)).float().to(self.device)

reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes.to(self.device))
reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes)


return reconstructed_cube
Expand Down Expand Up @@ -227,14 +221,14 @@ def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_
pix_j_col_value = np.random.randint(0,x)

pix_j_ref = ref_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy()
pixe_j_reconstructed = recontructed_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy()
axs[i].plot(pixe_j_reconstructed, label="pix reconstructed" + str(j),c=colors[j])
pix_j_reconstructed = recontructed_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy()
axs[i].plot(pix_j_reconstructed, label="pix reconstructed" + str(j),c=colors[j])
axs[i].plot(pix_j_ref, label="pix" + str(j), linestyle='--',c=colors[j])

axs[i].set_title(f"Reconstruction quality")

axs[i].set_xlabel("Wavelength index")
axs[i].set_ylabel("pxie values")
axs[i].set_ylabel("pix values")
axs[i].grid(True)

plt.legend()
Expand All @@ -257,12 +251,12 @@ def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_

def subsample(input, origin_sampling, target_sampling):
[bs, row, col, nC] = input.shape
output = torch.zeros(bs, row, col, len(target_sampling))
indices = torch.zeros(len(target_sampling), dtype=torch.int)
for i in range(len(target_sampling)):
sample = target_sampling[i]
idx = np.abs(origin_sampling-sample).argmin()
output[:,:,:,i] = input[:,:,:,idx]
return output
idx = torch.abs(origin_sampling-sample).argmin()
indices[i] = idx
return input[:,:,:,indices]

def expand_mask_3d(mask_batch):
if len(mask_batch.shape)==3:
Expand Down
19 changes: 13 additions & 6 deletions training_simca_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from optimization_modules import JointReconstructionModule_V1
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import torch


data_dir = "./datasets_reconstruction/"
data_dir = "./datasets_reconstruction/cave_1024_28"
#data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28"

datamodule = CubesDataModule(data_dir, batch_size=32, num_workers=1)
datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=1)

name = "testing_simca_reconstruction"
model_name = "mst_plus_plus"
Expand All @@ -35,9 +36,15 @@

reconstruction_module = JointReconstructionModule_V1(model_name,log_dir=log_dir+'/'+ name)

trainer = pl.Trainer( logger=logger,
accelerator="gpu",
max_epochs=500,
log_every_n_steps=1)
if torch.cuda.is_available():
trainer = pl.Trainer( logger=logger,
accelerator="gpu",
max_epochs=500,
log_every_n_steps=1)
else:
trainer = pl.Trainer( logger=logger,
accelerator="cpu",
max_epochs=500,
log_every_n_steps=1)

trainer.fit(reconstruction_module, datamodule)

0 comments on commit 057d66d

Please sign in to comment.