diff --git a/src/alignment_processor.py b/src/alignment_processor.py index 8e130ed2..e584649a 100644 --- a/src/alignment_processor.py +++ b/src/alignment_processor.py @@ -307,6 +307,7 @@ def process_intergenic(self, alignment_storage): read_assignment.polya_info = alignment_info.polya_info read_assignment.cage_found = len(alignment_info.cage_hits) > 0 read_assignment.exons = alignment_info.read_exons + read_assignment.cigartuples = alignment.cigartuples read_assignment.corrected_exons = alignment_info.read_exons read_assignment.corrected_introns = junctions_from_blocks(read_assignment.corrected_exons) @@ -358,6 +359,7 @@ def process_genic(self, alignment_storage, gene_info): read_assignment.polya_info = alignment_info.polya_info read_assignment.cage_found = len(alignment_info.cage_hits) > 0 read_assignment.exons = alignment_info.read_exons + read_assignment.cigartuples = alignment.cigartuples read_assignment.corrected_exons = exon_corrector.correct_assigned_read(alignment_info, read_assignment) read_assignment.corrected_introns = junctions_from_blocks(read_assignment.corrected_exons) diff --git a/src/graph_based_model_construction.py b/src/graph_based_model_construction.py index 7e94114c..991fbd17 100644 --- a/src/graph_based_model_construction.py +++ b/src/graph_based_model_construction.py @@ -26,6 +26,11 @@ from .long_read_profiles import CombinedProfileConstructor from .polya_finder import PolyAInfo +from .transcript_splice_site_corrector import ( + count_deletions_for_splice_site_locations, + correct_splice_site_errors, + generate_updated_exon_list + ) logger = logging.getLogger('IsoQuant') @@ -130,6 +135,7 @@ def process(self, read_assignment_storage): self.construct_assignment_based_isoforms(read_assignment_storage) self.assign_reads_to_models(read_assignment_storage) self.filter_transcripts() + self.correct_transcripts() if self.params.genedb: self.create_extended_annotation() @@ -198,6 +204,81 @@ def compare_models_with_known(self): model.add_additional_attribute("alternatives", event_string) self.transcript2transcript.append(assignment) + def correct_transcripts(self): + for model in self.transcript_model_storage: + exons = model.exon_blocks + assigned_reads = self.transcript_read_ids[model.transcript_id] + corrected_exons = self.correct_transcript_splice_sites(exons, assigned_reads) + if corrected_exons: + logger.debug(f"correct_transcripts. Corrected exons: {corrected_exons}, original exons: {exons}") + model.exon_blocks = corrected_exons + + def correct_transcript_splice_sites(self, exons: list, assigned_reads: list): + # exons: list of coordinate pairs + # assigned_reads: list of ReadAssignment, contains read_id and cigartuples + # self.chr_record - FASTA recored, i.e. a single chromosome from a reference + # returns: a list of corrected exons if correction takes place, None - otherwise + # TODO Heidi: insert your code here + + # Constants + ACCEPTED_DEL_CASES = [3, 4, 5, 6] + SUPPORTED_STRANDS = ['+', '-'] + THRESHOLD_CASES_AT_LOCATION = 0.7 + MIN_N_OF_ALIGNED_READS = 5 + WINDOW_SIZE = 8 + + MORE_CONSERVATIVE_STRATEGY = False + + + strand = assigned_reads[0].strand + if strand not in SUPPORTED_STRANDS: + return None + + splice_site_cases = {} + # Iterate assigned_reads list and count deletions for splice site locations + for read_assignment in assigned_reads: + read_start = read_assignment.corrected_exons[0][0] + read_end = read_assignment.corrected_exons[-1][1] + cigartuples = read_assignment.cigartuples + if not cigartuples: + # logger.debug(f"Heidi: No cigar tuples for read {read_assignment.read_id}") + continue + # logger.debug(f"Heidi: Cigar tuples for read {read_assignment.read_id}: {cigartuples}") + count_deletions_for_splice_site_locations( + read_start, + read_end, + cigartuples, + exons, + splice_site_cases, + WINDOW_SIZE) + + + + corrected_exons = correct_splice_site_errors( + splice_site_cases, + MIN_N_OF_ALIGNED_READS, + ACCEPTED_DEL_CASES, + THRESHOLD_CASES_AT_LOCATION, + MORE_CONSERVATIVE_STRATEGY, + strand, + self.chr_record + ) + + if not corrected_exons: + return None + + cases = [str(exon) + ": " + str(splice_site_cases[exon]) for exon in corrected_exons] + + + updated_exons = generate_updated_exon_list( + splice_site_cases, + corrected_exons, + exons + ) + + return updated_exons + + def filter_transcripts(self): filtered_storage = [] confirmed_transcipt_ids = set() diff --git a/src/isoform_assignment.py b/src/isoform_assignment.py index 47d73552..6e90a18c 100644 --- a/src/isoform_assignment.py +++ b/src/isoform_assignment.py @@ -477,6 +477,7 @@ def __init__(self, read_id, assignment_type, match=None): self.assignment_id = ReadAssignment.assignment_id_generator.increment() self.read_id = read_id self.exons = None + self.cigartuples = None self.corrected_exons = None self.corrected_introns = None self.gene_info = None @@ -507,6 +508,9 @@ def deserialize(cls, infile, gene_info): read_assignment.assignment_id = read_int(infile) read_assignment.read_id = read_string(infile) read_assignment.exons = read_list_of_pairs(infile, read_int) + read_assignment.cigartuples = read_list_of_pairs(infile, read_int) + if not read_assignment.cigartuples: + read_assignment.cigartuples = None read_assignment.corrected_exons = read_list_of_pairs(infile, read_int) read_assignment.corrected_introns = junctions_from_blocks(read_assignment.corrected_exons) read_assignment.gene_info = gene_info @@ -532,6 +536,10 @@ def serialize(self, outfile): write_int(self.assignment_id, outfile) write_string(self.read_id, outfile) write_list_of_pairs(self.exons, outfile, write_int) + if self.cigartuples is None: + write_list_of_pairs([], outfile, write_int) + else: + write_list_of_pairs(self.cigartuples, outfile, write_int) write_list_of_pairs(self.corrected_exons, outfile, write_int) write_bool_array([self.multimapper, self.polyA_found, self.cage_found], outfile) write_int_neg(self.polya_info.external_polya_pos, outfile) diff --git a/src/transcript_splice_site_corrector.py b/src/transcript_splice_site_corrector.py new file mode 100644 index 00000000..fee5711c --- /dev/null +++ b/src/transcript_splice_site_corrector.py @@ -0,0 +1,314 @@ +import logging +logger = logging.getLogger('IsoQuant') + +def extract_location_from_cigar_string(cigartuples: list, + read_start: int, + read_end: int, + splice_site_location: int): + """ + Extract location from cigar string. + + Args: + cigar_tuples (list): list of cigar tuples (cigar code, aligned position). + See pysam documentation for more information + read_start (int): the start location for the read (base-1) + read_end (int): the end location for the read (base-1) + splice_site_location (int): location of interest (base-1) + + Returns: + _type_: _description_ + """ + relative_position = splice_site_location - read_start + alignment_position = 0 + ref_position = 0 + + for cigar_code in cigartuples: + + if cigar_code[0] in [0, 2, 3, 7, 8]: + ref_position += cigar_code[1] + if ref_position <= relative_position and not \ + read_start + ref_position == read_end: + alignment_position += cigar_code[1] + else: + return alignment_position + (cigar_code[1] - (ref_position - relative_position)) + + return -1 + + +def count_deletions_from_cigar_codes_in_given_window(cigartuples: list, + aligned_location: int, + location_is_end: bool, + splice_site_data: dict, + window_size: int): + """ + Get cigar codes in a given window. + + Args: + cigar_tuples (list): list of cigar tuples (cigar code, aligned position). See + pysam documentation for more information + aligned_location (int): aligned location + loc_type (str): type of location (start or end) + """ + + count_of_deletions = 0 + + + cigar_code_list = [] + location = 0 + + if location_is_end: + aligned_location = aligned_location - window_size + 1 + + for cigar_code in cigartuples: + if window_size == len(cigar_code_list): + break + if location + cigar_code[1] > aligned_location: + overlap = location + \ + cigar_code[1] - (aligned_location + len(cigar_code_list)) + cigar_code_list.extend( + [cigar_code[0] for _ in range(min(window_size - + len(cigar_code_list), overlap))]) + location += cigar_code[1] + + for i in range(window_size): + if i >= len(cigar_code_list): + break + if cigar_code_list[i] == 2: + count_of_deletions += 1 + splice_site_data["del_pos_distr"][i] += 1 + + if count_of_deletions not in splice_site_data["deletions"]: + splice_site_data["deletions"][count_of_deletions] = 0 + + splice_site_data["deletions"][count_of_deletions] += 1 + + +def extract_splice_site_locations_within_aligned_read(read_start: int, read_end: int, exons:list): + matching_locations = [] + for exon_start, exon_end in exons: + if read_start <= exon_start <= read_end: + location_is_end = False + matching_locations.append((exon_start, location_is_end)) + if read_start <= exon_end <= read_end: + location_is_end = True + matching_locations.append((exon_end, location_is_end)) + if read_end <= exon_end: + break + return matching_locations + + +def count_deletions_for_splice_site_locations( + read_start: int, + read_end: int, + cigartuples: list, + exons: list, + splice_site_cases: dict, + WINDOW_SIZE: int): + """ + + Args: + assigned_read (ReadAssignment): read assignment + exons (list): tuple of exons (start, end) + splice_site_cases (dict): a dictionary for storing splice site cases + """ + + + # Extract splice site locations within aligned read + matching_locations = extract_splice_site_locations_within_aligned_read(read_start, read_end, exons) + + logger.debug(f"Matching locations: {matching_locations}") + # Count deletions for each splice site location + for splice_site_location, location_type in matching_locations: + if splice_site_location not in splice_site_cases: + splice_site_cases[splice_site_location] = { + 'location_is_end': location_type, + 'deletions': {}, + 'del_pos_distr': [0 for _ in range(WINDOW_SIZE)], + 'most_common_del': -1, + 'canonical_bases_found': False + } + + # Processing cigartuples + # 1. Find the aligned location + aligned_location = extract_location_from_cigar_string(cigartuples, read_start, read_end, splice_site_location) + # 2. Count deletions in a predefined window + count_deletions_from_cigar_codes_in_given_window( + cigartuples, + aligned_location, + location_type, + splice_site_cases[splice_site_location], + WINDOW_SIZE) + + + +def compute_most_common_case_of_deletions(deletions: dict, location_is_end: bool): + del_most_common_case = [k for k, v in deletions.items( + ) if v == max(deletions.values())] + if len(del_most_common_case) == 1: + if location_is_end: + return -del_most_common_case[0] + return del_most_common_case[0] + return -1 + + +def extract_nucleotides_from_most_common_del_location( + location: int, + splice_site_data: dict, + chr_record, + strand: str): + most_common_del = splice_site_data["most_common_del"] + idx_correction = 0 + extraction_start = location + most_common_del + idx_correction + extraction_end = location + most_common_del + 2 + idx_correction + try: + extracted_canonicals = chr_record[extraction_start:extraction_end] + except KeyError: + extracted_canonicals = 'XX' + + + canonical_pairs = { + '+': { + 'start': ['AG', 'AC'], + 'end': ['GT', 'GC', 'AT'] + }, + '-': { + 'start': ['AC', 'GC', 'AC'], + 'end': ['CT', 'GT'] + } + } + if splice_site_data["location_is_end"]: + possible_canonicals = canonical_pairs[strand]['end'] + else: + possible_canonicals = canonical_pairs[strand]['start'] + if extracted_canonicals in possible_canonicals: + splice_site_data["canonical_bases_found"] = True + +def compute_most_common_del_and_verify_nucleotides( + splice_site_location: int, + splice_site_data: dict, + chr_record, + ACCEPTED_DEL_CASES: list, + strand: str,): + + + # Compute most common case of deletions + splice_site_data["most_common_del"] = compute_most_common_case_of_deletions( + splice_site_data["deletions"], + splice_site_data["location_is_end"]) + + # Extract nucleotides from most common deletion location if it is an accepted case + if abs(splice_site_data["most_common_del"]) in ACCEPTED_DEL_CASES: + extract_nucleotides_from_most_common_del_location( + splice_site_location, + splice_site_data, + chr_record, + strand) + + + +def threshold_for_del_cases_exceeded( + del_pos_distr: list, + deletions: dict, + most_common_del: int, + THRESHOLD_CASES_AT_LOCATION): + total_cases = sum(deletions.values()) + nucleotides_exceeding_treshold = 0 + for value in del_pos_distr: + if value > total_cases * THRESHOLD_CASES_AT_LOCATION: + nucleotides_exceeding_treshold += 1 + return bool(nucleotides_exceeding_treshold >= abs(most_common_del)) + +def sublist_largest_values_exists(lst, n): + """ + Verifies that there is a sublist of size n that contains the largest values in the list. + Not currently in use, but may be included in the error prediction strategy for stricter prediction. + Args: + lst (int): list of deletion distribution + n (int): most common case of deletions + + Returns: + _type_: _description_ + """ + largest_values = set(sorted(lst, reverse=True)[:n]) + count = 0 + + for num in lst: + if num in largest_values: + count += 1 + if count >= n: + return True + else: + count = 0 + + return False + + +def correct_splice_site_errors( + splice_site_cases: dict, + MIN_N_OF_ALIGNED_READS: int, + ACCEPTED_DEL_CASES: list, + THRESHOLD_CASES_AT_LOCATION: float, + MORE_CONSERVATIVE_STRATEGY: bool, + strand: str, + chr_record): + """ 1. Count most common deletion at each splice site location + 2. For interesting cases count nucleotides at deletion positions + 3. If canonical nucleotides are found, correct splice site + + Args: + splice_site_cases (dict): collected splice site cases + MIN_N_OF_ALIGNED_READS (int): constant for minimum number of aligned reads + ACCEPTED_DEL_CASES (list): constant for accepted cases of deletions + MORE_CONSERVATIVE_STRATEGY (bool): constant for more conservative strategy + strand (str): transcript strand (extracted from first ReadAssignment-object in read_assignments list) + chr_record (Fasta): FASTA recored, i.e. a single chromosome from a reference + """ + + + + locations_with_errors = [] + for splice_site_location, splice_site_data in splice_site_cases.items(): + + reads = sum(splice_site_data["deletions"].values()) + if reads < MIN_N_OF_ALIGNED_READS: + continue + + compute_most_common_del_and_verify_nucleotides( + splice_site_location, + splice_site_data, + chr_record, + ACCEPTED_DEL_CASES, + strand + ) + if MORE_CONSERVATIVE_STRATEGY: + if not sublist_largest_values_exists( + splice_site_data["del_pos_distr"], + abs(splice_site_data["most_common_del"])): + continue + if not threshold_for_del_cases_exceeded( + splice_site_data["del_pos_distr"], + splice_site_data["deletions"], + splice_site_data["most_common_del"], + THRESHOLD_CASES_AT_LOCATION): + continue + + if splice_site_data["canonical_bases_found"]: + locations_with_errors.append(splice_site_location) + + return locations_with_errors + +def generate_updated_exon_list( + splice_site_cases: dict, + locations_with_errors: list, + exons: list): + updated_exons = [] + for exon in exons: + updated_exon = exon + if exon[0] in locations_with_errors: + corrected_location = exon[0] + splice_site_cases[exon[0]]["most_common_del"] + updated_exon = (corrected_location, exon[1]) + if exon[1] in locations_with_errors: + corrected_location = exon[1] + splice_site_cases[exon[1]]["most_common_del"] + updated_exon = (exon[0], corrected_location) + updated_exons.append(updated_exon) + return updated_exons \ No newline at end of file diff --git a/tests/test_transcript_splice_site_corrector.py b/tests/test_transcript_splice_site_corrector.py new file mode 100644 index 00000000..010dc7fb --- /dev/null +++ b/tests/test_transcript_splice_site_corrector.py @@ -0,0 +1,728 @@ +from unittest import TestCase +from unittest.mock import patch, MagicMock + +from src.isoform_assignment import ReadAssignment + +from src.graph_based_model_construction import GraphBasedModelConstructor +from src.transcript_splice_site_corrector import ( + extract_location_from_cigar_string, + count_deletions_from_cigar_codes_in_given_window, + extract_splice_site_locations_within_aligned_read, + count_deletions_for_splice_site_locations, + compute_most_common_case_of_deletions, + extract_nucleotides_from_most_common_del_location, + compute_most_common_del_and_verify_nucleotides, + threshold_for_del_cases_exceeded, + sublist_largest_values_exists, + correct_splice_site_errors, + generate_updated_exon_list, +) + +####################################################################### +## ## +## Run tests with: ## +## python -m unittest tests/test_transcript_splice_site_corrector.py ## +## ## +####################################################################### +class TestMoreConservativeStrategyConditions(TestCase): + + def test_threshold_exceeds_returns_true(self): + THRESHOLD = 0.7 + del_pos_distr = [0, 0, 10, 10, 10, 10, 0, 0] + deletions = {4: 10} + most_common_del = 4 + result = threshold_for_del_cases_exceeded( + del_pos_distr, + deletions, + most_common_del, + THRESHOLD) + self.assertTrue(result) + + def test_threshold_not_exceeded_returns_false(self): + THRESHOLD = 0.7 + del_pos_distr = [0, 0, 10, 10, 10, 6, 0, 0] + deletions = {4: 6, 3: 4} + most_common_del = 4 + result = threshold_for_del_cases_exceeded( + del_pos_distr, + deletions, + most_common_del, + THRESHOLD) + self.assertFalse(result) + + def test_sublist_largest_values_exists_returns_true(self): + lst = [0, 0, 10, 10, 10, 10, 0, 0] + n = 4 + result = sublist_largest_values_exists(lst, n) + self.assertTrue(result) + + def test_sublist_largest_values_exists_returns_false(self): + lst = [0, 0, 10, 10, 10, 0, 6, 0] + n = 4 + result = sublist_largest_values_exists(lst, n) + self.assertFalse(result) + + +class TestExtractingLocationFromCigarString(TestCase): + + def test_cigar_string_with_soft_clip_and_one_match_is_parsed_correctly(self): + cigar = [(4, 50), (0, 10)] + reference_start = 100 + reference_end = 160 + location = 105 + expected_output = 55 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + + def test_cigar_string_with_soft_clip_insertion_and_one_match_is_parsed_correctly(self): + cigar = [(4, 50), (1, 10), (0, 10)] + reference_start = 100 + reference_end = 160 + location = 105 + expected_output = 65 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + + def test_cigar_str_with_s_d_i_m_gives_correct_output(self): + cigar = [(4, 50), (2, 10), (1, 10), (0, 10)] + reference_start = 100 + reference_end = 160 + location = 115 + expected_output = 75 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + def test_cigar_str_with_s_d_n_m_gives_correct_output(self): + cigar = [(4, 50), (2, 10), (3, 100), (0, 10)] + reference_start = 100 + reference_end = 160 + location = 215 + expected_output = 165 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + def test_cigar_str_with_s_m_i_n_m_gives_correct_output(self): + cigar = [(4, 50), (0, 10), (1, 10), (3, 100), (0, 10)] + reference_start = 100 + reference_end = 160 + location = 215 + expected_output = 175 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + def test_location_outside_of_cigar_str_returns_minus_one(self): + cigar = [(4, 50), (0, 10)] + reference_start = 100 + reference_end = 160 + location = 199 + expected_output = -1 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + def test_more_complicated_test_returns_correct_position(self): + cigar_tuples = [(4, 156), (0, 12), (2, 3), (0, 2), (2, 2), (0, 10), (2, 2), (0, 4), (2, 3), (0, 7), (1, 1), (0, 16), (1, 4), (0, 23), (1, 1), (0, 7), + (1, 1), (0, 9), (2, 1), (0, 13), (2, 1), (0, 15), (2, 2), (0, 3), (1, 2), (0, 19), (2, 2), (0, 20), (2, 1), (0, 32), (3, 294), (0, 36), (4, 25)] + reference_start = 72822568 + reference_end = 73822568 + position = 72823071 + expected_output = 668 + result = extract_location_from_cigar_string( + cigar_tuples, reference_start, reference_end, position) + self.assertEqual(result, expected_output) + + def test_case_that_does_not_consume_any_reference_returns_the_correct_location(self): + cigar = [(4, 50), (0, 10)] + reference_start = 100 + reference_end = 160 + location = 100 + expected_output = 50 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + def test_case_that_has_no_reference_consuming_codes_returns_minus_one_as_error(self): + cigar = [(4, 50), (1, 10)] + reference_start = 100 + reference_end = 160 + location = 100 + expected_output = -1 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + def test_case_that_has_no_reference_consuming_codes_at_the_end_returns_minus_one_as_error(self): + cigar = [(4, 50), (0, 10), (1, 10)] + reference_start = 100 + reference_end = 160 + location = 110 + expected_output = -1 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + def test_case_that_has_it_s_location_at_final_match_returns_correct_value(self): + cigar = [(4, 50), (0, 10), (1, 10)] + reference_start = 100 + reference_end = 110 + location = 110 + expected_output = 60 + result = extract_location_from_cigar_string( + cigar, reference_start, reference_end, location) + self.assertEqual(result, expected_output) + + +class TestIndelCountingFromCigarCodes(TestCase): + + def setUp(self): + self.window_size = 8 + + def test_indel_counter_returns_false_and_an_empty_debug_list_for_given_empty_list(self): + cigar_tuples = [] + aligned_location = 100 + location_is_end = False + splice_site_data = { + 'deletions': {}, + "del_pos_distr": [0] * self.window_size, + } + expected_result = { + 'deletions': {0: 1}, + "del_pos_distr": [0, 0, 0, 0, 0, 0, 0, 0] + } + count_deletions_from_cigar_codes_in_given_window( + cigar_tuples, + aligned_location, + location_is_end, + splice_site_data, + self.window_size) + + self.assertEqual(splice_site_data['deletions'], expected_result['deletions']) + self.assertEqual(splice_site_data['del_pos_distr'], expected_result['del_pos_distr']) + + + + def test_indels_are_counted_correctly(self): + cigar_tuples = [(0, 20), (2, 3), (1, 2), (0, 10)] + aligned_location = 27 + location_is_end = True + splice_site_data = { + 'deletions': {}, + "del_pos_distr": [0] * self.window_size, + } + + + expected_result = { + 'deletions': {3: 1}, + "del_pos_distr": [1, 1, 1, 0, 0, 0, 0, 0] + } + + count_deletions_from_cigar_codes_in_given_window( + cigar_tuples, + aligned_location, + location_is_end, + splice_site_data, + self.window_size) + + self.assertEqual(splice_site_data['deletions'], expected_result['deletions']) + self.assertEqual(splice_site_data['del_pos_distr'], expected_result['del_pos_distr']) + + def test_full_window_of_dels_returns_true_for_errors(self): + cigar_tuples = [(0, 20), (2, 8), (1, 2), (0, 10)] + aligned_location = 20 + location_is_end = False + splice_site_data = { + 'deletions': {}, + "del_pos_distr": [0] * self.window_size, + } + expected_result = { + 'deletions': {8: 1}, + "del_pos_distr": [1, 1, 1, 1, 1, 1, 1, 1] + } + + count_deletions_from_cigar_codes_in_given_window( + cigar_tuples, + aligned_location, + location_is_end, + splice_site_data, + self.window_size) + + self.assertEqual(splice_site_data['deletions'], expected_result['deletions']) + self.assertEqual(splice_site_data['del_pos_distr'], expected_result['del_pos_distr']) + +class TestExtractSpliceSiteLocationsFromAlignedRead(TestCase): + + def test_correct_splice_sites_are_extracted(self): + exons = [(1, 10), (20, 30), (40, 50)] + read_start = 20 + read_end = 40 + result = extract_splice_site_locations_within_aligned_read( + read_start, read_end, exons) + expected_output = [(20, False), (30, True) , (40, False)] + self.assertEqual(result, expected_output) + + +class TestExonListUpdater(TestCase): + + def test_error_at_location_start_is_corrected(self): + exons = [(1, 10), (20, 30), (40, 50)] + locations_with_errors = [20] + splice_site_cases = { + 20: { + "most_common_del": 4, + } + } + result = generate_updated_exon_list( + splice_site_cases, locations_with_errors, exons) + expected_result = [(1, 10), (24, 30), (40, 50)] + self.assertEqual(result, expected_result) + + def test_error_at_location_end_is_corrected(self): + exons = [(1, 10), (20, 30), (40, 50)] + locations_with_errors = [30] + splice_site_cases = { + 30: { + "most_common_del": -4, + } + } + result = generate_updated_exon_list( + splice_site_cases, locations_with_errors, exons) + expected_result = [(1, 10), (20, 26), (40, 50)] + self.assertEqual(result, expected_result) + + + pass + +class TestHelperFunctions(TestCase): + + def test_distinct_most_common_case_is_returned_for_location_end(self): + cases = {0: 10, 1: 2, 3: 0, 4: 20, 5: 1} + location_is_end = False + result = compute_most_common_case_of_deletions(cases, location_is_end) + expected_result = 4 + self.assertEqual(result, expected_result) + + def test_distinct_most_common_case_is_returned_for_location_start(self): + cases = {0: 10, 1: 2, 3: 0, 4: 20, 5: 1} + location_is_end = True + result = compute_most_common_case_of_deletions(cases, location_is_end) + expected_result = -4 + self.assertEqual(result, expected_result) + + def test_if_no_distinct_most_commont_del_exists_return_neg_one(self): + cases = {0: 10, 1: 2, 3: 20, 4: 20, 5: 1} + location_is_end = False + result = compute_most_common_case_of_deletions(cases, location_is_end) + expected_result = -1 + self.assertEqual(result, expected_result) + + +class TestCorrectSpliceSiteErrors(TestCase): + + @patch('src.transcript_splice_site_corrector.compute_most_common_case_of_deletions') + def test_errors_are_correctly_returned(self, mock_compute_most_common_case_of_deletions): + splice_site_cases = { + 20: { + "canonical_bases_found": False, + "deletions": {4: 10}, + "location_is_end": False, + "most_common_del": 4, + }, + 30: { + "canonical_bases_found": True, + "deletions": {4: 10}, + "location_is_end": False, + "most_common_del": 4, + }, + } + MIN_N_ALIGNED_READS = 5 + ACCEPTED_DEL_CASES = [4] + THRESHOLD_CASES_AT_LOCATION = 0.7 + MORE_CONSERVATIVE_STRATEGY = False + strand = "+" + chr_record = None + result = correct_splice_site_errors( + splice_site_cases, + MIN_N_ALIGNED_READS, + ACCEPTED_DEL_CASES, + THRESHOLD_CASES_AT_LOCATION, + MORE_CONSERVATIVE_STRATEGY, + strand, + chr_record) + expected_result = [30] + self.assertEqual(result, expected_result) + +class TestCountDeletionsFromSpliceSiteLocations(TestCase): + def test_count_deletions_from_splice_site_locations_extracts_correct_locations(self): + exons = [(1, 10), (20, 30), (40, 50)] + # Cigar codes for indeces 20-40: + # 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 + # [M ,M, M, M, M, M, D, D, D, D, M, M, M, M, M, M, M, M, M, M, M] + cigartuples = [(0, 6), (2, 4), (0, 10)] + read_start = 20 + read_end = 40 + splice_site_cases = {} + WINDOW_SIZE = 8 + count_deletions_for_splice_site_locations( + read_start, + read_end, + cigartuples, + exons, + splice_site_cases, + WINDOW_SIZE) + expected_result = { + 20: { + 'location_is_end': False, + 'deletions': {2: 1}, + 'del_pos_distr': [0, 0, 0, 0, 0, 0, 1, 1], + 'most_common_del': -1, + 'canonical_bases_found': False + }, + 30: { + 'location_is_end': True, + 'deletions': {4: 1}, + 'del_pos_distr': [0, 0, 0, 1, 1, 1, 1, 0], + 'most_common_del': -1, + 'canonical_bases_found': False + }, + 40: { + 'location_is_end': False, + 'deletions': {0: 1}, + 'del_pos_distr': [0, 0, 0, 0, 0, 0, 0, 0], + 'most_common_del': -1, + 'canonical_bases_found': False + }, + } + self.assertEqual(splice_site_cases, expected_result) + + +class TestNucleotideExtraction(TestCase): + + def test_canonical_nucleotides_for_loc_start_pos_strand_are_extracted_correctly(self): + location = 10 + splice_site_data = { + "most_common_del": 4, + "location_is_end": False, + "canonical_bases_found": False, + } + chr_record = "AAAAAAAAAAAAAAAG" + + strand = "+" + extract_nucleotides_from_most_common_del_location( + location, + splice_site_data, + chr_record, + strand) + self.assertTrue(splice_site_data["canonical_bases_found"]) + + def test_canonical_nucleotides_for_loc_end_pos_strand_are_extracted_correctly(self): + location = 10 + splice_site_data = { + "most_common_del": -4, + "location_is_end": True, + "canonical_bases_found": False, + } + + # Fasta 1-based index extraction location: + # 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + # offset of -4 ^ + # | | + # v start pos + # A A A A A G C A A A A A A A A + chr_record = "AAAAAAGCAAAAAAAA" + + strand = "+" + extract_nucleotides_from_most_common_del_location( + location, + splice_site_data, + chr_record, + strand) + self.assertTrue(splice_site_data["canonical_bases_found"]) + + def test_canonical_nucleotides_for_loc_start_neg_strand_are_extracted_correctly(self): + location = 10 + splice_site_data = { + "most_common_del": 4, + "location_is_end": False, + "canonical_bases_found": False, + } + chr_record = "AAAAAAAAAAAAAAAC" + + strand = "-" + extract_nucleotides_from_most_common_del_location( + location, + splice_site_data, + chr_record, + strand) + self.assertTrue(splice_site_data["canonical_bases_found"]) + + def test_canonical_nucleotides_for_loc_end_neg_strand_are_extracted_correctly(self): + location = 10 + splice_site_data = { + "most_common_del": -4, + "location_is_end": True, + "canonical_bases_found": False, + } + chr_record = "AAAAAACTAAAAAAAA" + + strand = "-" + extract_nucleotides_from_most_common_del_location( + location, + splice_site_data, + chr_record, + strand) + self.assertTrue(splice_site_data["canonical_bases_found"]) + + +class TestDeletionComputationAndBaseExtraction(TestCase): + + def test_for_accepted_del_case_nucleotides_are_vefiried(self): + splice_site_location = 10 + splice_site_data = { + "most_common_del": -1, + "location_is_end": False, + "canonical_bases_found": False, + "deletions": {4: 1}, + "del_pos_distr": [0, 0, 0, 0, 0, 0, 0, 0], + } + + chr_record = "AAAAAAAAAAAAAAAG" + ACCEPTED_DEL_CASES = [4] + strand = "+" + compute_most_common_del_and_verify_nucleotides( + splice_site_location, + splice_site_data, + chr_record, + ACCEPTED_DEL_CASES, + strand) + expected_result = { + "most_common_del": 4, + "location_is_end": False, + "canonical_bases_found": True, + "deletions": {4: 1}, + "del_pos_distr": [0, 0, 0, 0, 0, 0, 0, 0], + } + self.assertEqual(splice_site_data, expected_result) + + + def test_for_not_accepted_del_case_nucleotides_are_not_vefiried(self): + splice_site_location = 10 + splice_site_data = { + "most_common_del": -1, + "location_is_end": False, + "canonical_bases_found": False, + "deletions": {2: 1}, + "del_pos_distr": [0, 0, 0, 0, 0, 0, 0, 0], + } + + chr_record = "AAAAAAAAAAAAAAAG" + ACCEPTED_DEL_CASES = [4] + strand = "+" + compute_most_common_del_and_verify_nucleotides( + splice_site_location, + splice_site_data, + chr_record, + ACCEPTED_DEL_CASES, + strand) + expected_result = { + "most_common_del": 2, + "location_is_end": False, + "canonical_bases_found": False, + "deletions": {2: 1}, + "del_pos_distr": [0, 0, 0, 0, 0, 0, 0, 0], + } + self.assertEqual(splice_site_data, expected_result) + + def test_for_accepted_del_case_non_canonical_nucleotides_return_false(self): + splice_site_location = 10 + splice_site_data = { + "most_common_del": -1, + "location_is_end": False, + "canonical_bases_found": False, + "deletions": {4: 1}, + "del_pos_distr": [0, 0, 0, 0, 0, 0, 0, 0], + } + + chr_record = "AAAAAAAAAAAAAAXX" + ACCEPTED_DEL_CASES = [4] + strand = "+" + compute_most_common_del_and_verify_nucleotides( + splice_site_location, + splice_site_data, + chr_record, + ACCEPTED_DEL_CASES, + strand) + expected_result = { + "most_common_del": 4, + "location_is_end": False, + "canonical_bases_found": False, + "deletions": {4: 1}, + "del_pos_distr": [0, 0, 0, 0, 0, 0, 0, 0], + } + self.assertEqual(splice_site_data, expected_result) + +class TestSpliceSiteCorrector(TestCase): + + + def test_error_in_start_on_pos_strand_is_corrected(self): + assigned_read_1 = ReadAssignment(read_id="1", assignment_type="test") + assigned_read_1.cigartuples = [(0, 10), (2, 4), (0, 6)] + assigned_read_1.corrected_exons = [(0, 20)] + assigned_read_1.strand = "+" + assigned_reads = [assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1] + exons = [(0, 5), (10, 20)] + + constructor = GraphBasedModelConstructor( + gene_info=MagicMock(), + chr_record= "ABCDEFGHIJKLMNAGQRSTUVWXYZ", + params=MagicMock(), + transcript_counter=0 + ) + result = constructor.correct_transcript_splice_sites(exons, assigned_reads) + + expected_result = [(0, 5), (14, 20)] + self.assertTrue(result == expected_result) + + + def test_error_in_end_on_pos_strand_is_corrected(self): + assigned_read_1 = ReadAssignment(read_id="1", assignment_type="test") + assigned_read_1.cigartuples = [(0, 10), (2, 4), (0, 16)] + assigned_read_1.corrected_exons = [(0, 20)] + assigned_read_1.strand = "+" + assigned_reads = [assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1] + exons = [(0, 14), (20, 30)] + + constructor = GraphBasedModelConstructor( + gene_info=MagicMock(), + chr_record= "ABCDEFGHIJGCMNOPQRSTUVWXYZ", + params=MagicMock(), + transcript_counter=0 + ) + result = constructor.correct_transcript_splice_sites(exons, assigned_reads) + + expected_result = [(0, 10), (20, 30)] + self.assertTrue(result == expected_result) + + + def test_error_in_start_on_neg_strand_is_corrected(self): + assigned_read_1 = ReadAssignment(read_id="1", assignment_type="test") + assigned_read_1.cigartuples = [(0, 10), (2, 4), (0, 6)] + assigned_read_1.corrected_exons = [(0, 20)] + assigned_read_1.strand = "-" + assigned_reads = [assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1] + exons = [(0, 5), (10, 20)] + + constructor = GraphBasedModelConstructor( + gene_info=MagicMock(), + chr_record= "ABCDEFGHIJKLMNGCQRSTUVWXYZ", + params=MagicMock(), + transcript_counter=0 + ) + result = constructor.correct_transcript_splice_sites(exons, assigned_reads) + + expected_result = [(0, 5), (14, 20)] + self.assertTrue(result == expected_result) + + + def test_error_in_end_on_neg_strand_is_corrected(self): + assigned_read_1 = ReadAssignment(read_id="1", assignment_type="test") + assigned_read_1.cigartuples = [(0, 10), (2, 4), (0, 16)] + assigned_read_1.corrected_exons = [(0, 20)] + assigned_read_1.strand = "-" + assigned_reads = [assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1] + exons = [(0, 14), (20, 30)] + + constructor = GraphBasedModelConstructor( + gene_info=MagicMock(), + chr_record= "ABCDEFGHIJCTMNOPQRSTUVWXYZ", + params=MagicMock(), + transcript_counter=0 + ) + result = constructor.correct_transcript_splice_sites(exons, assigned_reads) + + expected_result = [(0, 10), (20, 30)] + self.assertTrue(result == expected_result) + + def test_error_in_end_on_neg_strand_and_min_accepted_del_cases_is_corrected(self): + assigned_read_1 = ReadAssignment(read_id="1", assignment_type="test") + assigned_read_1.cigartuples = [(0, 10), (2, 3), (0, 17)] + assigned_read_1.corrected_exons = [(0, 20)] + assigned_read_1.strand = "-" + assigned_reads = [assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1] + exons = [(0, 14), (20, 30)] + + constructor = GraphBasedModelConstructor( + gene_info=MagicMock(), + chr_record= "ABCDEFGHIJKCTNOPQRSTUVWXYZ", + params=MagicMock(), + transcript_counter=0 + ) + result = constructor.correct_transcript_splice_sites(exons, assigned_reads) + + expected_result = [(0, 11), (20, 30)] + self.assertTrue(result == expected_result) + + def test_error_in_end_on_neg_strand_and_max_accepted_del_cases_is_corrected(self): + assigned_read_1 = ReadAssignment(read_id="1", assignment_type="test") + assigned_read_1.cigartuples = [(0, 8), (2, 6), (0, 16)] + assigned_read_1.corrected_exons = [(0, 20)] + assigned_read_1.strand = "-" + assigned_reads = [assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1] + exons = [(0, 14), (20, 30)] + + constructor = GraphBasedModelConstructor( + gene_info=MagicMock(), + chr_record= "ABCDEFGHCTKLMNOPQRSTUVWXYZ", + params=MagicMock(), + transcript_counter=0 + ) + result = constructor.correct_transcript_splice_sites(exons, assigned_reads) + + + expected_result = [(0, 8), (20, 30)] + self.assertTrue(result == expected_result) + + + def test_case_with_dels_but_no_canonicals_in_end_on_neg_strand_returns_none(self): + assigned_read_1 = ReadAssignment(read_id="1", assignment_type="test") + assigned_read_1.cigartuples = [(0, 10), (2, 4), (0, 16)] + assigned_read_1.corrected_exons = [(0, 20)] + assigned_read_1.strand = "-" + assigned_reads = [assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1] + exons = [(0, 14), (20, 30)] + + constructor = GraphBasedModelConstructor( + gene_info=MagicMock(), + chr_record= "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + params=MagicMock(), + transcript_counter=0 + ) + result = constructor.correct_transcript_splice_sites(exons, assigned_reads) + + expected_result = None + self.assertTrue(result == expected_result) + + def test_case_with_not_enough_dels_but_canonicals_in_end_on_pos_strand_returns_none(self): + assigned_read_1 = ReadAssignment(read_id="1", assignment_type="test") + assigned_read_1.cigartuples = [(0, 10), (2, 2), (0, 18)] + assigned_read_1.corrected_exons = [(0, 20)] + assigned_read_1.strand = "-" + assigned_reads = [assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1, assigned_read_1] + exons = [(0, 14), (20, 30)] + + constructor = GraphBasedModelConstructor( + gene_info=MagicMock(), + chr_record= "ABCDEFGHIJGCMNOPQRSTUVWXYZ", + params=MagicMock(), + transcript_counter=0 + ) + result = constructor.correct_transcript_splice_sites(exons, assigned_reads) + + expected_result = None + self.assertTrue(result == expected_result)