Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/sbmlsim/Batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def generate(
number_susceptible,
selected_susceptible_mutations,
) = self._get_susceptible_mutations(
n_sus, remaining_aa_positions, sample_gene
n_sus, remaining_aa_positions, sample_gene, sample_label=label
)
else:
selected_susceptible_mutations = []
Expand Down Expand Up @@ -342,7 +342,9 @@ def _get_resistant_mutations(self, n_res, distribution, sample_gene):
number_resistant,
selected_resistant_mutations,
selected_resistant_positions,
) = self._get_defined_mutations(label="R", n=n_res, distribution=distribution)
) = self._get_defined_mutations(
label="R", n=n_res, sample_label="R", distribution=distribution
)

# Get amino acid positions that are not altered by selected resistant mutations
remaining_aa_positions = [
Expand All @@ -354,7 +356,9 @@ def _get_resistant_mutations(self, n_res, distribution, sample_gene):

return number_resistant, selected_resistant_mutations, remaining_aa_positions

def _get_susceptible_mutations(self, n_sus, remaining_aa_positions, sample_gene):
def _get_susceptible_mutations(
self, n_sus, remaining_aa_positions, sample_gene, sample_label
):
# choose susceptible mutations for a sample

if self.define_susceptibles:
Expand All @@ -366,6 +370,7 @@ def _get_susceptible_mutations(self, n_sus, remaining_aa_positions, sample_gene)
label="S",
n=n_sus,
distribution="poisson",
sample_label=sample_label,
remaining_aa_positions=remaining_aa_positions,
)

Expand Down Expand Up @@ -434,9 +439,13 @@ def _get_susceptible_mutations(self, n_sus, remaining_aa_positions, sample_gene)
return number_susceptible, selected_susceptible_mutations

def _get_defined_mutations(
self, label, n, distribution, remaining_aa_positions=None
self, label, n, distribution, sample_label, remaining_aa_positions=None
):

assert not (
label == "R" and sample_label == "S"
), "Cannot have resistant mutations in a susceptible sample"

if label == "R":
mutation_positions = self.resistant_positions
mutations = self.resistant_mutations
Expand All @@ -457,12 +466,14 @@ def _get_defined_mutations(
while True:
number_mutations = numpy.random.poisson(n)
if label == "R":
if number_mutations > 0 and number_mutations <= len(
if 0 < number_mutations <= len(mutation_positions):
break
if label == "S":
if sample_label == "S" and 0 < number_mutations <= len(
mutation_positions
):
break
if label == "S":
if number_mutations > 0 and number_mutations <= len(
if sample_label == "R" and number_mutations <= len(
mutation_positions
):
break
Expand Down