diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index dbb2f1f5..1b074881 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -66,12 +66,14 @@ def __init__(self, neural: Union[torch.Tensor, npt.NDArray], continuous: Union[torch.Tensor, npt.NDArray] = None, discrete: Union[torch.Tensor, npt.NDArray] = None, + discrete_time: Union[torch.Tensor, npt.NDArray] = None, offset: Offset = Offset(0, 1), device: str = "cpu"): super().__init__(device=device) self.neural = self._to_tensor(neural, check_dtype="float").float() self.continuous = self._to_tensor(continuous, check_dtype="float") self.discrete = self._to_tensor(discrete, check_dtype="int") + self.discrete_time = self._to_tensor(discrete_time, check_dtype="int") if self.continuous is None and self.discrete is None: raise ValueError( "You have to pass at least one of the arguments 'continuous' or 'discrete'." diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 7802b787..1c1e934d 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -105,6 +105,8 @@ class DiscreteDataLoader(cebra_data.Loader): """, ) + num_negatives: int = dataclasses.field(default=None) + @property def index(self): """The (discrete) dataset index.""" @@ -148,8 +150,14 @@ def get_indices(self, num_samples: int) -> BatchIndex: Returns: Indices for reference, positive and negatives samples. """ - reference_idx = self.distribution.sample_prior(num_samples * 2) - negative_idx = reference_idx[num_samples:] + if self.num_negatives is None: + num_negatives = num_samples + else: + num_negatives = self.num_negatives + + reference_idx = self.distribution.sample_prior(num_negatives + + num_samples) + negative_idx = reference_idx[num_negatives:] reference_idx = reference_idx[:num_samples] reference = self.index[reference_idx] positive_idx = self.distribution.sample_conditional(reference) @@ -158,6 +166,407 @@ def get_indices(self, num_samples: int) -> BatchIndex: negative=negative_idx) +@dataclasses.dataclass +class DiscreteTimeDataLoader(cebra_data.Loader): + + prior: str = dataclasses.field( + default="empirical", + doc="""Re-sampling mode for the discrete index. + + The option `empirical` uses label frequencies as they appear in the dataset. + The option `uniform` re-samples the dataset and adjust the frequencies of less + common class labels. + For balanced datasets, it is typically more accurate to stick to the `empirical` + option. + """, + ) + + num_negatives: int = dataclasses.field(default=None) + + time_offset: int = dataclasses.field(default=10) + + @property + def index(self): + """The (discrete) dataset index.""" + return self.dataset.discrete_index + + @property + def index_time(self): + return self.dataset.discrete_time + + @property + def max_index_time(self): + return self.index_time.max() + + def __post_init__(self): + super().__post_init__() + if self.dataset.discrete_index is None: + raise ValueError("Dataset does not provide a discrete index.") + self._init_distribution() + + def _init_distribution(self): + self.distribution = cebra.distributions.discrete.DiscreteEmpirical( + self.index) + + assert len(self.index) == len(self.index_time) + self.num_samples = len(self.index_time) + + # def sample_indices(self, index_label, index_time, reference_label, + # reference_time): + # indices = [] + + # # Create the full boolean mask for all indices at once + # mask = (index_label.unsqueeze(1) == reference_label) & ( + # index_time.unsqueeze(1) == reference_time + # ) #mask is of shape (len(index_label), batch_size) + + # non_zero_indices = mask.nonzero(as_tuple=True) + + # # Iterate over unique pairs in mask and randomly sample + # for i in range(len(reference_label)): + + # # Extract indices for current pair + # idx_ = non_zero_indices[0][non_zero_indices[1] == i] + + # # Sample a random index from these matching indices + # #assert len(idx_) == 150, print("Index length", len(idx_)) + # random_idx = torch.randint(0, len(idx_), (1,)).item() + # indices.append(idx_[random_idx]) + + # indices = torch.stack(indices) + # return indices + + def sample_indices(self, index_label, index_time, reference_label, + reference_time): + + # Create the full boolean mask, shape: (len(index_label), batch_size) + mask = (index_label.unsqueeze(1) == reference_label) & ( + index_time.unsqueeze(1) == reference_time) + + # Identify all non-zero indices + non_zero_indices = mask.nonzero(as_tuple=True) + + # Use unique labels in the reference to identify matching groups + unique_reference_idx = non_zero_indices[1].unique() + #assert unique_reference_idx.shape == (num_samples,) + + # Get a count of indices for each unique reference label for sampling + counts = torch.bincount(non_zero_indices[1]) + #NOTE: should the counts be the same? the way its written, it assumes they are. + + # Generate a random choice per reference label + #NOTE: right now im creating the randon_index in cuda, what is the best way to do this? + random_index = (torch.rand(len(counts), device=self.device) * + counts).int() + + # Map random offsets to indices and store results + # If counts == 100, index_offsets == [0, 100, 200, 300, ...] + index_offsets = torch.cumsum( + torch.cat((torch.tensor([0], device=self.device), counts[:-1])), 0) + selected_indices = non_zero_indices[0][index_offsets + + random_index % counts] + + return selected_indices + + #def get_ref_index() + + def get_indices(self, num_samples: int) -> BatchIndex: + + if self.num_negatives is None: + num_negatives = num_samples + else: + num_negatives = self.num_negatives + + # reference / negative for discrete labels + reference_idx_discrete = self.distribution.sample_prior(num_negatives + + num_samples) + negative_idx_discrete = reference_idx_discrete[num_negatives:] + reference_idx_discrete = reference_idx_discrete[:num_samples] + + # reference / negative for time + #TODO: we only want to have reference_idx_time whee positive is not in the next trial. + reference_idx_time = torch.randint(0, + self.num_samples - self.time_offset, + (num_samples * 2, )) + + negative_idx_time = reference_idx_time[num_samples:] + reference_idx_time = reference_idx_time[:num_samples] + + # refence combined + reference_discrete = self.index[reference_idx_discrete] + reference_time = self.index_time[reference_idx_time] + reference_idx_combined = self.sample_indices( + self.index, + self.index_time, + reference_discrete, + reference_time, + ) + + # negative combined + negative_discrete = self.index[negative_idx_discrete] + negative_time = self.index_time[negative_idx_time] + negative_idx_combined = self.sample_indices(self.index, + self.index_time, + negative_discrete, + negative_time) + + # positive combined + positive_idx_time = reference_idx_time + self.time_offset + positive_time = self.index_time[positive_idx_time] + positive_idx_combined = self.sample_indices( + self.index, + self.index_time, + reference_discrete, + positive_time, + ) + + return BatchIndex(reference=reference_idx_combined, + positive=positive_idx_combined, + negative=negative_idx_combined) + + +@dataclasses.dataclass +class DiscreteTimeDataLoaderV2(cebra_data.Loader): + + prior: str = dataclasses.field( + default="empirical", + doc="""Re-sampling mode for the discrete index. + + The option `empirical` uses label frequencies as they appear in the dataset. + The option `uniform` re-samples the dataset and adjust the frequencies of less + common class labels. + For balanced datasets, it is typically more accurate to stick to the `empirical` + option. + """, + ) + + num_negatives: int = dataclasses.field(default=None) + + time_offset: int = dataclasses.field(default=10) + + @property + def index(self): + """The (discrete) dataset index.""" + return self.dataset.discrete_index + + @property + def index_time(self): + return self.dataset.discrete_time + + @property + def max_index_time(self): + return self.index_time.max() + + def __post_init__(self): + super().__post_init__() + if self.dataset.discrete_index is None: + raise ValueError("Dataset does not provide a discrete index.") + self._init_distribution() + + self.valid_indices = self.compute_valid_indices( + self.index, self.index_time, offset=self.time_offset) + + assert self.num_negatives is not None + + def _init_distribution(self): + self.distribution = cebra.distributions.discrete.DiscreteEmpirical( + self.index) + + assert len(self.index) == len(self.index_time) + self.num_samples = len(self.index_time) + + # if self.num_negatives is None: + # num_negatives = num_samples + # else: + # num_negatives = self.num_negatives + + def compute_valid_indices(self, tensor1, tensor2, offset=1): + + # Get unique values in each tensor + unique_vals1 = tensor1.unique() + unique_vals2 = tensor2.unique() + + # Create a meshgrid of all combinations + grid1, grid2 = torch.meshgrid(unique_vals1, + unique_vals2, + indexing='ij') + comb_grid = torch.stack((grid1.flatten(), grid2.flatten()), dim=1) + + # remove indices where tensor2 is 2 + offset = 1 + + comb_grid = comb_grid[comb_grid[:, 1] < (tensor2.max() + 1 - offset)] + + # Stack tensor1 and tensor2 for easier comparison + stacked = torch.stack((tensor1, tensor2), dim=1) + + # For each unique combination, check which rows in stacked match it + matches = (stacked[:, None, :] == comb_grid).all(dim=2) + # shape = (len(index), # unique combinations) + + valid_indices = torch.nonzero(matches, as_tuple=True)[0] + #random_indices = torch.randint(0, len(valid_indices), (batch_size, )) + + return valid_indices + + def get_indices(self, num_samples: int) -> BatchIndex: + + reference_idx = torch.randint(0, len(self.valid_indices), + (self.num_negatives + num_samples, )) + negative_idx = reference_idx[:self.num_negatives] + reference_idx = reference_idx[self.num_negatives:] + + # positive combined + positive_idx_time = reference_idx + self.time_offset + # positive_time = self.index_time[positive_idx_time] + # positive_idx_combined = self.sample_indices( + # self.index, + # self.index_time, + # reference_discrete, + # positive_time, + # ) + #print(positive_idx_time.shape, negative_idx.shape) + return BatchIndex(reference=reference_idx, + positive=positive_idx_time, + negative=negative_idx) + + +@dataclasses.dataclass +class DiscreteTimeDataLoaderV3(cebra_data.Loader): + + prior: str = dataclasses.field( + default="empirical", + doc="""Re-sampling mode for the discrete index. + + The option `empirical` uses label frequencies as they appear in the dataset. + The option `uniform` re-samples the dataset and adjust the frequencies of less + common class labels. + For balanced datasets, it is typically more accurate to stick to the `empirical` + option. + """, + ) + + num_negatives: int = dataclasses.field(default=None) + + time_offset: int = dataclasses.field(default=10) + + @property + def index(self): + """The (discrete) dataset index.""" + return self.dataset.discrete_index + + @property + def index_time(self): + return self.dataset.discrete_time + + @property + def max_index_time(self): + return self.index_time.max() + + def __post_init__(self): + super().__post_init__() + if self.dataset.discrete_index is None: + raise ValueError("Dataset does not provide a discrete index.") + self._init_distribution() + + self.valid_indices = self.compute_valid_indices( + self.index, self.index_time, offset=self.time_offset) + + assert self.num_negatives is not None + + def _init_distribution(self): + self.distribution = cebra.distributions.discrete.DiscreteEmpirical( + self.index) + + assert len(self.index) == len(self.index_time) + self.num_samples = len(self.index_time) + + # if self.num_negatives is None: + # num_negatives = num_samples + # else: + # num_negatives = self.num_negatives + + def sample_indices(self, index_label, index_time, reference_label, + reference_time): + + # Create the full boolean mask, shape: (len(index_label), batch_size) + mask = (index_label.unsqueeze(1) == reference_label) & ( + index_time.unsqueeze(1) == reference_time) + + # Identify all non-zero indices + non_zero_indices = mask.nonzero(as_tuple=True) + + # Use unique labels in the reference to identify matching groups + unique_reference_idx = non_zero_indices[1].unique() + #assert unique_reference_idx.shape == (num_samples,) + + # Get a count of indices for each unique reference label for sampling + counts = torch.bincount(non_zero_indices[1]) + #NOTE: should the counts be the same? the way its written, it assumes they are. + + # Generate a random choice per reference label + #NOTE: right now im creating the randon_index in cuda, what is the best way to do this? + random_index = (torch.rand(len(counts), device=self.device) * + counts).int() + + # Map random offsets to indices and store results + # If counts == 100, index_offsets == [0, 100, 200, 300, ...] + index_offsets = torch.cumsum( + torch.cat((torch.tensor([0], device=self.device), counts[:-1])), 0) + selected_indices = non_zero_indices[0][index_offsets + + random_index % counts] + + return selected_indices + + def compute_valid_indices(self, tensor1, tensor2, offset=1): + + # Get unique values in each tensor + unique_vals1 = tensor1.unique() + unique_vals2 = tensor2.unique() + + # Create a meshgrid of all combinations + grid1, grid2 = torch.meshgrid(unique_vals1, + unique_vals2, + indexing='ij') + comb_grid = torch.stack((grid1.flatten(), grid2.flatten()), dim=1) + + # remove indices where tensor2 is 2 + offset = 1 + + comb_grid = comb_grid[comb_grid[:, 1] < (tensor2.max() + 1 - offset)] + + # Stack tensor1 and tensor2 for easier comparison + stacked = torch.stack((tensor1, tensor2), dim=1) + + # For each unique combination, check which rows in stacked match it + matches = (stacked[:, None, :] == comb_grid).all(dim=2) + # shape = (len(index), # unique combinations) + + valid_indices = torch.nonzero(matches, as_tuple=True)[0] + #random_indices = torch.randint(0, len(valid_indices), (batch_size, )) + + return valid_indices + + def get_indices(self, num_samples: int) -> BatchIndex: + + reference_idx = torch.randint(0, len(self.valid_indices), + (self.num_negatives + num_samples, )) + negative_idx = reference_idx[:self.num_negatives] + reference_idx = reference_idx[self.num_negatives:] + + # positive combined + positive_idx_time = reference_idx + self.time_offset + positive_time = self.index_time[positive_idx_time] + positive_idx_combined = self.sample_indices( + self.index, self.index_time, + self.index[reference_idx].to(self.device), + self.index_time[positive_idx_time].to(self.device)) + #print(positive_idx_time.shape, negative_idx.shape) + return BatchIndex(reference=reference_idx, + positive=positive_idx_combined, + negative=negative_idx) + + @dataclasses.dataclass class ContinuousDataLoader(cebra_data.Loader): """Contrastive learning conditioned on a continuous behavior variable. @@ -361,7 +770,9 @@ def __post_init__(self): num_samples=len(self.dataset.neural), device=self.device) self.behavior_distribution = cebra.distributions.TimedeltaDistribution( - self.dataset.continuous_index, self.time_offset, device=self.device) + self.dataset.continuous_index, + self.time_offset, + device=self.device) def get_indices(self, num_samples: int) -> BatchIndex: """Samples indices for reference, positive and negative examples. diff --git a/cebra/distributions/discrete.py b/cebra/distributions/discrete.py index 531dfc16..d43b1f44 100644 --- a/cebra/distributions/discrete.py +++ b/cebra/distributions/discrete.py @@ -45,8 +45,8 @@ class Discrete(abc_.ConditionalDistribution, abc_.HasGenerator): samples: Discrete index used for sampling """ - def _to_numpy_int(self, samples: Union[torch.Tensor, - npt.NDArray]) -> npt.NDArray: + def _to_numpy_int( + self, samples: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray: if isinstance(samples, torch.Tensor): samples = samples.cpu().numpy() if not cebra.helper._is_integer(samples): @@ -79,8 +79,9 @@ def num_samples(self) -> int: def _init_transform(self): self.counts = np.bincount(self.samples) - self.cdf = np.zeros((len(self.counts) + 1,)) + self.cdf = np.zeros((len(self.counts) + 1, )) self.cdf[1:] = np.cumsum(self.counts) + #print("cdf", self.cdf) # NOTE(stes): This is the only use of a scipy function in the entire code # base for now. Replacing scipy.interpolate.interp1d with an equivalent # function from torch would make it possible to drop scipy as a dependency @@ -104,7 +105,7 @@ def sample_uniform(self, num_samples: int) -> torch.Tensor: index samples of this instance with the returned in indices will yield a uniform distribution across the discrete values. """ - samples = np.random.uniform(0, self.num_samples, (num_samples,)) + samples = np.random.uniform(0, self.num_samples, (num_samples, )) samples = self.transform(samples).astype(int) return self.sorted_idx[samples] @@ -118,10 +119,11 @@ def sample_empirical(self, num_samples: int) -> torch.Tensor: A batch of indices from the empirical distribution, which is the uniform distribution over ``[0, N-1]``. """ - samples = np.random.randint(0, self.num_samples, (num_samples,)) + samples = np.random.randint(0, self.num_samples, (num_samples, )) return self.sorted_idx[samples] - def sample_conditional(self, reference_index: torch.Tensor) -> torch.Tensor: + def sample_conditional(self, + reference_index: torch.Tensor) -> torch.Tensor: """Draw samples conditional on template samples. Args: diff --git a/cebra/models/model.py b/cebra/models/model.py index 7631ba86..aac6036e 100644 --- a/cebra/models/model.py +++ b/cebra/models/model.py @@ -101,7 +101,8 @@ def __init__( super().__init__() if num_input < 1: raise ValueError( - f"Input dimension needs to be at least 1, but got {num_input}.") + f"Input dimension needs to be at least 1, but got {num_input}." + ) if num_output < 1: raise ValueError( f"Output dimension needs to be at least 1, but got {num_output}." @@ -216,8 +217,8 @@ def __init__(self, super().__init__(num_input=num_input, num_output=num_output) if normalize: - layers += (cebra_layers._Norm(),) - layers += (cebra_layers.Squeeze(),) + layers += (cebra_layers._Norm(), ) + layers += (cebra_layers.Squeeze(), ) self.net = nn.Sequential(*layers) # TODO(stes) can this layer be removed? it is already added to # the self.net @@ -249,8 +250,8 @@ def num_parameters(self) -> int: @property def num_trainable_parameters(self) -> int: """Number of trainable parameters.""" - return sum( - param.numel() for param in self.parameters() if param.requires_grad) + return sum(param.numel() for param in self.parameters() + if param.requires_grad) @register("offset10-model") @@ -279,6 +280,43 @@ def get_offset(self) -> cebra.data.datatypes.Offset: return cebra.data.Offset(5, 5) +@register("offset10-model-dropout") +class Offset10ModelDropout(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 10 sample receptive field.""" + + def __init__( + self, + num_neurons, + num_units, + num_output, + dropout_rate, + normalize=True, + ): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + nn.Dropout1d(dropout_rate), + cebra_layers._Skip(nn.Dropout1d(dropout_rate), + nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Dropout1d(dropout_rate), + nn.Conv1d(num_units, num_units, 3), nn.GELU()), + cebra_layers._Skip(nn.Dropout1d(dropout_rate), + nn.Conv1d(num_units, num_units, 3), nn.GELU()), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(5, 5) + + @register("offset10-model-mse") class Offset10ModelMSE(Offset10Model): """Symmetric model with 10 sample receptive field, without normalization. @@ -290,6 +328,25 @@ def __init__(self, num_neurons, num_units, num_output, normalize=False): super().__init__(num_neurons, num_units, num_output, normalize) +@register("offset10-model-mse-dropout") +class Offset10ModelMSEDropout(Offset10ModelDropout): + """Symmetric model with 10 sample receptive field, without normalization. + + Suitable for use with InfoNCE metrics for Euclidean space. + """ + + def __init__( + self, + num_neurons, + num_units, + num_output, + dropout_rate, + normalize=False, + ): + super().__init__(num_neurons, num_units, num_output, normalize, + dropout_rate) + + @register("offset5-model") class Offset5Model(_OffsetModel, ConvolutionalModelMixin): """CEBRA model with a 5 sample receptive field and output normalization.""" @@ -421,9 +478,11 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): num_units, ), nn.GELU(), - cebra_layers._Skip(nn.Linear(num_units, num_units), crop=(0, None)), + cebra_layers._Skip(nn.Linear(num_units, num_units), + crop=(0, None)), nn.GELU(), - cebra_layers._Skip(nn.Linear(num_units, num_units), crop=(0, None)), + cebra_layers._Skip(nn.Linear(num_units, num_units), + crop=(0, None)), nn.GELU(), nn.Linear(num_units, num_output), num_input=num_neurons, @@ -497,13 +556,17 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): num_units, ), nn.GELU(), - cebra_layers._Skip(nn.Linear(num_units, num_units), crop=(0, None)), + cebra_layers._Skip(nn.Linear(num_units, num_units), + crop=(0, None)), nn.GELU(), - cebra_layers._Skip(nn.Linear(num_units, num_units), crop=(0, None)), + cebra_layers._Skip(nn.Linear(num_units, num_units), + crop=(0, None)), nn.GELU(), - cebra_layers._Skip(nn.Linear(num_units, num_units), crop=(0, None)), + cebra_layers._Skip(nn.Linear(num_units, num_units), + crop=(0, None)), nn.GELU(), - cebra_layers._Skip(nn.Linear(num_units, num_units), crop=(0, None)), + cebra_layers._Skip(nn.Linear(num_units, num_units), + crop=(0, None)), nn.GELU(), nn.Linear(num_units, num_output), num_input=num_neurons, @@ -548,7 +611,8 @@ def get_offset(self) -> cebra.data.datatypes.Offset: @register("resample5-model", deprecated=True) @register("offset20-model-4x-subsample") -class Resample5Model(_OffsetModel, ConvolutionalModelMixin, ResampleModelMixin): +class Resample5Model(_OffsetModel, ConvolutionalModelMixin, + ResampleModelMixin): """CEBRA model with 20 sample receptive field, output normalization and 4x subsampling.""" ##120Hz diff --git a/cebra/solver/base.py b/cebra/solver/base.py index ea951c7c..7f68cbb6 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -85,6 +85,7 @@ class Solver(abc.ABC, cebra.io.HasDevice): "accuracy_train": [], "accuracy_valid": [], "temperature": [], + "weight_norm": [], })) tqdm_on: bool = True @@ -212,11 +213,12 @@ def fit( is not None) and (num_steps % valid_frequency == 0) if run_validation: - valid_loss, accuracy_train, accuracy_valid = self.validation( + valid_loss, accuracy_train, accuracy_valid, weight_norm = self.validation( loader, valid_loader) self.log["total_valid"].append(valid_loss) self.log["accuracy_train"].append(accuracy_train) self.log["accuracy_valid"].append(accuracy_valid) + self.log["weight_norm"].append(weight_norm) # validation_metrics = None @@ -326,11 +328,18 @@ def validation(self, accuracy_train = accuracy_score(train_labels, prediction_train) accuracy_valid = accuracy_score(valid_labels, prediction_valid) + n_train_labels = len(np.unique(train_labels)) + n_valid_labels = len(np.unique(valid_labels)) + + weight_norm = 0 + for param in self.model.parameters(): + weight_norm += torch.sum(param**2) + print( - f"Accuracy train: {accuracy_train:.2f}, accuracy test: {accuracy_valid:.2f}" + f"Accuracy train[{n_train_labels} labels]: {accuracy_train:.2f}, accuracy test[{n_valid_labels} labels]: {accuracy_valid:.2f}" ) - return valid_loss, accuracy_train, accuracy_valid + return valid_loss, accuracy_train, accuracy_valid, weight_norm.item() # @torch.no_grad() # def decoding(self, train_loader, valid_loader): diff --git a/tests/test_loader.py b/tests/test_loader.py index 562f64a7..278284ed 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -27,7 +27,7 @@ def parametrize_device(func): - _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) + _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu", ) return pytest.mark.parametrize("device", _devices)(func) @@ -49,7 +49,7 @@ class RandomDataset(cebra.data.SingleSessionDataset): def __init__(self, N=100, d=5, device="cpu"): super().__init__(device=device) self._cindex = torch.randint(0, 5, (N, d), device=device).float() - self._dindex = torch.randint(0, 5, (N,), device=device).long() + self._dindex = torch.randint(0, 5, (N, ), device=device).long() self.neural = self._data = torch.randn((N, d), device=device) @property @@ -338,3 +338,55 @@ def test_multisession_loader(data_name, loader_initfunc, device): _check_attributes(batch, is_list=True) for session_batch in batch: assert len(session_batch.positive) == 32 + + +def test_discrete_loader_with_offset(): + + n_stimuli = 4 + block_size = 1000 + neural_data = torch.randn(n_stimuli * block_size, 10) + + index_discrete = torch.tensor( + [i for i in range(n_stimuli) for _ in range(block_size)]) + + trial_length = 200 + index_time = torch.cat([ + torch.arange(trial_length) + for _ in range((n_stimuli * block_size) // trial_length) + ]) + + dataset = cebra.data.TensorDataset( + neural_data.type(torch.FloatTensor), + discrete=index_discrete.type(torch.LongTensor), + discrete_time=index_time.type(torch.LongTensor), + ) + + batch_size = 500 + time_offset = 10 + dataloader = cebra.data.single_session.DiscreteTimeDataLoader( + dataset=dataset, + num_steps=1, + batch_size=batch_size, + time_offset=time_offset) + index = dataloader.get_indices(batch_size) + + index.reference[0] + index.positive[0] + + collect_time = [] + for i in range(batch_size): + assert index_discrete[index.reference[i]] == index_discrete[ + index.positive[i]] + # assert index_time[ + # index.reference[i]] == index_time[index.positive[i]] - time_offset + + collect_time.append(index_time[index.positive[i] - + time_offset] <= trial_length - + time_offset) + # assert index_time[index.positive[i] - + # time_offset] <= trial_length - time_offset + import numpy as np + print(np.sum(collect_time)) + #print(index_time[index.positive[i]] + time_offset) + + #assert index_time[index.positive[i] - time_offset] <= index_time[index.reference[i]]