Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions gridnext/gridnet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def forward(self, x):
# - After applying classifiers to each modalities, concatenates along feature dimension before applying corrector
class GridNetHexMM(GridNetHexOddr):
def __init__(self, image_classifier, count_classifier, image_shape, count_shape, grid_shape, n_classes,
use_bn=True, atonce_patch_limit=None, image_f_dim=None, count_f_dim=None):
use_bn=True, atonce_patch_limit=None, atonce_count_limit=None, device="cpu", delay_sending_to_device=True, image_f_dim=None, count_f_dim=None):
if image_f_dim is None:
image_f_dim = n_classes
if count_f_dim is None:
Expand All @@ -207,17 +207,26 @@ def __init__(self, image_classifier, count_classifier, image_shape, count_shape,
self.count_shape = count_shape
self.image_f_dim = image_f_dim
self.count_f_dim = count_f_dim


self.mm_atonce_patch_limit=atonce_patch_limit
#atone patch limit exists already in the super class
self.mm_atonce_count_limit = atonce_count_limit
self.delay_sending_to_device = delay_sending_to_device
self.mode="count"

def _set_mode(self, mode):
if mode == 'image':
self.patch_classifier = self.image_classifier
self.patch_shape = self.image_shape
self.f_dim = self.image_f_dim
elif mode == 'count':
if mode == 'count':
self.patch_classifier = self.count_classifier
self.patch_shape = self.count_shape
self.f_dim = self.count_f_dim
self.mode="count"
self.atonce_patch_limit=self.mm_atonce_count_limit
elif mode == 'image':
self.patch_classifier = self.image_classifier
self.patch_shape = self.image_shape
self.f_dim = self.image_f_dim
self.mode="image"
self.atonce_patch_limit = self.mm_atonce_patch_limit
else:
self.f_dim = self.count_f_dim + self.image_f_dim

Expand All @@ -227,10 +236,31 @@ def patch_predictions(self, x):
x_image, x_count = x

self._set_mode('count')

if self.delay_sending_to_device:
x_image=x_image.to("cpu") #if the image was in the GPU at this point it may be that
# it remains there and there is a copy in the cpu and we lose the reference to it
x_count=x_count.to(self.device)
torch.cuda.empty_cache()

# LES: I will accelerate this testing temporarily:
ppg_count = super(GridNetHexMM, self).patch_predictions(x_count)

self._set_mode('image')
if self.delay_sending_to_device:
x_image = x_image.to(self.device)
x_count = x_count.to("cpu")
torch.cuda.empty_cache()

ppg_image = super(GridNetHexMM, self).patch_predictions(x_image)
self._set_mode('concat')

del x_count,x_image
torch.cuda.empty_cache()

return torch.cat((ppg_count, ppg_image), dim=1)
if self.delay_sending_to_device:
ppg_image = ppg_image.to(self.device)
ppg_count = ppg_count.to(self.device)

self._set_mode('concat')

return torch.cat((ppg_count, ppg_image), dim=1)
33 changes: 33 additions & 0 deletions gridnext/image_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,39 @@
from gridnext.utils import read_annotfile


class ImageAnnDataset(Dataset):
'''
Parameters:
----------
adata: AnnData
AnnData object containing count data (in X/obsm) and image data (in obs) from ST arrays
obs_label: str
column in adata.obs containing the spot labels to predict
obs_img: str
column in adata.obs containing paths to individual spot images
img_transforms: torchvision.Transform
preprocessing transforms to apply to image patches after loading
'''

def __init__(self, adata, obs_img='imgpath', img_transforms=None):
super(ImageAnnDataset, self).__init__()

self.imgfiles = adata.obs[obs_img]

if img_transforms is None:
self.preprocess = Compose([ToTensor()])
else:
self.preprocess = img_transforms

def __len__(self):
return len(self.imgfiles)

def __getitem__(self, idx):
x_image = Image.open(self.imgfiles[idx])
x_image = self.preprocess(x_image).float()

return x_image

class PatchDataset(Dataset):
def __init__(self, img_files, annot_files=None, position_files=None, Visium=True,
img_transforms=None, afile_delim=',', img_ext='jpg', verbose=False):
Expand Down
39 changes: 23 additions & 16 deletions gridnext/llm/scbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@


# Preprocess raw count data for input to scBERT model
def preprocess_scbert(adata, target_depth=1e4, counts_layer=None, min_genes=None, min_depth=None,
gene_names=None):
def preprocess_scbert(adata, target_depth=1e4, counts_layer=None, min_genes=None,
min_depth=None, gene_symbols=None, target_genes=None):
'''
Parameters:
----------
Expand All @@ -24,50 +24,57 @@ def preprocess_scbert(adata, target_depth=1e4, counts_layer=None, min_genes=None
number of counts to normalize each spot to
counts_layer: str or None
layer of adata containing raw counts, or "None" to default to adata.X
obs_label: str or None
column in adata.obs containing spot labels to train on
min_genes: int or None
filter spots with fewer than min_genes
min_depth: int or None
filter spots with fewer than min_counts (prior to depth normalization)
gene_names: path or None
gene_symbols: str or None
column name in adata.var storing gene_symbols matching target_genes
target_genes: path or None
path to single-column CSV file containing ordered list of gene names to pull from adata,
or "None" to default to the default list of gene2vec.
'''
if gene_names is None:
if target_genes is None:
ref_data = pkgutil.get_data('gridnext.llm', 'gene2vec_names.csv').decode('utf-8')
ref_data = StringIO(ref_data)
else:
ref_data = gene_names
ref_data = target_genes
ref_names = pd.read_csv(ref_data, header=None, index_col=0).index

if counts_layer is None:
X = adata.X
else:
X = adata.layers[counts_layer]
counts = sparse.lil_matrix((X.shape[0],len(ref_names)),dtype=np.float32)
ref = ref_names.tolist()
obj = adata.var_names.tolist()

for i in range(len(ref)):
if ref[i] in obj:
loc = obj.index(ref[i])
counts[:,i] = X[:,loc]

counts = counts.tocsr()
counts = sparse.csr_matrix((X.shape[0], len(ref_names)), dtype=np.float32)
new = ad.AnnData(X=counts)
new.var_names = ref
new.obs_names = adata.obs_names
new.obs = adata.obs

var_old=None
# AnnData-based way of populating empty counts matrix:
if gene_symbols is not None:
var_old = adata.var.copy()
adata.var = adata.var.set_index(gene_symbols)
adata.var.index=adata.var.index.astype(str) #LES: Quickfix instead of mod anndata
adata.var_names_make_unique() # handles multiple ENSEMBL with same common name

genes_shared = set(adata.var.index.intersection(ref))
genes_shared=pd.CategoricalIndex(genes_shared)
new[:, genes_shared].X = adata[:, genes_shared].X

if gene_symbols is not None:
adata.var = var_old # undo modification of original AnnData

if min_genes is not None or min_depth is not None:
sc.pp.filter_cells(new, min_genes=min_genes, min_counts=min_depth)

sc.pp.normalize_total(new, target_sum=target_depth)
sc.pp.log1p(new, base=2)

return new


# scBERT model class; functional wrapper around PerformerLM
class scBERT(PerformerLM):
Expand Down
24 changes: 24 additions & 0 deletions gridnext/multimodal_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,30 @@ def __getitem__(self, idx):

return (x_image, x_count), y

#Special requirement for the gmm(f_counts+f(UNI))
class MMFeatureGridDataset(AnnGridDataset):
def __init__(self, adata, obs_label, obsm_img, obs_arr='array', obs_x='x', obs_y='y',
h_st=78, w_st=64, vis_coords=True):
super(MMFeatureGridDataset, self).__init__(adata, obs_label, obs_arr, obs_x=obs_x, obs_y=obs_y,
h_st=h_st, w_st=w_st, use_pcs=False, vis_coords=vis_coords)
self.obsm_img = obsm_img
self.nfeats_img = adata.obsm[obsm_img].shape[1]

def __getitem__(self, idx):
x_count, y_count = super(MMFeatureGridDataset, self).__getitem__(idx)
adata_arr = self.adata[self.adata.obs[self.obs_arr] == self.arrays[idx]]
x_image = torch.zeros(self.nfeats_img, self.h_st, self.w_st)

for imfeats, a_x, a_y in zip(adata_arr.obsm[self.obsm_img],
adata_arr.obs[self.obs_x],
adata_arr.obs[self.obs_y]):
if self.vis_coords:
x, y = pseudo_hex_to_oddr(a_x, a_y)
else:
x, y = a_x, a_y
x_image[:, y, x] = torch.from_numpy(imfeats)

return (x_image, x_count), y_count

############ CURRENTLY DEFUNCT ############

Expand Down
31 changes: 24 additions & 7 deletions gridnext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from sklearn.preprocessing import label_binarize
from sklearn.metrics import confusion_matrix, roc_curve, auc, roc_auc_score

import glob


############### Prediction functions ###############

Expand Down Expand Up @@ -263,15 +265,30 @@ def visium_get_positions_fromfile(position_file):
positions = pd.read_csv(position_file, index_col=0, header=None,
names=["in_tissue", "array_row", "array_col", "pxl_row_in_fullres", "pxl_col_in_fullres"])
return positions

# Given Spaceranger directory, locate file mapping spot barcodes to array/pixel coordinates
def visium_find_position_file(spaceranger_dir):
position_paths = [
os.path.join("outs", "spatial", "tissue_positions.csv"), # Spaceranger >=2.0
os.path.join("outs", "spatial", "tissue_positions_list.csv") # Spaceranger <2.0
]
position_paths = glob.glob(spaceranger_dir+'/**/*.csv', recursive = True)
# tissue_positions.csv # Spaceranger >=2.0
# tissue_positions_list.csv # Spaceranger <2.0
for pos_path in position_paths:
if os.path.exists(os.path.join(spaceranger_dir, pos_path)):
return os.path.join(spaceranger_dir, pos_path)
if os.path.exists(pos_path) and "tissue_positions" in pos_path:
return pos_path
raise ValueError("Cannot location position file for %s" % spaceranger_dir)


# Given Spaceranger directory, locate feature_matrix_files
def find_feature_matrix_files(spaceranger_dir):
existing_paths = glob.glob(spaceranger_dir + '/**', recursive=True)
found = {}
keys = ["matrix", "features", "barcodes"]
values = ["matrix.mtx.gz", "features.tsv.gz", "barcodes.tsv.gz"]
for k, v in zip(keys, values):
for e_path in existing_paths:
if v in e_path:
found[k] = e_path
break
if all(k in found for k in keys):
return found

raise ValueError("Cannot location position file for %s" % spaceranger_dir)