diff --git a/pipeline/download_gene.py b/pipeline/download_gene.py index 54f09d7..a39336b 100644 --- a/pipeline/download_gene.py +++ b/pipeline/download_gene.py @@ -51,6 +51,14 @@ def parse_args(): Path to output directory to save results. """, ) + parser.add_argument( + "--downsample-ref", + type=int, + default=25, + help="""\ + Downsampling factor of the reference space. + """, + ) parser.add_argument( "--downsample-img", type=int, @@ -79,6 +87,7 @@ def postprocess_dataset( None, ], n_images: int, + downsample_ref: int, ) -> Tuple[np.ndarray, np.ndarray, dict[str, Any]]: """Post process given dataset. @@ -116,7 +125,7 @@ def postprocess_dataset( # TODO: maybe notify the user somehow? continue - section_numbers.append(section_coordinate // 25) + section_numbers.append(section_coordinate // downsample_ref) image_ids.append(img_id) warped_img = 255 - df.warp(img, border_mode="constant", c=img[0, 0, :].tolist()) dataset_np.append(warped_img) @@ -140,6 +149,7 @@ def postprocess_dataset( def main( experiment_id: int, output_dir: Path | str, + downsample_ref: int, downsample_img: int, expression: bool = True, ) -> int: @@ -151,6 +161,10 @@ def main( Gene ID to download. output_dir Directory when results are going to be saved. + downsample_ref + Downscaling of the reference space grid. If set to 1 no + downsampling takes place. The higher the value the smaller the grid + in the reference space and the faster the matrix multiplication. downsample_img Downsampling factor given to Allen API when downloading the images. This factor is going to reduce the size. @@ -174,19 +188,23 @@ def main( logger.info(f"Start downloading experiment ID {experiment_id}") dataset = DatasetDownloader( - experiment_id, downsample_img=downsample_img, include_expression=expression + experiment_id, + downsample_img=downsample_img, + include_expression=expression, + downsample_ref=downsample_ref, ) dataset.fetch_metadata() dataset_gen = dataset.run() axis = CommonQueries.get_axis(experiment_id) dataset_np, expression_np, metadata_dict = postprocess_dataset( - dataset_gen, len(dataset) + dataset_gen, len(dataset), downsample_ref ) metadata_dict["axis"] = axis + metadata_dict["downsample-ref"] = downsample_ref logger.info(f"Saving results of experiment ID {experiment_id}") - np.save(output_dir / f"{experiment_id}.npy", dataset_np) - with open(output_dir / f"{experiment_id}.json", "w") as f: + np.save(output_dir / f"{experiment_id}-{downsample_ref}.npy", dataset_np) + with open(output_dir / f"{experiment_id}-{downsample_ref}.json", "w") as f: json.dump(metadata_dict, f, indent=True, sort_keys=True) if expression_np is not None: diff --git a/pipeline/gene_to_nissl.py b/pipeline/gene_to_nissl.py index 046925b..e7a0804 100644 --- a/pipeline/gene_to_nissl.py +++ b/pipeline/gene_to_nissl.py @@ -111,7 +111,7 @@ def registration( except IndexError: logger.warn( f"One of the gene slice has a section number ({section_number})" - f"out of nissl volume shape {nissl_volume.shape}. This slice is" + f"out of nissl volume shape {nissl_volume.shape}. This slice is " "removed from the pipeline." ) section_numbers_kept.append(False) diff --git a/pipeline/interpolate_gene.py b/pipeline/interpolate_gene.py index 82a0ee9..a27271c 100644 --- a/pipeline/interpolate_gene.py +++ b/pipeline/interpolate_gene.py @@ -153,11 +153,12 @@ def main( section_numbers = [int(s) for s in metadata["section_numbers"]] axis = metadata["axis"] + volume_shape = [num // metadata["downsample-ref"] for num in [13200, 8000, 11400]] # Wrap the data into a GeneDataset class gene_dataset = GeneDataset( section_images, section_numbers, - volume_shape=(528, 320, 456, 3), + volume_shape=(*volume_shape, 3), axis=axis, ) @@ -197,7 +198,7 @@ def main( np.save( output_dir - / f"{experiment_id}-{interpolator_name}-interpolated-{image_type}.npy", + / f"{experiment_id}-{metadata['downsample-ref']}-{interpolator_name}-interpolated-{image_type}.npy", predicted_volume, ) diff --git a/pipeline/nissl_to_ccfv3.py b/pipeline/nissl_to_ccfv3.py index 7a9f8b8..c506a13 100644 --- a/pipeline/nissl_to_ccfv3.py +++ b/pipeline/nissl_to_ccfv3.py @@ -84,16 +84,28 @@ def registration( Nissl volume once the registration transformation are applied. """ logger.info("Compute the registration...") - nii_data = register(reference_volume, moving_volume) - logger.info(f"Max displacements: {np.abs(nii_data).max(axis=(0, 1, 2, 3))}") + nissl_warped = [] + warped_volume = [] - logger.info("Apply transformation to Moving Volume...") - warped_volume = transform(moving_volume, nii_data, interpolator="genericLabel") + for i, (reference, moving, nissl) in enumerate(zip(reference_volume, moving_volume, nissl_volume)): + try: + nii_data = register(reference, moving) + logger.info(f"Max displacements: {np.abs(nii_data).max(axis=(0, 1, 2, 3))}") - logger.info("Apply transformation to Nissl Volume...") - nissl_warped = transform(nissl_volume, nii_data) + logger.info("Apply transformation to Moving Volume...") + warped_volume.append(transform(moving, nii_data, interpolator="genericLabel")) - return warped_volume, nissl_warped + logger.info("Apply transformation to Nissl Volume...") + nissl_warped.append(transform(nissl, nii_data)) + except RuntimeError: + logger.info(f"Registration for slice {i} went wrong...") + warped_volume.append(moving) + nissl_warped.append(nissl) + + if (i + 1) % 5 == 0: + logger.info(f" {i + 1} / {reference_volume.shape[0]} registrations done") + + return np.array(warped_volume), np.array(nissl_warped) def main(