From ba8c2d2de8f58fc18277aa65fc22a141b693cd4a Mon Sep 17 00:00:00 2001 From: Leslie Evelyn Solorzano <399351+lesolorzanov@users.noreply.github.com> Date: Tue, 28 Jan 2025 01:30:20 +0100 Subject: [PATCH 1/6] Update image_datasets.py add a dataset "ImageAnnDataset" to bring only images and no counts at all. --- gridnext/image_datasets.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/gridnext/image_datasets.py b/gridnext/image_datasets.py index 360084c..f5c5e1b 100644 --- a/gridnext/image_datasets.py +++ b/gridnext/image_datasets.py @@ -17,6 +17,39 @@ from gridnext.utils import read_annotfile +class ImageAnnDataset(Dataset): + ''' + Parameters: + ---------- + adata: AnnData + AnnData object containing count data (in X/obsm) and image data (in obs) from ST arrays + obs_label: str + column in adata.obs containing the spot labels to predict + obs_img: str + column in adata.obs containing paths to individual spot images + img_transforms: torchvision.Transform + preprocessing transforms to apply to image patches after loading + ''' + + def __init__(self, adata, obs_img='imgpath', img_transforms=None): + super(ImageAnnDataset, self).__init__() + + self.imgfiles = adata.obs[obs_img] + + if img_transforms is None: + self.preprocess = Compose([ToTensor()]) + else: + self.preprocess = img_transforms + + def __len__(self): + return len(self.imgfiles) + + def __getitem__(self, idx): + x_image = Image.open(self.imgfiles[idx]) + x_image = self.preprocess(x_image).float() + + return x_image + class PatchDataset(Dataset): def __init__(self, img_files, annot_files=None, position_files=None, Visium=True, img_transforms=None, afile_delim=',', img_ext='jpg', verbose=False): From 7d4380833aad85681dba627c2e8a6327492726c7 Mon Sep 17 00:00:00 2001 From: Leslie Evelyn Solorzano <399351+lesolorzanov@users.noreply.github.com> Date: Tue, 28 Jan 2025 01:34:04 +0100 Subject: [PATCH 2/6] Update multimodal_datasets.py Adding MMFeatureGridDataset which is needed for the count logits coming from scBERT plus the UNI features to which f_img will be applied to do the final g_mm --- gridnext/multimodal_datasets.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/gridnext/multimodal_datasets.py b/gridnext/multimodal_datasets.py index d23106e..8b5a572 100644 --- a/gridnext/multimodal_datasets.py +++ b/gridnext/multimodal_datasets.py @@ -123,6 +123,30 @@ def __getitem__(self, idx): return (x_image, x_count), y +#Special requirement for the gmm(f_counts+f(UNI)) +class MMFeatureGridDataset(AnnGridDataset): + def __init__(self, adata, obs_label, obsm_img, obs_arr='array', obs_x='x', obs_y='y', + h_st=78, w_st=64, vis_coords=True): + super(MMFeatureGridDataset, self).__init__(adata, obs_label, obs_arr, obs_x=obs_x, obs_y=obs_y, + h_st=h_st, w_st=w_st, use_pcs=False, vis_coords=vis_coords) + self.obsm_img = obsm_img + self.nfeats_img = adata.obsm[obsm_img].shape[1] + + def __getitem__(self, idx): + x_count, y_count = super(MMFeatureGridDataset, self).__getitem__(idx) + adata_arr = self.adata[self.adata.obs[self.obs_arr] == self.arrays[idx]] + x_image = torch.zeros(self.nfeats_img, self.h_st, self.w_st) + + for imfeats, a_x, a_y in zip(adata_arr.obsm[self.obsm_img], + adata_arr.obs[self.obs_x], + adata_arr.obs[self.obs_y]): + if self.vis_coords: + x, y = pseudo_hex_to_oddr(a_x, a_y) + else: + x, y = a_x, a_y + x_image[:, y, x] = torch.from_numpy(imfeats) + + return (x_image, x_count), y_count ############ CURRENTLY DEFUNCT ############ From c02a5da45fd86455620045b747e75c5b4153cc2f Mon Sep 17 00:00:00 2001 From: Leslie Evelyn Solorzano <399351+lesolorzanov@users.noreply.github.com> Date: Tue, 28 Jan 2025 02:28:13 +0100 Subject: [PATCH 3/6] Update preprocess_scbert in scbert.py to have gene_symbols as kwarg in preprocess to ask for a column name in adata.var storing gene_symbols matching target_genes --- gridnext/llm/scbert.py | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/gridnext/llm/scbert.py b/gridnext/llm/scbert.py index 9651f10..67bab2a 100644 --- a/gridnext/llm/scbert.py +++ b/gridnext/llm/scbert.py @@ -13,8 +13,8 @@ # Preprocess raw count data for input to scBERT model -def preprocess_scbert(adata, target_depth=1e4, counts_layer=None, min_genes=None, min_depth=None, - gene_names=None): +def preprocess_scbert(adata, target_depth=1e4, counts_layer=None, min_genes=None, + min_depth=None, gene_symbols=None, target_genes=None): ''' Parameters: ---------- @@ -24,42 +24,50 @@ def preprocess_scbert(adata, target_depth=1e4, counts_layer=None, min_genes=None number of counts to normalize each spot to counts_layer: str or None layer of adata containing raw counts, or "None" to default to adata.X - obs_label: str or None - column in adata.obs containing spot labels to train on min_genes: int or None filter spots with fewer than min_genes min_depth: int or None filter spots with fewer than min_counts (prior to depth normalization) - gene_names: path or None + gene_symbols: str or None + column name in adata.var storing gene_symbols matching target_genes + target_genes: path or None path to single-column CSV file containing ordered list of gene names to pull from adata, or "None" to default to the default list of gene2vec. ''' - if gene_names is None: + if target_genes is None: ref_data = pkgutil.get_data('gridnext.llm', 'gene2vec_names.csv').decode('utf-8') ref_data = StringIO(ref_data) else: - ref_data = gene_names + ref_data = target_genes ref_names = pd.read_csv(ref_data, header=None, index_col=0).index if counts_layer is None: X = adata.X else: X = adata.layers[counts_layer] - counts = sparse.lil_matrix((X.shape[0],len(ref_names)),dtype=np.float32) ref = ref_names.tolist() - obj = adata.var_names.tolist() - for i in range(len(ref)): - if ref[i] in obj: - loc = obj.index(ref[i]) - counts[:,i] = X[:,loc] - - counts = counts.tocsr() + counts = sparse.csr_matrix((X.shape[0], len(ref_names)), dtype=np.float32) new = ad.AnnData(X=counts) new.var_names = ref new.obs_names = adata.obs_names new.obs = adata.obs + var_old=None + # AnnData-based way of populating empty counts matrix: + if gene_symbols is not None: + var_old = adata.var.copy() + adata.var = adata.var.set_index(gene_symbols) + adata.var.index=adata.var.index.astype(str) #LES: Quickfix instead of mod anndata + adata.var_names_make_unique() # handles multiple ENSEMBL with same common name + + genes_shared = set(adata.var.index.intersection(ref)) + genes_shared=pd.CategoricalIndex(genes_shared) + new[:, genes_shared].X = adata[:, genes_shared].X + + if gene_symbols is not None: + adata.var = var_old # undo modification of original AnnData + if min_genes is not None or min_depth is not None: sc.pp.filter_cells(new, min_genes=min_genes, min_counts=min_depth) @@ -67,7 +75,6 @@ def preprocess_scbert(adata, target_depth=1e4, counts_layer=None, min_genes=None sc.pp.log1p(new, base=2) return new - # scBERT model class; functional wrapper around PerformerLM class scBERT(PerformerLM): From 3c5d4abd86dd93b87a3f5256634cb3c172224f7b Mon Sep 17 00:00:00 2001 From: Leslie Evelyn Solorzano <399351+lesolorzanov@users.noreply.github.com> Date: Tue, 28 Jan 2025 02:29:33 +0100 Subject: [PATCH 4/6] Update file-finding abilities utils.py using glob instead of a hardcoded locations to find spaceranger files --- gridnext/utils.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/gridnext/utils.py b/gridnext/utils.py index 4a74ff4..3821f75 100644 --- a/gridnext/utils.py +++ b/gridnext/utils.py @@ -263,15 +263,30 @@ def visium_get_positions_fromfile(position_file): positions = pd.read_csv(position_file, index_col=0, header=None, names=["in_tissue", "array_row", "array_col", "pxl_row_in_fullres", "pxl_col_in_fullres"]) return positions - # Given Spaceranger directory, locate file mapping spot barcodes to array/pixel coordinates def visium_find_position_file(spaceranger_dir): - position_paths = [ - os.path.join("outs", "spatial", "tissue_positions.csv"), # Spaceranger >=2.0 - os.path.join("outs", "spatial", "tissue_positions_list.csv") # Spaceranger <2.0 - ] + position_paths = glob.glob(spaceranger_dir+'/**/*.csv', recursive = True) + # tissue_positions.csv # Spaceranger >=2.0 + # tissue_positions_list.csv # Spaceranger <2.0 for pos_path in position_paths: - if os.path.exists(os.path.join(spaceranger_dir, pos_path)): - return os.path.join(spaceranger_dir, pos_path) + if os.path.exists(pos_path) and "tissue_positions" in pos_path: + return pos_path + raise ValueError("Cannot location position file for %s" % spaceranger_dir) + + +# Given Spaceranger directory, locate feature_matrix_files +def find_feature_matrix_files(spaceranger_dir): + existing_paths = glob.glob(spaceranger_dir + '/**', recursive=True) + found = {} + keys = ["matrix", "features", "barcodes"] + values = ["matrix.mtx.gz", "features.tsv.gz", "barcodes.tsv.gz"] + for k, v in zip(keys, values): + for e_path in existing_paths: + if v in e_path: + found[k] = e_path + break + if all(k in found for k in keys): + return found + raise ValueError("Cannot location position file for %s" % spaceranger_dir) From 6a686dbe4991d5545f21287acd58b85665ac3ad3 Mon Sep 17 00:00:00 2001 From: Leslie Evelyn Solorzano <399351+lesolorzanov@users.noreply.github.com> Date: Tue, 28 Jan 2025 02:43:49 +0100 Subject: [PATCH 5/6] Add missing glob import in utils.py adding the glob library missing in utils --- gridnext/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gridnext/utils.py b/gridnext/utils.py index 3821f75..7cd3a57 100644 --- a/gridnext/utils.py +++ b/gridnext/utils.py @@ -12,6 +12,8 @@ from sklearn.preprocessing import label_binarize from sklearn.metrics import confusion_matrix, roc_curve, auc, roc_auc_score +import glob + ############### Prediction functions ############### From 524766423c26ec393e2434cd6e7c90c2df7a8d3b Mon Sep 17 00:00:00 2001 From: Leslie Evelyn Solorzano <399351+lesolorzanov@users.noreply.github.com> Date: Wed, 29 Jan 2025 00:30:48 +0100 Subject: [PATCH 6/6] A small attempt of GPU memory optimization gridnet_models.py added parameters atonce_count_limit=None, device="cpu", delay_sending_to_device=True, to GridNetHexMM on the hopes of delaying sending ton GPU when the counts and images are full size. Recently, f_count is the logits and f_img is a 1024 vector of UNI instead of the image itself. --- gridnext/gridnet_models.py | 48 +++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/gridnext/gridnet_models.py b/gridnext/gridnet_models.py index 794ff18..31336c3 100644 --- a/gridnext/gridnet_models.py +++ b/gridnext/gridnet_models.py @@ -192,7 +192,7 @@ def forward(self, x): # - After applying classifiers to each modalities, concatenates along feature dimension before applying corrector class GridNetHexMM(GridNetHexOddr): def __init__(self, image_classifier, count_classifier, image_shape, count_shape, grid_shape, n_classes, - use_bn=True, atonce_patch_limit=None, image_f_dim=None, count_f_dim=None): + use_bn=True, atonce_patch_limit=None, atonce_count_limit=None, device="cpu", delay_sending_to_device=True, image_f_dim=None, count_f_dim=None): if image_f_dim is None: image_f_dim = n_classes if count_f_dim is None: @@ -207,17 +207,26 @@ def __init__(self, image_classifier, count_classifier, image_shape, count_shape, self.count_shape = count_shape self.image_f_dim = image_f_dim self.count_f_dim = count_f_dim - + + self.mm_atonce_patch_limit=atonce_patch_limit + #atone patch limit exists already in the super class + self.mm_atonce_count_limit = atonce_count_limit + self.delay_sending_to_device = delay_sending_to_device + self.mode="count" def _set_mode(self, mode): - if mode == 'image': - self.patch_classifier = self.image_classifier - self.patch_shape = self.image_shape - self.f_dim = self.image_f_dim - elif mode == 'count': + if mode == 'count': self.patch_classifier = self.count_classifier self.patch_shape = self.count_shape self.f_dim = self.count_f_dim + self.mode="count" + self.atonce_patch_limit=self.mm_atonce_count_limit + elif mode == 'image': + self.patch_classifier = self.image_classifier + self.patch_shape = self.image_shape + self.f_dim = self.image_f_dim + self.mode="image" + self.atonce_patch_limit = self.mm_atonce_patch_limit else: self.f_dim = self.count_f_dim + self.image_f_dim @@ -227,10 +236,31 @@ def patch_predictions(self, x): x_image, x_count = x self._set_mode('count') + + if self.delay_sending_to_device: + x_image=x_image.to("cpu") #if the image was in the GPU at this point it may be that + # it remains there and there is a copy in the cpu and we lose the reference to it + x_count=x_count.to(self.device) + torch.cuda.empty_cache() + + # LES: I will accelerate this testing temporarily: ppg_count = super(GridNetHexMM, self).patch_predictions(x_count) + self._set_mode('image') + if self.delay_sending_to_device: + x_image = x_image.to(self.device) + x_count = x_count.to("cpu") + torch.cuda.empty_cache() + ppg_image = super(GridNetHexMM, self).patch_predictions(x_image) - self._set_mode('concat') + + del x_count,x_image + torch.cuda.empty_cache() - return torch.cat((ppg_count, ppg_image), dim=1) + if self.delay_sending_to_device: + ppg_image = ppg_image.to(self.device) + ppg_count = ppg_count.to(self.device) + self._set_mode('concat') + + return torch.cat((ppg_count, ppg_image), dim=1)