diff --git a/src/sbmlsim/Batch.py b/src/sbmlsim/Batch.py index 21e37c2..7bc211b 100644 --- a/src/sbmlsim/Batch.py +++ b/src/sbmlsim/Batch.py @@ -203,7 +203,7 @@ def generate( number_susceptible, selected_susceptible_mutations, ) = self._get_susceptible_mutations( - n_sus, remaining_aa_positions, sample_gene + n_sus, remaining_aa_positions, sample_gene, sample_label=label ) else: selected_susceptible_mutations = [] @@ -342,7 +342,9 @@ def _get_resistant_mutations(self, n_res, distribution, sample_gene): number_resistant, selected_resistant_mutations, selected_resistant_positions, - ) = self._get_defined_mutations(label="R", n=n_res, distribution=distribution) + ) = self._get_defined_mutations( + label="R", n=n_res, sample_label="R", distribution=distribution + ) # Get amino acid positions that are not altered by selected resistant mutations remaining_aa_positions = [ @@ -354,7 +356,9 @@ def _get_resistant_mutations(self, n_res, distribution, sample_gene): return number_resistant, selected_resistant_mutations, remaining_aa_positions - def _get_susceptible_mutations(self, n_sus, remaining_aa_positions, sample_gene): + def _get_susceptible_mutations( + self, n_sus, remaining_aa_positions, sample_gene, sample_label + ): # choose susceptible mutations for a sample if self.define_susceptibles: @@ -366,6 +370,7 @@ def _get_susceptible_mutations(self, n_sus, remaining_aa_positions, sample_gene) label="S", n=n_sus, distribution="poisson", + sample_label=sample_label, remaining_aa_positions=remaining_aa_positions, ) @@ -434,9 +439,13 @@ def _get_susceptible_mutations(self, n_sus, remaining_aa_positions, sample_gene) return number_susceptible, selected_susceptible_mutations def _get_defined_mutations( - self, label, n, distribution, remaining_aa_positions=None + self, label, n, distribution, sample_label, remaining_aa_positions=None ): + assert not ( + label == "R" and sample_label == "S" + ), "Cannot have resistant mutations in a susceptible sample" + if label == "R": mutation_positions = self.resistant_positions mutations = self.resistant_mutations @@ -457,12 +466,14 @@ def _get_defined_mutations( while True: number_mutations = numpy.random.poisson(n) if label == "R": - if number_mutations > 0 and number_mutations <= len( + if 0 < number_mutations <= len(mutation_positions): + break + if label == "S": + if sample_label == "S" and 0 < number_mutations <= len( mutation_positions ): break - if label == "S": - if number_mutations > 0 and number_mutations <= len( + if sample_label == "R" and number_mutations <= len( mutation_positions ): break