diff --git a/.github/workflows/build-tests.yml b/.github/workflows/build-tests.yml index 11dc3093..5c4bbb13 100644 --- a/.github/workflows/build-tests.yml +++ b/.github/workflows/build-tests.yml @@ -23,26 +23,40 @@ jobs: shell: bash --login {0} run: unzip TestData.zip - - name: Set up conda environment + - name: Set up conda (Miniconda only) uses: conda-incubator/setup-miniconda@v2 with: - activate-environment: contextsv - environment-file: environment.yml - python-version: 3.9 - auto-activate-base: false + auto-activate-base: true - - name: Install samtools and bcftools using sudo apt-get + - name: Configure conda channels and create environment + shell: bash -l {0} run: | - sudo apt-get update - sudo apt-get install -y samtools bcftools + conda config --remove channels defaults || true + conda config --add channels conda-forge + conda config --add channels bioconda + conda config --set channel_priority strict + conda info # confirm the change + conda env create -f environment.yml - name: Build C++ code shell: bash --login {0} # --login enables PATH variable access run: | - make + source $(conda info --base)/etc/profile.d/conda.sh + conda activate contextsv + echo "CONDA_PREFIX=$CONDA_PREFIX" + ls -l $CONDA_PREFIX/include/htslib + make CONDA_PREFIX=$CONDA_PREFIX - name: Run unit tests shell: bash --login {0} run: | - mkdir -p tests/output - python -m pytest -s -v tests/test_general.py + source $(conda info --base)/etc/profile.d/conda.sh + conda activate contextsv + ./build/contextsv --version + ./build/contextsv --help + + # run: | + # source $(conda info --base)/etc/profile.d/conda.sh + # conda activate contextsv + # mkdir -p tests/output + # python -m pytest -s -v tests/test_general.py diff --git a/.gitignore b/.gitignore index 50575520..b7478d26 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,7 @@ CMakeSettings.json # Output folder output/ +python/ # Doxygen docs/html/ @@ -67,12 +68,15 @@ python/dbscan python/agglo linktoscripts tests/data +tests/cpp_module_out # Population allele frequency filepaths data/gnomadv2_filepaths.txt data/gnomadv3_filepaths.txt data/gnomadv4_filepaths.txt +data/gnomadv4_filepaths_ssd.txt data/gnomadv4_hg19_filepaths.txt +data/gnomadv4_hg19_filepaths_ssd.txt # Training data data/sv_scoring_dataset/ @@ -84,3 +88,17 @@ data/hg19ToHg38.over.chain.gz # Test images python/dbscan_clustering*.png python/dist_plots +upset_plot*.png + +# Temporary files +lib/.nfs* +valgrind.log + +# Log files +*.log +*.err +*.out + +# Snakemake files +.snakemake +snakemake_bench/results/ diff --git a/Makefile b/Makefile index b6f7ddab..207f7e09 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,53 @@ +# Directories INCL_DIR := $(CURDIR)/include SRC_DIR := $(CURDIR)/src +BUILD_DIR := $(CURDIR)/build LIB_DIR := $(CURDIR)/lib +# Version header +VERSION := $(shell git describe --tags --always) +VERSION_HEADER := $(INCL_DIR)/version.h +.PHONY: $(VERSION_HEADER) + @echo "#pragma once" > $@ + @echo "#define VERSION \"$(VERSION)\"" >> $@ -all: - # Generate the SWIG wrapper (C++ -> Python) - swig -c++ -python -I$(INCL_DIR) -o $(SRC_DIR)/swig_wrapper.cpp -outdir $(LIB_DIR) $(SRC_DIR)/swig_wrapper.i +# Conda environment directories +CONDA_PREFIX := $(shell echo $$CONDA_PREFIX) +CONDA_INCL_DIR := $(CONDA_PREFIX)/include +CONDA_LIB_DIR := $(CONDA_PREFIX)/lib - # Compile the SWIG wrapper using setuptools - python3 setup.py build_ext --build-lib $(LIB_DIR) +# Compiler and Flags +CXX := g++ +CXXFLAGS := -std=c++17 -g -I$(INCL_DIR) -I$(CONDA_INCL_DIR) -Wall -Wextra -pedantic + +# Linker Flags +# Ensure that the library paths are set correctly for linking +LDFLAGS := -L$(LIB_DIR) -L$(CONDA_LIB_DIR) -Wl,-rpath=$(CONDA_LIB_DIR) # Add rpath for shared libraries +LDLIBS := -lhts # Link with libhts.a or libhts.so + +# Sources and Output +SOURCES := $(filter-out $(SRC_DIR)/swig_wrapper.cpp, $(wildcard $(SRC_DIR)/*.cpp)) # Filter out the SWIG wrapper from the sources +OBJECTS := $(patsubst $(SRC_DIR)/%.cpp,$(BUILD_DIR)/%.o,$(SOURCES)) +TARGET := $(BUILD_DIR)/contextsv + +# Default target +all: $(TARGET) + +# Debug target +debug: CXXFLAGS += -DDEBUG +debug: all + +# Link the executable +$(TARGET): $(OBJECTS) + @mkdir -p $(BUILD_DIR) + $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) $(LDLIBS) + +# Compile source files +$(BUILD_DIR)/%.o: $(SRC_DIR)/%.cpp + @mkdir -p $(BUILD_DIR) + $(CXX) $(CXXFLAGS) -c $< -o $@ + +# Clean the build directory +clean: + rm -rf $(BUILD_DIR) + \ No newline at end of file diff --git a/README.md b/README.md index e84f006d..707f6f64 100644 --- a/README.md +++ b/README.md @@ -12,33 +12,51 @@ corresponding reference genome (FASTA), a VCF with high-quality SNPs Class documentation is available at https://wglab.openbioinformatics.org/ContextSV

-## Installation (Linux) -### Using Anaconda (recommended) -First, install [Anaconda](https://www.anaconda.com/). +## Installation -Next, create a new environment. This installation has been tested with Python 3.11: - -``` -conda create -n contextsv python=3.11 -conda activate contextsv -``` - -ContextSV can then be installed using the following command: +### Building from source (for testing/development) +ContextSV requires HTSLib as a dependency that can be installed using [Anaconda](https://www.anaconda.com/). Create an environment +containing HTSLib: ``` -conda install -c bioconda -c wglab contextsv=1.0.0 +conda create -n htsenv -c bioconda -c conda-forge htslib +conda activate htsenv ``` -### Building from source (for testing/development) -First install [Anaconda](https://www.anaconda.com/). Then follow the instructions below to install LongReadSum and its dependencies: +Then follow the instructions below to build ContextSV: ``` git clone https://github.com/WGLab/ContextSV cd ContextSV -conda env create -f environment.yml make ``` +ContextSV can then be run: +``` +./build/contextsv --help + +Usage: ./build/contextsv [options] +Options: + -b, --bam Long-read BAM file (required) + -r, --ref Reference genome FASTA file (required) + -s, --snp SNPs VCF file (required) + -o, --outdir Output directory (required) + -c, --chr Chromosome + -r, --region Region (start-end) + -t, --threads Number of threads + -h, --hmm HMM file + -n, --sample-size Sample size for HMM predictions + --min-cnv Minimum CNV length + --eps DBSCAN epsilon + --min-pts-pct Percentage of mean chr. coverage to use for DBSCAN minimum points + -e, --eth ETH file + -p, --pfb PFB file + --save-cnv Save CNV data + --debug Debug mode with verbose logging + --version Print version and exit + -h, --help Print usage and exit +``` + ## Downloading gnomAD SNP population frequencies SNP population allele frequency information is used for copy number predictions in this tool (see @@ -53,7 +71,7 @@ Download links for genome VCF files are located here (last updated April 3, - **gnomAD v2.1.1 (GRCh37)**: https://gnomad.broadinstitute.org/downloads#2 -### Example download +### Script for downloading gnomAD VCFs ``` download_dir="~/data/gnomad/v4.0.0/" @@ -78,71 +96,6 @@ X=~/data/gnomad/v4.0.0/gnomad.genomes.v4.0.sites.chrX.vcf.bgz Y=~/data/gnomad/v4.0.0/gnomad.genomes.v4.0.sites.chrY.vcf.bgz ``` -## Calling structural variants -### Example full script generating a merged VCF of structural variants -``` -# Activate the environment -conda activate contextsv - -# Set the input reference genome -ref_file="~/data/GRCh38.fa" - -# Set the input alignment file (e.g. from minimap2) -long_read_bam="~/data/HG002.GRCh38.bam" - -# Set the input SNPs file (e.g. from NanoCaller) -snps_file="~/data/variant_calls.snps.vcf.gz" - -# Set the SNP population frequencies filepath -pfb_file="~/data/gnomadv4_filepaths.txt" - -# Set the output directory -output_dir=~/data/contextSV_output - -# Specify the number of threads (system-specific) -thread_count=40 - -# Run SV calling (~3-4 hours for whole-genome, 40 cores) -python contextsv --threads $thread_count -o $output_dir -lr $long_read_bam --snps $snps_file --reference $ref_file --pfb $pfb_file - -# The output VCF filepath is located here: -output_vcf=$output_dir/sv_calls.vcf - -# Merge SVs (~3-4 hours for whole-genome, 40 cores) -python contextsv --merge $output_vcf - -# The final merged VCF filepath is located here: -merged_vcf=$output_dir/sv_calls.merged.vcf -``` - -## Input arguments - -``` -python contextsv --help - -ContextSV: A tool for integrative structural variant detection. - -options: - -h, --help show this help message and exit - -lr LONG_READ, --long-read LONG_READ - path to the long read alignment BAM file - -g REFERENCE, --reference REFERENCE - path to the reference genome FASTA file - -s SNPS, --snps SNPS path to the SNPs VCF file - --pfb PFB path to the file with SNP population frequency VCF filepaths (see docs for format) - -o OUTPUT, --output OUTPUT - path to the output directory - -r REGION, --region REGION - region to analyze (e.g. chr1, chr1:1000-2000). If not provided, the entire genome will be analyzed - -t THREADS, --threads THREADS - number of threads to use - --hmm HMM path to the PennCNV HMM file - --window-size WINDOW_SIZE - window size for calculating log2 ratios for CNV predictions (default: 10 kb) - -d, --debug debug mode (verbose logging) - -v, --version print the version number and exit -``` - ## Revision history For release history, please visit [here](https://github.com/WGLab/ContextSV/releases). diff --git a/__main__.py b/__main__.py index a888cdbf..3821b8d1 100644 --- a/__main__.py +++ b/__main__.py @@ -214,7 +214,6 @@ def main(): # Set input parameters input_data = contextsv.InputData() input_data.setVerbose(args.debug) - input_data.setShortReadBam(args.short_read) input_data.setLongReadBam(args.long_read) input_data.setRefGenome(args.reference) input_data.setSNPFilepath(args.snps) diff --git a/environment.yml b/environment.yml index 26f46822..867f41a4 100644 --- a/environment.yml +++ b/environment.yml @@ -1,21 +1,9 @@ name: contextsv channels: - - defaults - - anaconda - - conda-forge - bioconda + - conda-forge dependencies: - - python + - python=3.10 - numpy - htslib - - swig - pytest - - plotly - -# [A] Generate directly from the file: -# conda env create -f environment.yml -n contextsv -# [B] Generate after creating a new environment: -# conda create -n contextsv -# conda activate contextsv -# conda env update -f environment.yml --prune # Prune removes unused packages - diff --git a/include/ThreadPool.h b/include/ThreadPool.h new file mode 100644 index 00000000..41832030 --- /dev/null +++ b/include/ThreadPool.h @@ -0,0 +1,98 @@ +#ifndef THREAD_POOL_H +#define THREAD_POOL_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); +private: + // need to keep track of threads so we can join them + std::vector< std::thread > workers; + // the task queue + std::queue< std::function > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) + : stop(false) +{ + for(size_t i = 0;i task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, + [this]{ return this->stop || !this->tasks.empty(); }); + if(this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + } + ); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> +{ + using return_type = typename std::result_of::type; + + auto task = std::make_shared< std::packaged_task >( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if(stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks.emplace([task](){ (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() +{ + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for(std::thread &worker: workers) + worker.join(); +} + +#endif diff --git a/include/cnv_caller.h b/include/cnv_caller.h index c913c24b..afdd78b3 100644 --- a/include/cnv_caller.h +++ b/include/cnv_caller.h @@ -6,19 +6,19 @@ #include "khmm.h" #include "input_data.h" -#include "cnv_data.h" -#include "sv_data.h" #include "sv_types.h" +#include "sv_object.h" +#include "utils.h" /// @cond #include #include #include #include -#include +// #include +#include #include -#include "snp_info.h" /// @endcond using namespace sv_types; @@ -26,7 +26,7 @@ using namespace sv_types; // SNP data is a struct containing vectors used in predicting copy number // states. It is sorted by SNP position. struct SNPData { - std::vector pos; + std::vector pos; std::vector pfb; std::vector baf; std::vector log2_cov; @@ -47,105 +47,68 @@ struct SNPData { // CNVCaller: Detect CNVs and return the state sequence by SNP position class CNVCaller { private: - InputData* input_data; - mutable std::mutex sv_candidates_mtx; // SV candidate map mutex - mutable std::mutex snp_data_mtx; // SNP data mutex - mutable std::mutex hmm_mtx; // HMM mutex - CHMM hmm; - SNPData snp_data; - SNPInfo snp_info; - double mean_chr_cov = 0.0; - std::unordered_map pos_depth_map; + std::shared_mutex& shared_mutex; - // Define a map of CNV genotypes by HMM predicted state. - // We only use the first 3 genotypes (0/0, 0/1, 1/1) for the VCF output. - // Each of the 6 state predictions corresponds to a copy number state - // (0=No predicted state) - // 0: 1/1 (Normal diploid: no copy number change, GT: 1/1) - // 1: 0/0 (Two copy loss: homozygous deletion, GT: 0/0) - // 2: 1/0 (One copy loss: heterozygous deletion, GT: 0/1) - // 3: 1/1 (Normal diploid: no copy number change, GT: 1/1) - // 4: 1/1 (Copy neutral LOH: no copy number change, GT: 1/1) - // 5: 2/1 (One copy gain: heterozygous duplication, GT: 1/2->0/1) - // 6: 2/2 (Two copy gain: homozygous duplication, GT: 2/2->1/1) - std ::map cnv_genotype_map = { - {0, "1/1"}, - {1, "0/0"}, - {2, "0/1"}, - {3, "1/1"}, - {4, "1/1"}, - {5, "0/1"}, - {6, "1/1"} - }; + void updateSNPData(SNPData& snp_data, uint32_t pos, double pfb, double baf, double log2_cov, bool is_snp); - // Define a map of CNV types by HMM predicted state (0=No predicted state) - std ::map cnv_type_map = { - {0, sv_types::UNKNOWN}, - {1, sv_types::DEL}, - {2, sv_types::DEL}, - {3, sv_types::UNKNOWN}, - {4, sv_types::UNKNOWN}, - {5, sv_types::DUP}, - {6, sv_types::DUP} - }; - - void updateSNPData(SNPData& snp_data, int64_t pos, double pfb, double baf, double log2_cov, bool is_snp); - - std::pair, double> runViterbi(CHMM hmm, SNPData &snp_data); + void runViterbi(const CHMM& hmm, SNPData& snp_data, std::pair, double>& prediction) const; // Query a region for SNPs and return the SNP data - std::pair querySNPRegion(std::string chr, int64_t start_pos, int64_t end_pos, SNPInfo &snp_info, std::unordered_map &pos_depth_map, double mean_chr_cov); - - // Run copy number prediction for a chunk of SV candidates from CIGAR strings - void runCIGARCopyNumberPredictionChunk(std::string chr, std::map& sv_candidates, std::vector sv_chunk, SNPInfo& snp_info, CHMM hmm, int window_size, double mean_chr_cov, std::unordered_map& pos_depth_map); - - void updateSVCopyNumber(std::map& sv_candidates, SVCandidate key, int sv_type_update, std::string data_type, std::string genotype, double hmm_likelihood); - - void updateDPValue(std::map& sv_candidates, SVCandidate key, int dp_value); + void querySNPRegion(std::string chr, uint32_t start_pos, uint32_t end_pos, const std::vector& pos_depth_map, double mean_chr_cov, SNPData& snp_data, const InputData& input_data) const; // Split a region into chunks for parallel processing - std::vector splitRegionIntoChunks(std::string chr, int64_t start_pos, int64_t end_pos, int chunk_count); - - // Split SV candidates into chunks for parallel processing - std::vector> splitSVCandidatesIntoChunks(std::map& sv_candidates, int chunk_count); - - // Merge the read depths from a chunk into the main read depth map - void mergePosDepthMaps(std::unordered_map& main_map, std::unordered_map& map_update); + std::vector splitRegionIntoChunks(std::string chr, uint32_t start_pos, uint32_t end_pos, int chunk_count) const; public: - CNVCaller(InputData& input_data); + CNVCaller(std::shared_mutex& shared_mutex) : shared_mutex(shared_mutex) {} - // Load file data for a chromosome (SNP positions, BAF values, and PFB values) - void loadChromosomeData(std::string chr); + // Define a map of CNV genotypes by HMM predicted state. + // We only use the first 3 genotypes (0/0, 0/1, 1/1) for the VCF output. + // Each of the 6 state predictions corresponds to a copy number state + // (0=No predicted state) + // 0: Unknown (No predicted state) + // 1: 1/1 (Two copy loss: homozygous deletion, GT: 1/1 for homozygous variant) + // 2: 0/1 (One copy loss: heterozygous deletion, GT: 0/1) + // 3: 0/0 (Normal diploid: no copy number change, GT: 0/0 for homozygous reference) + // 4: 1/1 (Copy neutral LOH: no copy number change, GT: 1/1 for homozygous variant) + // 5: 2/1 (One copy gain: heterozygous duplication, GT: 1/2->0/1) + // 6: 2/2 (Two copy gain: homozygous duplication, GT: 2/2->1/1) + const std::unordered_map StateGenotypeMap = { + {0, Genotype::UNKNOWN}, + {1, Genotype::HOMOZYGOUS_ALT}, + {2, Genotype::HETEROZYGOUS}, + {3, Genotype::HOMOZYGOUS_REF}, + {4, Genotype::HOMOZYGOUS_ALT}, + {5, Genotype::HETEROZYGOUS}, + {6, Genotype::HOMOZYGOUS_ALT} + }; - // Run copy number prediction for a pair of SV candidates, and add only - // the SV candidate with the highest likelihood - std::tuple runCopyNumberPredictionPair(std::string chr, SVCandidate sv_one, SVCandidate sv_two); + // Function to get the genotype string from the state + inline Genotype getGenotypeFromCNState(int cn_state) const { + // return StateGenotypeMap.at(cn_state); + try { + return StateGenotypeMap.at(cn_state); + } catch (const std::out_of_range& e) { + printError("ERROR: Invalid CN state: " + std::to_string(cn_state)); + return Genotype::UNKNOWN; + } + } + + // Run copy number prediction for a single SV candidate, returning the + // likelihood, predicted CNV type, genotype, and whether SNPs were found + std::tuple runCopyNumberPrediction(std::string chr, const CHMM& hmm, uint32_t start_pos, uint32_t end_pos, double mean_chr_cov, const std::vector& pos_depth_map, const InputData& input_data) const; // Run copy number prediction for SVs meeting the minimum length threshold obtained from CIGAR strings - SNPData runCIGARCopyNumberPrediction(std::string chr, std::map& sv_candidates, int min_length); - - void updateSVsFromCopyNumberPrediction(SVData& sv_calls, std::vector>& sv_list, std::string chr); - - // Calculate the mean chromosome coverage - double calculateMeanChromosomeCoverage(std::string chr); - - // Calculate read depths for a region - void calculateDepthsForSNPRegion(std::string chr, int64_t start_pos, int64_t end_pos, std::unordered_map& pos_depth_map); + void runCIGARCopyNumberPrediction(std::string chr, std::vector& sv_candidates, const CHMM& hmm, double mean_chr_cov, const std::vector& pos_depth_map, const InputData& input_data) const; - // Calculate the log2 ratio for a region given the read depths and mean - // chromosome coverage - double calculateLog2Ratio(uint32_t start_pos, uint32_t end_pos, std::unordered_map& pos_depth_map, double mean_chr_cov); + void calculateMeanChromosomeCoverage(const std::vector& chromosomes, std::unordered_map>& chr_pos_depth_map, std::unordered_map& chr_mean_cov_map, const std::string& bam_filepath, int thread_count) const; - // Read SNP positions and BAF values from the VCF file of SNP calls - void readSNPAlleleFrequencies(std::string chr, std::string filepath, SNPInfo& snp_info); + void readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, uint32_t end_pos, std::vector& snp_pos, std::unordered_map& snp_baf, std::unordered_map& snp_pfb, const InputData& input_data) const; - // Read SNP population frequencies from the PFB file and return a vector - // of population frequencies for each SNP location - void getSNPPopulationFrequencies(std::string chr, SNPInfo& snp_info); + // Save a TSV with B-allele frequencies, log2 ratios, and copy number predictions + void saveSVCopyNumberToTSV(SNPData& snp_data, std::string filepath, std::string chr, uint32_t start, uint32_t end, std::string sv_type, double likelihood) const; - // Save a TSV with B-allele frequencies, log 2 ratios, and copy number predictions - void saveSVCopyNumberToTSV(SNPData& snp_data, std::string filepath, std::string chr, int64_t start, int64_t end, std::string sv_type, double likelihood); + void saveSVCopyNumberToJSON(SNPData& before_sv, SNPData& after_sv, SNPData& snp_data, std::string chr, uint32_t start, uint32_t end, std::string sv_type, double likelihood, const std::string& filepath) const; }; #endif // CNV_CALLER_H diff --git a/include/cnv_data.h b/include/cnv_data.h deleted file mode 100644 index a2ebd403..00000000 --- a/include/cnv_data.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef CNV_DATA_H -#define CNV_DATA_H - -/// @cond -#include -#include -#include -/// @endcond - -// CNV candidate location map -// (chr, snp_pos) : cnv_type - -using SNPLocation = std::pair; -using SNPToCNVMap = std::map; - - -class CNVData { - private: - SNPToCNVMap cnv_calls; // Map of SNP positions to CNV types - - public: - // Add a CNV call to the map - void addCNVCall(std::string chr, int snp_pos, int cnv_type); - - // Get the most common CNV type within the SV region start and end positions - std::tuple getMostCommonCNV(std::string chr, int start, int end); - - // Load CNV calls from file - void loadFromFile(std::string filepath); -}; - -#endif // CNV_DATA_H diff --git a/include/contextsv.h b/include/contextsv.h index b2a5d6e3..890748da 100644 --- a/include/contextsv.h +++ b/include/contextsv.h @@ -7,19 +7,15 @@ #define CONTEXTSV_H #include "input_data.h" -#include "cnv_data.h" -#include "sv_data.h" class ContextSV { - private: - InputData* input_data; - public: - ContextSV(InputData& input_data); + // explicit ContextSV(InputData& input_data); + ContextSV() = default; // Entry point - int run(); + int run(const InputData& input_data) const; }; #endif // CONTEXTSV_H diff --git a/include/dbscan.h b/include/dbscan.h new file mode 100644 index 00000000..9826a144 --- /dev/null +++ b/include/dbscan.h @@ -0,0 +1,35 @@ +#ifndef DBSCAN_H +#define DBSCAN_H + +#include +#include +#include +#include + +#include "sv_object.h" + +class DBSCAN { + public: + DBSCAN(double epsilon, int minPts) : epsilon(epsilon), minPts(minPts) {} + + // Fit the DBSCAN algorithm to SV calls + void fit(const std::vector& sv_calls); + + const std::vector& getClusters() const; + + private: + double epsilon; + int minPts; + std::vector clusters; + + // Expand the cluster for a given SV call + bool expandCluster(const std::vector& sv_calls, size_t pointIdx, int clusterId); + + // Find the region query for a given SV call + std::vector regionQuery(const std::vector& sv_calls, size_t pointIdx) const; + + // Calculate the distance between two SV calls + double distance(const SVCall& a, const SVCall& b) const; +}; + +#endif // DBSCAN_H diff --git a/include/dbscan1d.h b/include/dbscan1d.h new file mode 100644 index 00000000..07692e65 --- /dev/null +++ b/include/dbscan1d.h @@ -0,0 +1,34 @@ +#ifndef DBSCAN1D_H +#define DBSCAN1D_H + + +#include +#include +#include +#include + + +class DBSCAN1D { + public: + DBSCAN1D(double epsilon, int minPts) : epsilon(epsilon), minPts(minPts) {} + + void fit(const std::vector& points); + + const std::vector& getClusters() const; + + std::vector getLargestCluster(const std::vector &points); + + private: + double epsilon; + int minPts; + std::vector clusters; + + bool expandCluster(const std::vector& points, size_t pointIdx, int clusterId); + + std::vector regionQuery(const std::vector& points, size_t pointIdx) const; + + double distance(int a, int b) const; + +}; + +#endif // DBSCAN1D_H diff --git a/include/debug.h b/include/debug.h new file mode 100644 index 00000000..08038b3c --- /dev/null +++ b/include/debug.h @@ -0,0 +1,23 @@ +// debug.h +#pragma once + +#include +#include +#include +#include +#include + +extern std::mutex debug_mutex; + +#ifdef DEBUG + #define DEBUG_PRINT(x) do { \ + std::lock_guard lock(debug_mutex); \ + auto now = std::chrono::system_clock::now(); \ + std::time_t now_time = std::chrono::system_clock::to_time_t(now); \ + std::ostringstream oss; \ + oss << std::put_time(std::localtime(&now_time), "%Y-%m-%d %H:%M:%S"); \ + std::cerr << oss.str() << " - " << x << std::endl; \ + } while (0) +#else + #define DEBUG_PRINT(x) +#endif diff --git a/include/fasta_query.h b/include/fasta_query.h index 558728bf..4486bdb0 100644 --- a/include/fasta_query.h +++ b/include/fasta_query.h @@ -1,4 +1,4 @@ -// FASTAQuery: A class for querying a FASTA file. +// ReferenceGenome: A class for querying a reference genome FASTA file. #ifndef FASTA_QUERY_H #define FASTA_QUERY_H @@ -8,27 +8,35 @@ #include #include #include +// #include +#include +#include /// @endcond -class FASTAQuery { +class ReferenceGenome { private: std::string fasta_filepath; std::vector chromosomes; std::unordered_map chr_to_seq; + std::map chr_to_length; + std::shared_mutex& shared_mutex; public: + ReferenceGenome(std::shared_mutex& shared_mutex) : shared_mutex(shared_mutex) {} + int setFilepath(std::string fasta_filepath); - std::string getFilepath(); - std::string query(std::string chr, int64_t pos_start, int64_t pos_end); + std::string getFilepath() const; + std::string_view query(const std::string& chr, uint32_t pos_start, uint32_t pos_end) const; + bool compare(const std::string& chr, uint32_t pos_start, uint32_t pos_end, const std::string& compare_seq, float match_threshold) const; // Get the chromosome contig lengths in VCF header format - std::string getContigHeader(); + std::string getContigHeader() const; // Get the list of chromosomes, used for whole genome analysis - std::vector getChromosomes(); + std::vector getChromosomes() const; // Get the length of a chromosome - int64_t getChromosomeLength(std::string chr); + uint32_t getChromosomeLength(std::string chr) const; }; #endif // FASTA_QUERY_H diff --git a/include/input_data.h b/include/input_data.h index 1042d664..1e2c3c1e 100644 --- a/include/input_data.h +++ b/include/input_data.h @@ -23,77 +23,72 @@ class InputData { public: InputData(); - std::string getShortReadBam(); + void printParameters() const; - void setShortReadBam(std::string filepath); - - std::string getLongReadBam(); + std::string getLongReadBam() const; void setLongReadBam(std::string filepath); // Set the filepath to the HMM parameters. void setHMMFilepath(std::string filepath); - std::string getHMMFilepath(); + std::string getHMMFilepath() const; // Set the filepath to the reference genome FASTA file. - void setRefGenome(std::string fasta_filepath); - - // Return a reference to the FASTAQuery object. - const FASTAQuery& getRefGenome() const; - // FASTAQuery getRefGenome(); - - // Query the reference genome for a sequence. - std::string queryRefGenome(std::string chr, int64_t pos_start, int64_t pos_end); - - // Get the chromosomes in the reference genome. - std::vector getRefGenomeChromosomes(); - - // Get a chromosome's length in the reference genome. - int64_t getRefGenomeChromosomeLength(std::string chr); + void setRefGenome(std::string filepath); + std::string getRefGenome() const; // Set the filepath to the text file containing the locations of the // VCF files with population frequencies for each chromosome. void setAlleleFreqFilepaths(std::string filepath); - - // Get the chromosome's VCF filepath with population frequencies. - std::string getAlleleFreqFilepath(std::string chr); - - // Get the population frequency map. - // PFBMap getPFBMap(); + std::string getAlleleFreqFilepath(std::string chr) const; // Set the filepath to the VCF file with SNP calls used for CNV // detection with the HMM. void setSNPFilepath(std::string filepath); - std::string getSNPFilepath(); + std::string getSNPFilepath() const; // Set the ethnicity for SNP population frequencies. void setEthnicity(std::string ethnicity); - std::string getEthnicity(); + std::string getEthnicity() const; + + // Set the assembly gaps file. + void setAssemblyGaps(std::string filepath); + std::string getAssemblyGaps() const; - // Set the window size for the log2 ratio calculation. - void setWindowSize(int window_size); - int getWindowSize(); + // Set the sample size for HMM predictions. + void setSampleSize(int sample_size); + int getSampleSize() const; // Set the minimum CNV length to use for copy number predictions. void setMinCNVLength(int min_cnv_length); - int getMinCNVLength(); + uint32_t getMinCNVLength() const; + + // Set the epsilon parameter for DBSCAN clustering. + void setDBSCAN_Epsilon(double epsilon); + double getDBSCAN_Epsilon() const; + + // Set the percentage of mean chromosome coverage to use for DBSCAN + // minimum points. + void setDBSCAN_MinPtsPct(double min_pts_pct); + double getDBSCAN_MinPtsPct() const; // Set the chromosome to analyze. void setChromosome(std::string chr); - std::string getChromosome(); + std::string getChromosome() const; + bool isSingleChr() const; // Set the region to analyze. void setRegion(std::string region); - std::pair getRegion(); - bool isRegionSet(); + std::pair getRegion() const; + bool isRegionSet() const; // Set the output directory where the results will be written. void setOutputDir(std::string dirpath); - std::string getOutputDir(); + std::string getOutputDir() const; // Set the number of threads to use when parallelization is possible. void setThreadCount(int thread_count); - int getThreadCount(); + int getThreadCount() const; // Set the verbose flag to true if verbose output is desired. void setVerbose(bool verbose); @@ -102,27 +97,34 @@ class InputData { // Set whether to extend the SNP CNV regions around the SV breakpoints // (+/- 1/2 SV length), save a TSV file, and generate HTML reports. void saveCNVData(bool save_cnv_data); - bool getSaveCNVData(); + bool getSaveCNVData() const; + + void setCNVOutputFile(std::string filepath); + std::string getCNVOutputFile() const; private: - std::string short_read_bam; std::string long_read_bam; std::string ref_filepath; std::string snp_vcf_filepath; std::string ethnicity; std::unordered_map pfb_filepaths; // Map of population frequency VCF filepaths by chromosome - FASTAQuery fasta_query; std::string output_dir; - int window_size; - int min_cnv_length; + int sample_size; + uint32_t min_cnv_length; + int min_reads; + double dbscan_epsilon; + double dbscan_min_pts_pct; std::string chr; // Chromosome to analyze std::pair start_end; // Region to analyze bool region_set; // True if a region is set int thread_count; std::string hmm_filepath; std::string cnv_filepath; + std::string assembly_gaps; // Assembly gaps file bool verbose; // True if verbose output is enabled bool save_cnv_data; // True if SNP CNV regions should be extended around SV breakpoints, and saved to a TSV file (Large performance hit) + bool single_chr; + std::string cnv_output_file; }; #endif // INPUT_DATA_H diff --git a/include/khmm.h b/include/khmm.h index 2f7ebe14..9585635f 100644 --- a/include/khmm.h +++ b/include/khmm.h @@ -10,28 +10,26 @@ #include /// @endcond -typedef struct { - int N; /* number of states; Q={1,2,...,N} */ - int M; /* number of observation symbols; V={1,2,...,M}*/ - double **A; /* A[1..N][1..N]. a[i][j] is the transition prob - of going from state i at time t to state j - at time t+1 */ - double **B; /* B[1..N][1..M]. b[j][k] is the probability of - of observing symbol k in state j */ - double *pi; /* pi[1..N] pi[i] is the initial state distribution. */ - double *B1_mean; /* B1_mean[1..N] mean of a continuous Gaussian distribution for state 1 through N*/ - double *B1_sd; /*B1_sd standard deviation of B1 values, which is the same for all states*/ - double B1_uf; /*B1_uniform_fraction: the contribution of uniform distribution to the finite mixture model */ - double *B2_mean; /* B2_mean[1..4] is the average of B_allele_freq*/ - double *B2_sd; /* B2_sd[1..4] is the standard deviation of four B_allele_freq, B2_sd[5] is specially for state1, where B is modelled as a wide normal distribution */ - double B2_uf; /* B2_uniform_fraction: the fraction of uniform distribution in the finite mixture model */ - - int NP_flag; /*flag of 1 and 0 to indicate whether Non-Polymorhpic marker information is contained with HMM file*/ - double *B3_mean; /* B3_mean[1..N] mean of non-polymorphic probe for state 1 through N*/ - double *B3_sd; /* B3_sd[1..4] is the standard deviation of B3 values*/ - double B3_uf; /* B3_uniform_fraction: */ - int dist; /* new parameter to facilitate CNV calling from resequencing data (2009 April) */ -} CHMM; +// Struct for HMM (C++ RAII style) +struct CHMM +{ + int N = 0; // Number of states + int M = 0; // Number of observation symbols + std::vector> A; // Transition probability matrix + std::vector> B; // Emission probability matrix + std::vector pi; // Initial state distribution + std::vector B1_mean; // Mean of a continuous Gaussian distribution for state 1 through N + std::vector B1_sd; // Standard deviation of B1 values, which is the same for all states + double B1_uf = 0.0; // B1_uniform_fraction: the contribution of uniform distribution to the finite mixture model + std::vector B2_mean; // B2_mean[1..4] is the average of B_allele_freq + std::vector B2_sd; // B2_sd[1..4] is the standard deviation of four B_allele_freq, B2_sd[5] is specially for state1, where B is modelled as a wide normal distribution + double B2_uf = 0.0; // B2_uniform_fraction: the fraction of uniform distribution in the finite mixture model + int NP_flag = 0; + std::vector B3_mean; + std::vector B3_sd; + double B3_uf = 0.0; + int dist = 0; +}; /************************************ @@ -39,10 +37,13 @@ typedef struct { ************************************/ /// Read an HMM from a file -CHMM ReadCHMM (const char *filename); +CHMM ReadCHMM (const std::string filename); -// /// Free the memory allocated for an HMM -// void FreeCHMM(CHMM *phmm); +// Read a matrix +std::vector> readMatrix(std::ifstream& file, int rows, int cols); + +// Read a vector +std::vector readVector(std::ifstream& file, int size); /// Run the main HMM algorithm std::pair, double> testVit_CHMM(CHMM hmm, int T, std::vector& O1, std::vector& O2, std::vector& pfb); diff --git a/include/snp_info.h b/include/snp_info.h deleted file mode 100644 index 0b57a629..00000000 --- a/include/snp_info.h +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef SNP_INFO_H -#define SNP_INFO_H - -#include -#include -#include -#include -#include -#include - -// Define the comparator for the binary search tree by SNP position (first -// element of tuple) -struct SNPCompare { - bool operator()(const std::tuple& a, const std::tuple& b) const { - return std::get<0>(a) < std::get<0>(b); - } -}; - -// Define the data structure for SNP frequencies sorted by position -using BST = std::set, SNPCompare>; - -class SNPInfo { -public: - SNPInfo() {} - - // Insert a SNP into the map with its position and B-allele frequency - void insertSNPAlleleFrequency(std::string chr, int64_t pos, double baf); - - // Insert a SNP into the map with its position and population frequency of - // the B allele - void insertSNPPopulationFrequency(std::string chr, int64_t pos, double pfb); - - // Query SNPs within a range (start, end) and return their BAF and PFB values - std::tuple, std::vector, std::vector> querySNPs(std::string chr, int64_t start, int64_t end); - - // Get the range of SNP positions for a given chromosome - std::pair getSNPRange(std::string chr); - - -private: - // Mutex for reading SNP information - std::mutex snp_info_mtx; - - // Define the map of chromosome to SNP B-allele frequency - std::unordered_map snp_baf_map; - - // Define the map of chromosome to SNP population frequency - std::unordered_map> snp_pfb_map; -}; - -#endif // SNP_INFO_H diff --git a/include/sv_caller.h b/include/sv_caller.h index c461d101..a0883e6c 100644 --- a/include/sv_caller.h +++ b/include/sv_caller.h @@ -5,58 +5,107 @@ #include "cnv_caller.h" #include "input_data.h" -#include "cnv_data.h" -#include "sv_data.h" +#include "sv_object.h" +#include "fasta_query.h" #include /// @cond -#include +// #include +#include #include #include /// @endcond -// SV candidate alignment data (chr, start, end, sequence, query start, query -// end, mismatch map) -using AlignmentData = std::tuple>; -using AlignmentVector = std::vector; - -// Query map (query name, alignment vector) -using PrimaryMap = std::unordered_map; -using SuppMap = std::unordered_map; -using RegionData = std::tuple; class SVCaller { private: - int min_sv_size = 50; // Minimum SV size to be considered + struct GenomicRegion { + int tid; + int start; + int end; + int query_start; + int query_end; + bool strand; + int cluster_size; // Number of alignments used for this region + }; + + struct PrimaryAlignment { + int start; + int end; + int query_start; + int query_end; + bool strand; + int cluster_size; // Number of alignments used for this region + }; + + struct SuppAlignment { + int tid; + int start; + int end; + int query_start; + int query_end; + bool strand; + }; + + struct SplitSignature { + int tid; + int start; + int end; + bool strand; + int query_start; + int query_end; + }; + + // Interval Tree Node + struct IntervalNode { + PrimaryAlignment region; + std::string qname; + int max_end; // To optimize queries + std::unique_ptr left; + std::unique_ptr right; + + IntervalNode(PrimaryAlignment r, std::string name) + : region(r), qname(name), max_end(r.end), left(nullptr), right(nullptr) {} + }; + int min_mapq = 20; // Minimum mapping quality to be considered - InputData* input_data; - std::mutex sv_mtx; // Mutex for locking the SV data + mutable std::shared_mutex shared_mutex; // Shared mutex for thread safety - // Detect SVs from the CIGAR string of a read alignment, and return the - // mismatch rate, and the start and end positions of the query sequence - std::tuple, int32_t, int32_t> detectSVsFromCIGAR(bam_hdr_t* header, bam1_t* alignment, SVData& sv_calls, bool is_primary); + std::vector getChromosomes(const std::string& bam_filepath); - // Detect SVs at a region from long read alignments. This is used for - // whole genome analysis running in parallel. - RegionData detectSVsFromRegion(std::string region); + void findSplitSVSignatures(std::unordered_map>& sv_calls, const InputData& input_data); + + // Process a single CIGAR record and find candidate SVs + void processCIGARRecord(bam_hdr_t* header, bam1_t* alignment, std::vector& sv_calls, const std::vector& pos_depth_map); + + std::pair getAlignmentReadPositions(bam1_t* alignment); + + void processChromosome(const std::string& chr, std::vector& combined_sv_calls, const InputData& input_data, const std::vector& chr_pos_depth_map, double mean_chr_cov); + + void findCIGARSVs(samFile* fp_in, hts_idx_t* idx, bam_hdr_t* bamHdr, const std::string& region, std::vector& sv_calls, const std::vector& pos_depth_map); // Read the next alignment from the BAM file in a thread-safe manner int readNextAlignment(samFile *fp_in, hts_itr_t *itr, bam1_t *bam1); - // Detect SVs from split alignments - void detectSVsFromSplitReads(SVData& sv_calls, PrimaryMap& primary_map, SuppMap& supp_map, CNVCaller& cnv_caller); + void runSplitReadCopyNumberPredictions(const std::string& chr, std::vector& split_sv_calls, const CNVCaller &cnv_caller, const CHMM &hmm, double mean_chr_cov, const std::vector &pos_depth_map, const InputData &input_data); - // Calculate the mismatch rate given a map of query positions to - // match/mismatch (1/0) values within a specified range of the query - // sequence - double calculateMismatchRate(std::unordered_map& mismatch_map, int32_t start, int32_t end); + void saveToVCF(const std::unordered_map> &sv_calls, const InputData &input_data, const ReferenceGenome &ref_genome, const std::unordered_map> &chr_pos_depth_map) const; + // void saveToVCF(const std::unordered_map> &sv_calls, const std::string &output_dir, const ReferenceGenome &ref_genome, const std::unordered_map>& chr_pos_depth_map) const; + + // Query the read depth (INFO/DP) at a position + int getReadDepth(const std::vector& pos_depth_map, uint32_t start) const; public: - SVCaller(InputData& input_data); + SVCaller() = default; // Detect SVs and predict SV type from long read alignments and CNV calls - SVData run(); + void run(const InputData& input_data); + + // Interval tree + void findOverlaps(const std::unique_ptr& root, const PrimaryAlignment& query, std::vector& result); + + void insert(std::unique_ptr& root, const PrimaryAlignment& region, std::string qname); }; #endif // SV_CALLER_H diff --git a/include/sv_data.h b/include/sv_data.h deleted file mode 100644 index 414d2eda..00000000 --- a/include/sv_data.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef SV_DATA_H -#define SV_DATA_H - -#include "fasta_query.h" // For querying the reference genome - -/// @cond -#include -#include -#include -#include - -#include "sv_types.h" -/// @endcond - -// Include the SV types namespace -using namespace sv_types; - -// SV data class -class SVData { - private: - SVDepthMap sv_calls; - - // Map of clipped base support by position (chr, pos) : depth - std::map, int> clipped_base_support; - - // SV type to string map for VCF output - std::map sv_type_map = { - {0, "DEL"}, - {1, "DUP"}, - {2, "INV"}, - {3, "INS"}, - {4, "BND"}, - {5, "DUP"} - }; - - public: - SVData() {}; - - int add(std::string chr, int64_t start, int64_t end, int sv_type, std::string alt_allele, std::string data_type, std::string genotype, double hmm_likelihood); - - void concatenate(const SVData& sv_data); - - // Update clipped base support for a given breakpoint location - void updateClippedBaseSupport(std::string chr, int64_t pos); - - int getClippedBaseSupport(std::string chr, int64_t pos, int64_t end); - - void saveToVCF(FASTAQuery& ref_genome, std::string output_dir); - - std::map& getChromosomeSVs(std::string chr); - - std::set getChromosomes(); - - // Begin and end iterators for the SV candidate map - SVDepthMap::iterator begin() { return this->sv_calls.begin(); } - SVDepthMap::iterator end() { return this->sv_calls.end(); } - - // Get the total number of calls (For summary purposes) - int totalCalls(); -}; - -#endif // SV_DATA_H diff --git a/include/sv_object.h b/include/sv_object.h new file mode 100644 index 00000000..8b9b2347 --- /dev/null +++ b/include/sv_object.h @@ -0,0 +1,50 @@ +#ifndef SV_OBJECT_H +#define SV_OBJECT_H + +#include +#include +#include +#include +#include +#include +#include + +#include "sv_types.h" + +using namespace sv_types; + +struct SVCall { + uint32_t start = 0; + uint32_t end = 0; + SVType sv_type = SVType::UNKNOWN; + std::string alt_allele = "."; + // SVDataType data_type = SVDataType::UNKNOWN; + SVEvidenceFlags aln_type; + Genotype genotype = Genotype::UNKNOWN; + double hmm_likelihood = 0.0; + int cn_state = 0; // Copy number state + int aln_offset = 0; // Alignment offset (read vs. reference distance factor) + int cluster_size = 0; // Number of SV calls in the cluster + + bool operator<(const SVCall& other) const; + + SVCall() = default; + + SVCall(uint32_t start, uint32_t end, SVType sv_type, const std::string& alt_allele, SVEvidenceFlags aln_type, Genotype genotype, double hmm_likelihood, int cn_state, int aln_offset, int cluster_size) : + start(start), end(end), sv_type(sv_type), alt_allele(alt_allele), aln_type(aln_type), genotype(genotype), hmm_likelihood(hmm_likelihood), cn_state(cn_state), aln_offset(aln_offset), cluster_size(cluster_size) {} +}; + +void addSVCall(std::vector& sv_calls, SVCall& sv_call); + +// Merge SVs with identical start positions, and sum the cluster sizes +void mergeDuplicateSVs(std::vector& sv_calls); + +uint32_t getSVCount(const std::vector& sv_calls); + +// Merge SVs using DBSCAN clustering +void mergeSVs(std::vector &sv_calls, double epsilon, int min_pts, bool keep_noise, const std::string& json_filepath = ""); + +// Save clusters of SV calls to a JSON file +void saveClustersToJSON(const std::string& filename, const std::map>& clusters); + +#endif // SV_OBJECT_H diff --git a/include/sv_types.h b/include/sv_types.h index 7e002777..359d0dc9 100644 --- a/include/sv_types.h +++ b/include/sv_types.h @@ -7,50 +7,152 @@ #include #include #include +#include /// @endcond namespace sv_types { + // Define constants for SV types - static const int DEL = 0; - static const int DUP = 1; - static const int INV = 2; - static const int INS = 3; - static const int BND = 4; - static const int TANDUP = 5; // Tandem duplication - static const int UNKNOWN = -1; - - // Define SVTypeString for SV types - static const std::string SVTypeString[] = {"DEL", "DUP", "INV", "INS", "BND", "DUP"}; - - // Create a struct for storing SV information - struct SVInfo { - int sv_type; - int read_support; // Number of reads supporting the SV breakpoints - int read_depth; // Read depth at the SV start position - std::set data_type; // Alignment type used to call the SV - int sv_length; - std::string genotype = "./."; // Default genotype (no call) - double hmm_likelihood = 0.0; // HMM likelihood score for the state sequence - - SVInfo() : - sv_type(-1), read_support(0), read_depth(0), data_type({}), sv_length(0), genotype("./."), hmm_likelihood(0.0){} - - SVInfo(int sv_type, int read_support, int read_depth, std::string data_type, int sv_length, std::string genotype, double hmm_likelihood) : - sv_type(sv_type), read_support(read_support), read_depth(read_depth), data_type({data_type}), sv_length(sv_length), genotype(genotype), hmm_likelihood(hmm_likelihood) {} - }; - - // SV (start, end, alt_allele) - using SVCandidate = std::tuple; - - // Chromosome to SV candidate to read depth map - using SVDepthMap = std::unordered_map>; - - // Define a map for storing copy number calls by SV candidate - using SVCopyNumberMap = std::map>; - - // Create a type for storing SV update information from copy number caller - // (SVCandidate, SV type, genotype, data type) - using SVUpdate = std::tuple; + enum class SVType { + UNKNOWN = -1, + DEL = 0, + DUP = 1, + INV = 2, + INS = 3, + BND = 4, + NEUTRAL = 5, // Neutral copy number with unknown type + LOH = 6 // Loss of heterozygosity + }; + + // Mapping of SV types to strings + const std::unordered_map SVTypeString = { + {SVType::UNKNOWN, "UNKNOWN"}, + {SVType::DEL, "DEL"}, + {SVType::DUP, "DUP"}, + {SVType::INV, "INV"}, + {SVType::INS, "INS"}, + {SVType::BND, "BND"}, + {SVType::NEUTRAL, "NEUTRAL"}, + {SVType::LOH, "LOH"} + }; + + // Mapping of SV types to symbols + const std::unordered_map SVTypeSymbol = { + {SVType::UNKNOWN, "."}, + {SVType::DEL, ""}, + {SVType::DUP, ""}, + {SVType::INV, ""}, + {SVType::INS, ""}, + {SVType::BND, ""}, + }; + + // Define constants for genotypes + enum class Genotype { + HOMOZYGOUS_REF = 0, + HETEROZYGOUS = 1, + HOMOZYGOUS_ALT = 2, + UNKNOWN = 3 + }; + + // Mapping of genotypes to strings + const std::unordered_map GenotypeString = { + {Genotype::HOMOZYGOUS_REF, "0/0"}, + {Genotype::HETEROZYGOUS, "0/1"}, + {Genotype::HOMOZYGOUS_ALT, "1/1"}, + {Genotype::UNKNOWN, "./."} + }; + + // Define constants for SV data types (evidence types) + enum class SVDataType { + CIGARINS = 0, + CIGARDEL = 1, + CIGARCLIP = 2, + SPLIT = 3, + SPLITDIST1 = 4, + SPLITDIST2 = 5, + SPLITINV = 6, + SUPPINV = 7, + HMM = 8, + UNKNOWN = 9 + }; + + using SVEvidenceFlags = std::bitset<10>; // Bitset for SV data types + + // Mapping of SV data types to strings + const std::unordered_map SVDataTypeString = { + {SVDataType::CIGARINS, "CIGARINS"}, + {SVDataType::CIGARDEL, "CIGARDEL"}, + {SVDataType::CIGARCLIP, "CIGARCLIP"}, + {SVDataType::SPLIT, "SPLIT"}, + {SVDataType::SPLITDIST1, "SPLITDIST1"}, + {SVDataType::SPLITDIST2, "SPLITDIST2"}, + {SVDataType::SPLITINV, "SPLITINV"}, + {SVDataType::SUPPINV, "SUPPINV"}, + {SVDataType::HMM, "HMM"}, + {SVDataType::UNKNOWN, "UNKNOWN"} + }; + + // Mapping of 6 copy number states to SV types + const std::unordered_map CNVTypeMap = { + {0, SVType::UNKNOWN}, + {1, SVType::DEL}, + {2, SVType::DEL}, + {3, SVType::NEUTRAL}, + {4, SVType::LOH}, + {5, SVType::DUP}, + {6, SVType::DUP} + }; + + // Function to get the SV type string + inline std::string getSVTypeString(SVType sv_type) { + return SVTypeString.at(sv_type); + } + + // Function to get the SV alignment type string from the bitset + inline std::string getSVAlignmentTypeString(SVEvidenceFlags aln_type) { + std::string result; + for (size_t i = 0; i < SVDataTypeString.size(); ++i) { + if (aln_type.test(i)) { + result += SVDataTypeString.at(static_cast(i)) + ","; + } + } + if (!result.empty()) { + result.pop_back(); // Remove the trailing comma + } + return result; + } + + // Function to get the SV type from the CNV state + inline SVType getSVTypeFromCNState(int cn_state) { + return CNVTypeMap.at(cn_state); + } + + // Function to get the genotype string + inline std::string getGenotypeString(Genotype genotype) { + return GenotypeString.at(genotype); + } + + // Function to get the SV data type string + // inline std::string getSVDataTypeString(SVDataType data_type) { + // return SVDataTypeString.at(data_type); + // } + + // Function to get the SV type symbol + inline std::string getSVTypeSymbol(SVType sv_type) { + return SVTypeSymbol.at(sv_type); + } + + // Function to check if an SV type is a valid update from copy number predictions + inline bool isValidCopyNumberUpdate(SVType sv_type, SVType updated_sv_type) { + if (updated_sv_type == SVType::UNKNOWN) { + return false; + } else if (sv_type == SVType::DEL && updated_sv_type != SVType::DEL) { + return false; + } else if (sv_type == SVType::INS && updated_sv_type != SVType::DUP) { + return false; + } + return true; + } } #endif // SV_TYPES_H diff --git a/include/swig_interface.h b/include/swig_interface.h index c7f163ae..578f4653 100644 --- a/include/swig_interface.h +++ b/include/swig_interface.h @@ -12,6 +12,6 @@ #include /// @endcond -int run(InputData input_data); +int run(const InputData& input_data); #endif // SWIG_INTERFACE_H diff --git a/include/utils.h b/include/utils.h index 41efb411..d95f0a8a 100644 --- a/include/utils.h +++ b/include/utils.h @@ -3,12 +3,44 @@ #ifndef UTILS_H #define UTILS_H +#include +#include + /// @cond #include #include #include /// @endcond + +// Guard to close the BAM file +// struct BamFileGuard { +// samFile* fp_in; +// hts_idx_t* idx; +// bam_hdr_t* bamHdr; + +// BamFileGuard(samFile* fp_in, hts_idx_t* idx, bam_hdr_t* bamHdr) +// : fp_in(fp_in), idx(idx), bamHdr(bamHdr) {} + +// ~BamFileGuard() { +// if (idx) { +// hts_idx_destroy(idx); +// idx = nullptr; +// } +// if (bamHdr) { +// bam_hdr_destroy(bamHdr); +// bamHdr = nullptr; +// } +// if (fp_in) { +// sam_close(fp_in); +// fp_in = nullptr; +// } +// } + +// BamFileGuard(const BamFileGuard&) = delete; // Non-copyable +// BamFileGuard& operator=(const BamFileGuard&) = delete; // Non-assignable +// }; + // Print the progress of a task void printProgress(int progress, int total); @@ -23,4 +55,14 @@ void printError(std::string message); std::string getElapsedTime(std::chrono::high_resolution_clock::time_point start, std::chrono::high_resolution_clock::time_point end); +std::string removeChrPrefix(std::string chr); + +void printMemoryUsage(const std::string &functionName); + +bool fileExists(const std::string &filepath); + +bool isFileEmpty(const std::string &filepath); + +void closeJSON(const std::string & filepath); + #endif // UTILS_H diff --git a/include/vcf_writer.h b/include/vcf_writer.h deleted file mode 100644 index 800df144..00000000 --- a/include/vcf_writer.h +++ /dev/null @@ -1,23 +0,0 @@ -/// @cond -#include -#include -#include -/// @endcond - -class VcfWriter { -public: - // Constructor - VcfWriter(const std::string& filename); - void writeHeader(const std::vector& headerLines); - void writeRecord(const std::string& chrom, int pos, const std::string& id, - const std::string& ref, const std::string& alt, - const std::string& qual, const std::string& filter, - const std::string& info, const std::string& format, - const std::vector& samples); - - // Close the VCF file - void close(); - -private: - std::ofstream file_stream; -}; diff --git a/include/version.h b/include/version.h new file mode 100644 index 00000000..d38178a8 --- /dev/null +++ b/include/version.h @@ -0,0 +1,2 @@ +#pragma once +#define VERSION "v0,1,0-41-gd62fe12" diff --git a/python/cnv_plots.py b/python/cnv_plots.py index ec9ba842..67c831c6 100644 --- a/python/cnv_plots.py +++ b/python/cnv_plots.py @@ -76,7 +76,7 @@ def run(cnv_data_file, output_html): line = f.readline().strip() if '=' in line: key, value = line.split("=") - log.info("Metadata: %s=%s", key, value) + # log.info("Metadata: %s=%s", key, value) value = value.strip() metadata[key] = value diff --git a/python/cnv_plots_json.py b/python/cnv_plots_json.py new file mode 100644 index 00000000..768058e9 --- /dev/null +++ b/python/cnv_plots_json.py @@ -0,0 +1,241 @@ +import os +import argparse +import json +import numpy as np +import plotly +from plotly.subplots import make_subplots + +min_sv_length = 200000 # Minimum SV length in base pairs + +# Set up argument parser +parser = argparse.ArgumentParser(description='Generate CNV plots from JSON data.') +parser.add_argument('json_file', type=str, help='Path to the JSON file containing SV data') +parser.add_argument('chromosome', type=str, help='Chromosome to filter the SVs by (e.g., "chr3")', nargs='?', default=None) +args = parser.parse_args() + +# Load your JSON data +with open(args.json_file) as f: + sv_data = json.load(f) + +# State marker colors +# https://community.plotly.com/t/plotly-colours-list/11730/6 +state_colors_dict = { + '1': 'red', + '2': 'darkred', + '3': 'darkgreen', + '4': 'green', + '5': 'darkblue', + '6': 'blue', +} + +sv_type_dict = { + 'DEL': 'Deletion', + 'DUP': 'Duplication', + 'INV': 'Inversion' +} + +# Loop through each SV (assuming your JSON contains multiple SVs) +for sv in sv_data: + + # If a chromosome is specified, filter the SVs by that chromosome + if args.chromosome and sv['chromosome'] != args.chromosome: + continue + + # Filter out SVs that are smaller than the minimum length + if np.abs(sv['size']) < min_sv_length: + print(f"Skipping SV {sv['chromosome']}:{sv['start']}-{sv['end']} of type {sv['sv_type']} with size {sv['size']} bp (smaller than {min_sv_length} bp)") + continue + + # Extract data for plotting + positions_before = sv['before_sv']['positions'] + b_allele_freq_before = sv['before_sv']['b_allele_freq'] + positions_after = sv['after_sv']['positions'] + b_allele_freq_after = sv['after_sv']['b_allele_freq'] + + # Create a subplot for the CNV plot and the BAF plot. + fig = make_subplots( + rows=2, + cols=1, + shared_xaxes=True, + vertical_spacing=0.05, + subplot_titles=(r"SNP Log2 Ratio", "SNP B-Allele Frequency") + ) + + # Get the chromosome, start, end, and sv_type from the SV data + chromosome = sv['chromosome'] + start = sv['start'] + end = sv['end'] + sv_type = sv['sv_type'] + likelihood = sv['likelihood'] + sv_length = sv['size'] + + # Plot the data for 'before_sv', 'sv', and 'after_sv' + for section in ["before_sv", "sv", "after_sv"]: + positions = sv[section]['positions'] + b_allele_freq = sv[section]['b_allele_freq'] + population_freq = sv[section]['population_freq'] + log2_ratio = sv[section]['log2_ratio'] + is_snp = sv[section]['is_snp'] + + # Set all b-allele frequencies to NaN if not SNPs + b_allele_freq = [freq if is_snp_val else float('nan') for freq, is_snp_val in zip(b_allele_freq, is_snp)] + + if section == "sv": + # is_snp = sv[section]['is_snp'] + states = sv[section]['states'] + state_colors = [state_colors_dict[str(state)] for state in states] + marker_symbols = ['circle' if is_snp_val else 'circle-open' for is_snp_val in is_snp] + + # Set the hover text + hover_text = [] + for i, position in enumerate(positions): + # Add hover text for each point + hover_text.append( + f"Position: {position}
" + f"State: {states[i]}
" + f"Log2 Ratio: {log2_ratio[i]}
" + f"SNP: {is_snp[i]}
" + f"BAF: {b_allele_freq[i]}
" + f"Population Frequency: {population_freq[i]}
" + ) + else: + # is_snp = sv[section]['is_snp'] + state_colors = ['black'] * len(positions) + # marker_symbols = ['circle-open'] * len(positions) + marker_symbols = ['circle' if is_snp_val else 'circle-open' for is_snp_val in is_snp] + hover_text = [] + for i, position in enumerate(positions): + # Add hover text for each point + hover_text.append( + f"Position: {position}
" + f"Log2 Ratio: {log2_ratio[i]}
" + f"BAF: {b_allele_freq[i]}
" + f"Population Frequency: {population_freq[i]}
" + ) + + # Create the log2 trace + log2_trace = plotly.graph_objs.Scatter( + x=positions, + y=log2_ratio, + mode='markers+lines', + name=r'Log2 Ratio', + text=hover_text, + hoverinfo='text', + marker=dict( + color=state_colors, + size=5, + symbol=marker_symbols, + ), + line=dict( + color='black', + width=0 + ), + showlegend=False + ) + + # Create the BAF trace + baf_trace = plotly.graph_objs.Scatter( + x=positions, + y=b_allele_freq, + mode='markers+lines', + name='B-Allele Frequency', + text=hover_text, + hoverinfo='text', + marker=dict( + color=state_colors, + size=5, + symbol=marker_symbols, + ), + line=dict( + color='black', + width=0 + ), + showlegend=False + ) + + if section == "sv": + # Create a shaded rectangle for the CNV, layering it below the CNV + # trace and labeling it with the CNV type. + fig.add_vrect( + x0 = start, + x1 = end, + fillcolor = "Black", + layer = "below", + line_width = 0, + opacity = 0.1, + annotation_text = '', + annotation_position = "top left", + annotation_font_size = 20, + annotation_font_color = "black" + ) + + # Add vertical lines at the start and end positions of the CNV. + fig.add_vline( + x = start, + line_width = 2, + line_color = "black", + layer = "below" + ) + + fig.add_vline( + x = end, + line_width = 2, + line_color = "black", + layer = "below" + ) + + # Add traces to the figure + fig.append_trace(log2_trace, row=1, col=1) + fig.append_trace(baf_trace, row=2, col=1) + + # Set the x-axis title. + fig.update_xaxes( + title_text = "Chromosome Position", + row = 2, + col = 1 + ) + + # Set the y-axis titles. + fig.update_yaxes( + title_text = r"Log2 Ratio", + row = 1, + col = 1 + ) + + fig.update_yaxes( + title_text = "B-Allele Frequency", + row = 2, + col = 1 + ) + + # Set the Y-axis range for the log2 ratio plot. + fig.update_yaxes( + range = [-2.0, 2.0], + row = 1, + col = 1 + ) + + # Set the Y-axis range for the BAF plot. + fig.update_yaxes( + range = [-0.2, 1.2], + row = 2, + col = 1 + ) + + # Set the title of the plot. + fig.update_layout( + title_text = f"{sv_type_dict[sv_type]} at {chromosome}:{start}-{end} ({sv_length} bp) (LLH={likelihood})", + title_x = 0.5, + showlegend = False, + ) + # height = 800, + # width = 800 + # ) + # Save the plot to an HTML file (use a unique filename per SV) + # Use the input filepath directory as the output directory + output_dir = os.path.dirname(args.json_file) + svlen_kb = sv_length // 1000 + file_name = f"SV_{chromosome}_{start}_{end}_{sv_type}_{svlen_kb}kb.html" + file_path = os.path.join(output_dir, file_name) + fig.write_html(file_path) + print(f"Plot saved as {file_path}") diff --git a/python/mendelian_inheritance.py b/python/mendelian_inheritance.py new file mode 100644 index 00000000..128b1d1a --- /dev/null +++ b/python/mendelian_inheritance.py @@ -0,0 +1,78 @@ +import csv +import sys + + +def read_tsv(file_path): + with open(file_path, 'r') as file: + reader = csv.reader(file, delimiter='\t') + return [row for row in reader] + +def calculate_mendelian_error(father_genotype, mother_genotype, child_genotype): + # Generate all possible child genotypes + child_genotypes = set() + for allele1 in father_genotype.split('/'): + for allele2 in mother_genotype.split('/'): + child_genotypes.add('/'.join(sorted([allele1, allele2]))) + + # Print the parent and child genotypes if invalid + if child_genotype not in child_genotypes: + print(f"ME: Father: {father_genotype}, Mother: {mother_genotype}, Child: {child_genotype}") + + # Check if the child genotype is valid + return 0 if child_genotype in child_genotypes else 1 + + +def main(father_file, mother_file, child_file): + father_records = read_tsv(father_file) + mother_records = read_tsv(mother_file) + child_records = read_tsv(child_file) + + if len(father_records) != len(mother_records) or len(father_records) != len(child_records): + raise ValueError("All files must have the same number of records") + + total_records = len(father_records) + error_count = 0 + + sv_type_dict = {} + sv_type_error_dict = {} + + for i in range(total_records): + father_genotype = father_records[i][5] + mother_genotype = mother_records[i][5] + child_genotype = child_records[i][5] + child_sv_type = child_records[i][2] + sv_type_dict[child_sv_type] = sv_type_dict.get(child_sv_type, 0) + 1 + + # Print SV size if error occurs + error_value = calculate_mendelian_error(father_genotype, mother_genotype, child_genotype) + if error_value == 1: + # print(f"SV size: {father_records[i][2]}") + sv_type_error_dict[child_sv_type] = sv_type_error_dict.get(child_sv_type, 0) + 1 + + error_count += error_value + # error_count += calculate_mendelian_error(father_genotype, mother_genotype, child_genotype) + + if total_records == 0: + error_rate = 0 + print("No records found") + else: + error_rate = error_count / total_records + + print(f"Mendelian Inheritance Error Rate: {error_rate:.2%} for {total_records} shared trio SVs") + + print("SV Type Distribution:") + for sv_type, count in sv_type_dict.items(): + error_count = sv_type_error_dict.get(sv_type, 0) + error_rate = error_count / count + print(f"{sv_type}: {error_rate:.2%} ({error_count}/{count})") + +if __name__ == "__main__": + if len(sys.argv) != 4: + print("Usage: python mendelian_inheritance.py ") + sys.exit(1) + + father_file = sys.argv[1] + mother_file = sys.argv[2] + child_file = sys.argv[3] + + main(father_file, mother_file, child_file) diff --git a/python/plot_distributions.py b/python/plot_distributions.py index 1f684ed8..c2644a8a 100644 --- a/python/plot_distributions.py +++ b/python/plot_distributions.py @@ -26,11 +26,28 @@ def generate_sv_size_plot(input_vcf, output_png, plot_title="SV Caller"): # Read VCF file into a pandas DataFrame - vcf_df = pd.read_csv(input_vcf, sep='\t', comment='#', header=None, \ - names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE'], \ - dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ - 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}) - + try: + vcf_df = pd.read_csv(input_vcf, sep='\t', comment='#', header=None, \ + names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE'], \ + dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ + 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}) + except Exception as e: + try: + print("[DEBUG] Caught TypeError") + # Truvari merged VCF format with different columns + vcf_df = pd.read_csv(input_vcf, sep='\t', comment='#', header=None, \ + names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE', 'SAMPLE2'], \ + dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ + 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str, 'SAMPLE2': str}) + except Exception as e: + print("[DEBUG] Caught Exception") + # Platinum pedigree VCF format with different columns + vcf_df = pd.read_csv(input_vcf, sep='\t', comment='#', header=None, \ + names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE', 'SAMPLE2', 'SAMPLE3', 'SAMPLE4', 'SAMPLE5', 'SAMPLE6', 'SAMPLE7'], \ + dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ + 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE1': str, 'SAMPLE2': str, 'SAMPLE3': str, 'SAMPLE4': str, \ + 'SAMPLE5': str, 'SAMPLE6': str, 'SAMPLE7': str}) + # Initialize dictionaries to store SV sizes for each type of SV sv_sizes = {} @@ -61,6 +78,7 @@ def generate_sv_size_plot(input_vcf, output_png, plot_title="SV Caller"): # Continue if SV type is BND (no SV size) if sv_type == "BND": continue + # If the SV caller is DELLY, then we use the second SV size for non-INS # (they don't have SVLEN) and the first SV size for INS sv_size = None @@ -71,7 +89,7 @@ def generate_sv_size_plot(input_vcf, output_png, plot_title="SV Caller"): # If the plot title is GIAB, then we need to convert INS to DUP if # INFO/SVTYPE is INS and INFO/REPTYPE is DUP - if plot_title == "GIAB" and sv_type == "INS": + if "GIAB" in plot_title and sv_type == "INS": if 'REPTYPE=DUP' in record['INFO']: sv_type = "DUP" @@ -90,10 +108,11 @@ def generate_sv_size_plot(input_vcf, output_png, plot_title="SV Caller"): # Create a dictionary of SV types and their corresponding colors. # From: https://davidmathlogic.com/colorblind/ - sv_colors = {'DEL': '#D81B60', 'DUP': '#1E88E5', 'INV': '#FFC107', 'INS': '#004D40'} + # WONG colors + sv_colors = {'DEL': '#E69F00', 'DUP': '#56B4E9', 'INV': '#009E73', 'INS': '#F0E442', 'INVDUP': '#D55E00', 'COMPLEX': '#CC79A7'} # Create a dictionary of SV types and their corresponding labels - sv_labels = {'DEL': 'Deletion', 'DUP': 'Duplication', 'INV': 'Inversion', 'INS': 'Insertion'} + sv_labels = {'DEL': 'Deletion', 'DUP': 'Duplication', 'INV': 'Inversion', 'INS': 'Insertion', 'INVDUP': 'Inverted Duplication', 'COMPLEX': 'Complex'} # Get the list of SV types and sort them in the order of the labels sv_types = sorted(sv_sizes.keys(), key=lambda x: sv_labels[x]) @@ -141,16 +160,16 @@ def generate_sv_size_plot(input_vcf, output_png, plot_title="SV Caller"): # Use a log scale for the y-axis axes[i].set_yscale('log') - # # In the same axis, plot a known duplication if within the range of the plot - if sv_type == 'DUP': - print("TEST: Found DUP") - cnv_size = 776237 / size_scale - x_min, x_max = axes[i].get_xlim() - if cnv_size > x_min and cnv_size < x_max: - axes[i].axvline(x=cnv_size, color='black', linestyle='--') - else: - # Print the values - print(f'CNV size: {cnv_size}, x_min: {x_min}, x_max: {x_max}') + # In the same axis, plot a known duplication if within the range of the plot + # if sv_type == 'DUP': + # print("TEST: Found DUP") + # cnv_size = 776237 / size_scale + # x_min, x_max = axes[i].get_xlim() + # if cnv_size > x_min and cnv_size < x_max: + # axes[i].axvline(x=cnv_size, color='black', linestyle='--') + # else: + # # Print the values + # print(f'CNV size: {cnv_size}, x_min: {x_min}, x_max: {x_max}') # Refresh the plot plt.draw() @@ -194,10 +213,18 @@ def generate_sv_size_plot(input_vcf, output_png, plot_title="SV Caller"): fig.update_layout(legend=dict( orientation='v', yanchor='top', - y=0.75, + y=0.9, xanchor='right', - x=0.75, + x=0.9, )) + # # Move the legend to the bottom right outside the plot + # fig.update_layout(legend=dict( + # orientation='v', + # yanchor='top', + # y=1.0, + # xanchor='right', + # x=1.15, + # )) # Set a larger font size for all text in the plot fig.update_layout(font=dict(size=26)) diff --git a/python/plot_venn.py b/python/plot_venn.py new file mode 100644 index 00000000..757f4408 --- /dev/null +++ b/python/plot_venn.py @@ -0,0 +1,48 @@ +# from matplotlib_venn import venn3 +from matplotlib_venn import venn2 +import argparse + +import matplotlib.pyplot as plt + +def plot_venn(AB, Ab, aB, output, plot_title, title_Ab, title_aB): + plt.figure(figsize=(8, 8)) + + print('AB:', AB) + print('Ab:', Ab) + print('aB:', aB) + + # Create scaled subsets for the venn diagram + scaling_factor = 1000 + scaled_AB = AB / scaling_factor + scaled_Ab = Ab / scaling_factor + scaled_aB = aB / scaling_factor + + # Create a venn diagram scaled to the number of elements in each set + # venn = venn2(subsets=(AB, Ab, aB), set_labels=(title_Ab, title_aB)) + venn = venn2(subsets=(scaled_Ab, scaled_aB, scaled_AB), set_labels=(title_Ab, title_aB)) + + # Update the labels to reflect the actual counts + venn.get_label_by_id('10').set_text(str(Ab)) + venn.get_label_by_id('01').set_text(str(aB)) + venn.get_label_by_id('11').set_text(str(AB)) + + # Update the title + # plt.title("contextsv and " + title_aB + " venn diagram (all SV types)") + plt.title(plot_title) + plt.savefig(output) + plt.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Generate a Venn diagram.') + parser.add_argument('-a', type=int, required=True, help='Shared count') + parser.add_argument('-b', type=int, required=True, help='False positive count') + parser.add_argument('-c', type=int, required=True, help='False negative count') + parser.add_argument('-o', '--output', type=str, required=True, help='Output file path') + parser.add_argument('-a_title', type=str, required=True, help='Title for set A') + parser.add_argument('-b_title', type=str, required=True, help='Title for set B') + parser.add_argument('-c_title', type=str, required=True, help='Title for set C') + + args = parser.parse_args() + + plot_venn(args.a, args.b, args.c, args.output, args.a_title, args.b_title, args.c_title) + print(f'Venn diagram saved to {args.output}') diff --git a/python/sv_merger.py b/python/sv_merger.py index 2d0027b1..2f5cb94f 100644 --- a/python/sv_merger.py +++ b/python/sv_merger.py @@ -89,6 +89,12 @@ def update_support(record, cluster_size): return record +def weighted_score(sv_len, hmm_score, weight_hmm): + """ + Calculate a weighted score based on read support and HMM score. + """ + return (1 - weight_hmm) * sv_len + weight_hmm * hmm_score + def cluster_breakpoints(vcf_df, sv_type, cluster_size_min): """ Cluster SV breakpoints using HDBSCAN. @@ -124,88 +130,103 @@ def cluster_breakpoints(vcf_df, sv_type, cluster_size_min): # Get the HMM likelihood scores hmm_scores = vcf_df['INFO'].str.extract(r'HMM=(-?\d+\.?\d*)', expand=False).astype(float) - # Set all 0 values to NaN - hmm_scores[hmm_scores == 0] = np.nan + # Set all 0 values to a low negative value + hmm_scores[hmm_scores == 0] = -1e-100 + # hmm_scores[hmm_scores == 0] = np.nan # Cluster SV breakpoints using HDBSCAN cluster_labels = [] # dbscan = DBSCAN(eps=30000, min_samples=3) - dbscan = HDBSCAN(min_cluster_size=cluster_size_min, min_samples=3) + + if len(breakpoints) == 1: + return merged_records + + logging.info("Clustering %d SV breakpoints with parameters: min_cluster_size=%d", len(breakpoints), cluster_size_min) + dbscan = HDBSCAN(min_cluster_size=cluster_size_min, min_samples=2) if len(breakpoints) > 0: logging.info("Clustering %d SV breakpoints...", len(breakpoints)) cluster_labels = dbscan.fit_predict(breakpoints) logging.info("Label counts: %d", len(np.unique(cluster_labels))) - - # Set all 0 values to NaN - hmm_scores[hmm_scores == 0] = np.nan + # Merge SVs with the same label unique_labels = np.unique(cluster_labels) - for label in unique_labels: - - # Skip label -1 (outliers) - if label == -1: - # # Print the positions if any are within a certain range - # pos_min = 180915940 - # pos_max = 180950356 - - # Debug if position is found - target_pos = 180949217 + #logging.info("Unique labels: %s", unique_labels) - idx = cluster_labels == label - pos_values = breakpoints[idx][:, 0] - if target_pos in pos_values: - logging.info(f"Outlier deletion positions: {pos_values}") - - # if (np.any(pos_values >= pos_min) and np.any(pos_values <= pos_max)): - # Print all within range - # pos_within_range = pos_values[(pos_values >= pos_min) & (pos_values <= pos_max)] - # logging.info(f"Outlier deletion positions: {pos_within_range}") - # logging.info(f"Outlier deletion positions: {pos_values}") + for label in unique_labels: + # Skip label -1 (outliers) only if there are no other clusters + if label == -1 and len(unique_labels) > 1: continue # Get the indices of SVs with the same label idx = cluster_labels == label # Get HMM and read support values for the cluster - max_score_idx = 0 # Default to the first SV in the cluster + # max_score_idx = 0 # Default to the first SV in the cluster cluster_hmm_scores = np.array(hmm_scores[idx]) - cluster_depth_scores = np.array(sv_support[idx]) - max_hmm = None - max_support = None - max_hmm_idx = None - max_support_idx = None + # cluster_depth_scores = np.array(sv_support[idx]) + cluster_sv_lengths = np.array(breakpoints[idx][:, 1] - breakpoints[idx][:, 0] + 1) + # max_hmm = None + # max_support = None + # max_hmm_idx = None + # max_support_idx = None # Find the maximum HMM score - if len(np.unique(cluster_hmm_scores)) > 1: - max_hmm_idx = np.nanargmax(cluster_hmm_scores) - max_hmm = cluster_hmm_scores[max_hmm_idx] + # if len(np.unique(cluster_hmm_scores)) > 1: + # max_hmm_idx = np.nanargmax(cluster_hmm_scores) + # max_hmm = cluster_hmm_scores[max_hmm_idx] # Find the maximum read alignment and clipped base support - if len(np.unique(cluster_depth_scores)) > 1: - max_support_idx = np.argmax(cluster_depth_scores) - max_support = cluster_depth_scores[max_support_idx] - - # For deletions, choose the SV with the highest HMM score if available - if sv_type == 'DEL': - if max_hmm is not None: - max_score_idx = max_hmm_idx - elif max_support is not None: - max_score_idx = max_support_idx - - # For insertions and duplications, choose the SV with the highest read - # support if available - elif sv_type == 'INS/DUP': - if max_support is not None: - max_score_idx = max_support_idx - elif max_hmm is not None: - max_score_idx = max_hmm_idx + # if len(np.unique(cluster_depth_scores)) > 1: + # max_support_idx = np.argmax(cluster_depth_scores) + # max_support = cluster_depth_scores[max_support_idx] + + # Normalize the HMM scores. Since the HMM scores are negative (log lh), we + # normalize them to the range [0, 1] by subtracting the minimum value + cluster_hmm_norm = (cluster_hmm_scores - np.min(cluster_hmm_scores)) / (np.max(cluster_hmm_scores) - np.min(cluster_hmm_scores)) + + # Normalize the SV lengths to the range [0, 1] + cluster_sv_lengths_norm = (cluster_sv_lengths - np.min(cluster_sv_lengths)) / (np.max(cluster_sv_lengths) - np.min(cluster_sv_lengths)) + + # Use a weighted approach to choose the best SV based on HMM and + # support. Deletions have higher priority for HMM scores, while + # insertions and duplications have higher priority for read alignment + # support. + # hmm_weight = 0.7 if sv_type == 'DEL' else 0.3 + hmm_weight = 0.5 + max_score_idx = 0 # Default to the first SV in the cluster + max_score = weighted_score(cluster_hmm_norm[max_score_idx], cluster_sv_lengths_norm[max_score_idx], hmm_weight) + # max_score = weighted_score(cluster_sv_lengths[max_score_idx], cluster_hmm_scores[max_score_idx], hmm_weight) + for k, hmm_norm in enumerate(cluster_hmm_norm): + svlen_norm = cluster_sv_lengths_norm[k] + score = weighted_score(svlen_norm, hmm_norm, hmm_weight) + if score > max_score: + max_score = score + max_score_idx = k + + # Get the VCF record with the highest score + max_record = vcf_df.iloc[idx, :].iloc[max_score_idx, :] + + # # For deletions, choose the SV with the highest HMM score if available + # if sv_type == 'DEL': + # if max_hmm is not None: + # max_score_idx = max_hmm_idx + # elif max_support is not None: + # max_score_idx = max_support_idx + + # # For insertions and duplications, choose the SV with the highest read + # # support if available + # elif sv_type == 'INS/DUP': + # if max_support is not None: + # max_score_idx = max_support_idx + # elif max_hmm is not None: + # max_score_idx = max_hmm_idx # Get the VCF record with the highest depth score - max_record = vcf_df.iloc[idx, :].iloc[max_score_idx, :] + # max_record = vcf_df.iloc[idx, :].iloc[max_score_idx, :] # Get the number of SVs in this cluster cluster_size = np.sum(idx) @@ -213,30 +234,7 @@ def cluster_breakpoints(vcf_df, sv_type, cluster_size_min): # Update the SUPPORT field in the INFO column max_record = update_support(max_record, cluster_size) - - # Get all position values in the cluster - pos_values = breakpoints[idx][:, 0] - - # Debug if position is found - target_pos = 180949217 - if target_pos in pos_values: - logging.info(f"Cluster size: {cluster_size}") - logging.info(f"Pos values:") - for k, pos in enumerate(pos_values): - logging.info(f"Row {k+1} - Pos: {pos}, HMM: {cluster_hmm_scores[k]}, support: {cluster_depth_scores[k]}") - - logging.info(f"Chosen position: {max_record['POS']} - HMM: {max_hmm}, support: {max_support}") - - # # If the POS value is a certain value, plot the support - # pos_min = 180915940 - # pos_max = 180950356 - # # if (np.any(pos_values >= pos_min) and np.any(pos_values <= pos_max)) or cluster_size > 1000: - # if (np.any(pos_values >= pos_min) and np.any(pos_values <= pos_max)): - # logging.info(f"Cluster size: {cluster_size}") - # logging.info(f"Pos values:") - # for k, pos in enumerate(pos_values): - # logging.info(f"Row {k+1} - Pos: {pos}, HMM: {cluster_hmm_scores[k]}, support: {cluster_depth_scores[k]}") - + # pos_values = breakpoints[idx][:, 0] # Append the chosen record to the dataframe of records that will # form the merged VCF file @@ -289,16 +287,19 @@ def sv_merger(vcf_file_path, cluster_size_min=3, suffix='.merged'): del chr_del_df # Cluster insertions and duplications - logging.info("Clustering insertions and duplications on chromosome %s...", chromosome) - chr_ins_dup_df = vcf_df[(vcf_df['CHROM'] == chromosome) & ((vcf_df['INFO'].str.contains('SVTYPE=INS')) | (vcf_df['INFO'].str.contains('SVTYPE=DUP')))] - ins_dup_records = cluster_breakpoints(chr_ins_dup_df, 'INS/DUP', cluster_size_min) - del chr_ins_dup_df + logging.info("Clustering all other SVs on chromosome %s...", chromosome) + # chr_ins_dup_df = vcf_df[(vcf_df['CHROM'] == chromosome) & + # ((vcf_df['INFO'].str.contains('SVTYPE=INS')) | + # (vcf_df['INFO'].str.contains('SVTYPE=DUP')))] + chr_non_del_df = vcf_df[(vcf_df['CHROM'] == chromosome) & (~vcf_df['INFO'].str.contains('SVTYPE=DEL'))] + ins_dup_records = cluster_breakpoints(chr_non_del_df, 'INS/DUP', cluster_size_min) + del chr_non_del_df # Summarize the number of deletions and insertions/duplications del_count = del_records.shape[0] ins_dup_count = ins_dup_records.shape[0] records_processed += del_count + ins_dup_count - logging.info("Chromosome %s - %d deletions, %d insertions, and duplications merged.", chromosome, del_count, ins_dup_count) + logging.info("Chromosome %s - %d deletions, %d other types merged.", chromosome, del_count, ins_dup_count) # Append the deletion and insertion/duplication records to the merged # records DataFrame @@ -392,4 +393,4 @@ def sv_merger(vcf_file_path, cluster_size_min=3, suffix='.merged'): # DBSCAN sv_merger(vcf_file_path, cluster_size_min=cluster_size_min, suffix=suffix) - \ No newline at end of file + diff --git a/setup.py b/setup.py index c57fb30f..c8591523 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,10 @@ # Set the project dependencies SRC_DIR = "src" +# SRC_FILES = glob.glob(os.path.join(SRC_DIR, "*.cpp")) SRC_FILES = glob.glob(os.path.join(SRC_DIR, "*.cpp")) +SRC_FILES = [f for f in SRC_FILES if "main.cpp" not in f] # Ignore the main.cpp file + INCLUDE_DIR = "include" INCLUDE_FILES = glob.glob(os.path.join(INCLUDE_DIR, "*.h")) @@ -37,7 +40,7 @@ name="_" + NAME, sources=SRC_FILES, include_dirs=[INCLUDE_DIR, conda_include_dir], - extra_compile_args=["-std=c++11"], + extra_compile_args=["-std=c++17"], language="c++", libraries=["hts"], library_dirs=[conda_lib_dir] diff --git a/src/cnv_caller.cpp b/src/cnv_caller.cpp index ca0588a2..66f1f146 100644 --- a/src/cnv_caller.cpp +++ b/src/cnv_caller.cpp @@ -2,6 +2,9 @@ #include "cnv_caller.h" #include +#include +#include +#include /// @cond #include @@ -20,9 +23,11 @@ #include #include #include // std::max +#include // std::pair +#include +#include // std::execution::par #include "utils.h" -#include "sv_data.h" #include "sv_types.h" #define MIN_PFB 0.01 @@ -31,325 +36,268 @@ using namespace sv_types; + // Function to call the Viterbi algorithm for the CHMM -std::pair, double> CNVCaller::runViterbi(CHMM hmm, SNPData& snp_data) +void CNVCaller::runViterbi(const CHMM& hmm, SNPData& snp_data, std::pair, double>& prediction) const { int data_count = (int) snp_data.pos.size(); - std::lock_guard lock(this->hmm_mtx); // Lock the mutex for the HMM - std::pair, double> state_sequence = testVit_CHMM(hmm, data_count, snp_data.log2_cov, snp_data.baf, snp_data.pfb); - return state_sequence; + if (data_count == 0) + { + printError("ERROR: No SNP data found for Viterbi algorithm."); + prediction = std::make_pair(std::vector(), 0.0); + } + prediction = testVit_CHMM(hmm, data_count, snp_data.log2_cov, snp_data.baf, snp_data.pfb); } // Function to obtain SNP information for a region -std::pair CNVCaller::querySNPRegion(std::string chr, int64_t start_pos, int64_t end_pos, SNPInfo& snp_info, std::unordered_map& pos_depth_map, double mean_chr_cov) +void CNVCaller::querySNPRegion(std::string chr, uint32_t start_pos, uint32_t end_pos, const std::vector& pos_depth_map, double mean_chr_cov, SNPData& snp_data, const InputData& input_data) const { - SNPData snp_data; - bool snps_found = false; - int window_size = this->input_data->getWindowSize(); - - // std::cout << "Querying SNPs for region " << chr << ":" << start_pos << - // "-" << end_pos << "..." << std::endl; - // TEST - if (start_pos == 43593639 && end_pos == 43608172) { - printMessage("Querying SNPs for region " + chr + ":" + std::to_string(start_pos) + "-" + std::to_string(end_pos) + "..."); + // Initialize the SNP data with default values and sample size length + int sample_size = input_data.getSampleSize(); + std::vector snp_pos; + std::unordered_map snp_baf_map; + std::unordered_map snp_pfb_map; + // printMessage("Reading SNP data for copy number prediction: " + chr + ":" + std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos)); + this->readSNPAlleleFrequencies(chr, start_pos, end_pos, snp_pos, snp_baf_map, snp_pfb_map, input_data); + + // Get the log2 ratio for evenly spaced positions in the + // region + sample_size = std::max((int) snp_pos.size(), sample_size); + + // Print an error if the end position is less than or equal to the start + // position + if (start_pos > end_pos) + { + printError("ERROR: Invalid SNP region for copy number prediction: " + chr + ":" + std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos)); + return; } - // printMessage("Querying SNPs for region " + chr + ":" + std::to_string(start_pos) + "-" + std::to_string(end_pos) + "..."); - for (int64_t i = start_pos; i <= end_pos; i += window_size) + + // Loop through evenly spaced positions in the region and get the log2 ratio + double pos_step = static_cast(end_pos - start_pos + 1) / static_cast(sample_size); + std::unordered_map window_log2_map; + for (int i = 0; i < sample_size; i++) { - // Run a sliding non-overlapping window of size window_size across - // the SV region and calculate the log2 ratio for each window - int64_t window_start = i; - int64_t window_end = std::min(i + window_size - 1, end_pos); - - // Get the SNP info for the window - // std::cout << "Querying SNPs for window " << chr << ":" << window_start << "-" << window_end << "..." << std::endl; - this->snp_data_mtx.lock(); - std::tuple, std::vector, std::vector> window_snps = snp_info.querySNPs(chr, window_start, window_end); - this->snp_data_mtx.unlock(); - std::vector& snp_window_pos = std::get<0>(window_snps); // SNP positions - std::vector& snp_window_bafs = std::get<1>(window_snps); // B-allele frequencies - std::vector& snp_window_pfbs = std::get<2>(window_snps); // Population frequencies of the B allele - - // Loop though the SNP positions and calculate the log2 ratio for - // the window up to the SNP, then calculate the log2 ratio centered - // at the SNP, and finally calculate the log2 ratio for the window - // after the SNP, and continue until the end of the window - std::vector window_log2_ratios; - int snp_count = (int) snp_window_pos.size(); - - // If there are no SNPs in the window, then use the default BAF and - // PFB values, and the coverage log2 ratio - if (snp_count == 0) - { - double window_log2_ratio = calculateLog2Ratio(window_start, window_end, pos_depth_map, mean_chr_cov); - double pfb_default = 0.5; - double baf_default = -1.0; // Use -1.0 to indicate no BAF data - this->updateSNPData(snp_data, (window_start + window_end) / 2, pfb_default, baf_default, window_log2_ratio, false); + uint32_t window_start = (uint32_t) (start_pos + i * pos_step); + uint32_t window_end = (uint32_t) (start_pos + (i + 1) * pos_step); - } else { - snps_found = true; + // Calculate the mean depth for the window + double cov_sum = 0.0; + int pos_count = 0; + for (int j = 0; j < pos_step; j++) + { + uint32_t pos = (uint32_t) (start_pos + i * pos_step + j); + if (pos > end_pos) + { + break; + } + if (pos < pos_depth_map.size()) { + cov_sum += pos_depth_map[pos]; + pos_count++; + } - // Loop through the SNPs and calculate the log2 ratios - int64_t bin_start = window_start; - int64_t bin_end = 0; - for (int j = 0; j < snp_count; j++) + } + double log2_cov = 0.0; + if (pos_count > 0) + { + if (cov_sum == 0) { - // SNP bin starts at 1/2 the distance between the previous SNP - // and the current SNP, and ends at 1/2 the distance between - // the current SNP and the next SNP. For the first SNP, the - // bin starts at the window start and ends at 1/2 the distance - // between the first SNP and the next SNP, and for the last - // SNP, the bin starts at 1/2 the distance between the previous - // SNP and the last SNP and ends at the window end. - int64_t snp_pos = snp_window_pos[j]; - bin_end = snp_pos + (j == snp_count-1 ? (window_end - snp_pos) / 2 : (snp_window_pos[j+1] - snp_pos) / 2); - - // Calculate the log2 ratio for the SNP bin - double bin_cov = calculateLog2Ratio(bin_start, bin_end, pos_depth_map, mean_chr_cov); - this->updateSNPData(snp_data, snp_pos, snp_window_pfbs[j], snp_window_bafs[j], bin_cov, true); - - // Update the previous bin start - bin_start = bin_end + 1; + // Use a small value to avoid division by zero + cov_sum = 1e-9; } + log2_cov = log2((cov_sum / (double) pos_count) / mean_chr_cov); } - } - - return std::make_pair(snp_data, snps_found); -} -void CNVCaller::updateSVsFromCopyNumberPrediction(SVData &sv_calls, std::vector> &sv_list, std::string chr) -{ - // Throw an error if there are more than two SV candidates - if (sv_list.size() > 2) { - throw std::runtime_error("Error: More than two SV candidates found for copy number prediction comparisons."); + // Store the log2 ratio for the window + std::string window_key = std::to_string(window_start) + "-" + std::to_string(window_end); + window_log2_map[window_key] = log2_cov; } - // Add a dummy call to the SV list if there is only one SV candidate - if (sv_list.size() == 1) { - SVCandidate dummy(0, 0, "."); - sv_list.push_back(std::make_pair(dummy, ".")); - } - - // Run copy number prediction for the SV pair and add only the SV - // candidate with the highest likelihood - SVCandidate& sv_one = sv_list[0].first; - SVCandidate& sv_two = sv_list[1].first; - std::tuple cnv_prediction = this->runCopyNumberPredictionPair(chr, sv_one, sv_two); - - // Get the SV info - int best_index = std::get<0>(cnv_prediction); - SVCandidate& best_sv_candidate = sv_list[best_index].first; - int64_t start_pos = std::get<0>(best_sv_candidate); - int64_t end_pos = std::get<1>(best_sv_candidate); - std::string aln_type = sv_list[best_index].second; - - // Get the prediction data - double best_likelihood = std::get<1>(cnv_prediction); - int best_cnv_type = std::get<2>(cnv_prediction); - std::string best_genotype = std::get<3>(cnv_prediction); - bool snps_found = std::get<4>(cnv_prediction); - if (snps_found) - { - aln_type += "_SNPS"; - } else { - aln_type += "_NOSNPS"; - } + // Create new vectors for the SNP data + std::vector snp_pos_hmm; + std::vector snp_baf_hmm; + std::vector snp_pfb_hmm; + std::vector snp_log2_hmm; + std::vector is_snp_hmm; - // Add the SV call to the main SV data - sv_calls.add(chr, start_pos, end_pos, best_cnv_type, ".", aln_type, best_genotype, best_likelihood); -} - -std::tuple CNVCaller::runCopyNumberPredictionPair(std::string chr, SVCandidate sv_one, SVCandidate sv_two) -{ - // std::cout << "Running copy number prediction for SV pair " << chr << ":" << std::get<0>(sv_one) << "-" << std::get<1>(sv_one) << " and " << std::get<0>(sv_two) << "-" << std::get<1>(sv_two) << "..." << std::endl; - double best_likelihood = 0.0; - bool best_likelihood_set = false; - bool snps_found = false; - int best_index = 0; - std::pair best_pos; - SNPData best_snp_data; - - // Get read depths for the SV candidate region - // int64_t region_start_pos = std::min(std::get<0>(sv_one), std::get<0>(sv_two)); - // int64_t region_end_pos = std::max(std::get<1>(sv_one), std::get<1>(sv_two)); - // std::unordered_map pos_depth_map; - // calculateDepthsForSNPRegion(chr, region_start_pos, region_end_pos, pos_depth_map); - - int current_index = 0; - int predicted_cnv_type = sv_types::UNKNOWN; - std::string genotype = "./."; - for (const auto& sv_call : {sv_one, sv_two}) + // Loop through the window ranges and append all SNPs in the range, using + // the log2 ratio for the window + for (const auto& window : window_log2_map) { - // Get the SV candidate - const SVCandidate& candidate = sv_call; + uint32_t window_start = std::stoi(window.first.substr(0, window.first.find('-'))); + uint32_t window_end = std::stoi(window.first.substr(window.first.find('-') + 1)); + double log2_cov = window.second; - // Get the start and end positions of the SV call - int64_t start_pos = std::get<0>(candidate); - int64_t end_pos = std::get<1>(candidate); - - // Skip if the start position equals zero (dummy call) - if (start_pos == 0) { - continue; - } - - // Get the depth at the start position, which is used as the FORMAT/DP - // value - // int dp_value = pos_depth_map[start_pos]; - - // Run the Viterbi algorithm on SNPs in the SV region +/- 1/2 - // the SV length - int64_t sv_length = (end_pos - start_pos) / 2.0; - int64_t snp_start_pos = std::max((int64_t) 1, start_pos - sv_length); - int64_t snp_end_pos = end_pos + sv_length; - - // Query the SNP region for the SV candidate - std::pair snp_call = querySNPRegion(chr, snp_start_pos, snp_end_pos, this->snp_info, this->pos_depth_map, this->mean_chr_cov); - SNPData sv_snps = snp_call.first; - bool sv_snps_found = snp_call.second; - - // Run the Viterbi algorithm - std::pair, double> prediction = runViterbi(this->hmm, sv_snps); - std::vector& state_sequence = prediction.first; - double likelihood = prediction.second; - - // Get all the states in the SV region - std::vector sv_states; - for (size_t i = 0; i < state_sequence.size(); i++) + // Loop through the SNP positions and add them to the SNP data + bool snp_found = false; + for (uint32_t pos : snp_pos) { - if (sv_snps.pos[i] >= start_pos && sv_snps.pos[i] <= end_pos) + if (pos >= window_start && pos <= window_end) { - sv_states.push_back(state_sequence[i]); + snp_pos_hmm.push_back(pos); + snp_baf_hmm.push_back(snp_baf_map[pos]); + snp_pfb_hmm.push_back(snp_pfb_map[pos]); + snp_log2_hmm.push_back(log2_cov); + is_snp_hmm.push_back(true); + snp_found = true; } } + if (!snp_found) + { + // If no SNPs were found in the window, add a dummy SNP with the + // log2 ratio for the window, using the window center as the SNP + // position + uint32_t window_center = (window_start + window_end) / 2; + snp_pos_hmm.push_back(window_center); + snp_baf_hmm.push_back(-1.0); + snp_pfb_hmm.push_back(0.5); + snp_log2_hmm.push_back(log2_cov); + is_snp_hmm.push_back(false); + } + } - // Determine if there is a majority state within the SV region and if it - // is greater than 75% - double pct_threshold = 0.75; - int max_state = 0; - int max_count = 0; - for (int i = 0; i < 6; i++) + // Update the SNP data with all information + snp_data.pos = std::move(snp_pos_hmm); + snp_data.baf = std::move(snp_baf_hmm); + snp_data.pfb = std::move(snp_pfb_hmm); + snp_data.log2_cov = std::move(snp_log2_hmm); + snp_data.is_snp = std::move(is_snp_hmm); +} + +std::tuple CNVCaller::runCopyNumberPrediction(std::string chr, const CHMM& hmm, uint32_t start_pos, uint32_t end_pos, double mean_chr_cov, const std::vector& pos_depth_map, const InputData& input_data) const +{ + // Check that the start position is less than the end position + if (start_pos > end_pos) + { + printError("ERROR: Invalid SV region for copy number prediction: " + chr + ":" + std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos)); + return std::make_tuple(0.0, SVType::UNKNOWN, Genotype::UNKNOWN, 0); + } + /* + // Check that there is no large number of zero-depth positions in the region + int zero_depth_count = 0; + for (uint32_t pos = start_pos; pos <= end_pos; pos++) + { + if (pos < pos_depth_map.size() && pos_depth_map[pos] == 0) { - int state_count = std::count(sv_states.begin(), sv_states.end(), i+1); - if (state_count > max_count) - { - max_state = i+1; - max_count = state_count; - } + zero_depth_count++; } - - // Update SV type and genotype based on the majority state - int state_count = (int) sv_states.size(); - if ((double) max_count / (double) state_count > pct_threshold) + } + if (zero_depth_count > 0.1 * (end_pos - start_pos + 1)) + { + printError("WARNING: Too many zero-depth positions in the SV region for copy number prediction, skipping: " + chr + ":" + std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos)); + return std::make_tuple(0.0, SVType::UNKNOWN, Genotype::UNKNOWN, 0); + } + */ + + // Run the Viterbi algorithm on SNPs in the SV region + // Only extend the region if "save CNV data" is enabled + SNPData before_sv; + SNPData after_sv; + if (input_data.getSaveCNVData()) + { + int sv_half_length = (static_cast(end_pos) - static_cast(start_pos)) / 2; + int before_sv_start = std::max(1, static_cast(start_pos) - sv_half_length); + int before_sv_end = std::max(1, static_cast(start_pos) - 1); + if (before_sv_start < before_sv_end) { - predicted_cnv_type = cnv_type_map[max_state]; - genotype = cnv_genotype_map[max_state]; + querySNPRegion(chr, before_sv_start, before_sv_end, pos_depth_map, mean_chr_cov, before_sv, input_data); } - // Update the best SV call based on the likelihood - if (!best_likelihood_set || (likelihood > best_likelihood)) + int chr_last_index = static_cast(pos_depth_map.size()) - 1; + int after_sv_start = std::min(chr_last_index, static_cast(end_pos) + 1); + int after_sv_end = std::min(chr_last_index, static_cast(end_pos) + sv_half_length); + if (after_sv_start < after_sv_end) { - best_likelihood = likelihood; - best_likelihood_set = true; - snps_found = sv_snps_found; - best_index = current_index; - - // Add the state sequence to the SNP data (avoid copying the data) - sv_snps.state_sequence = std::move(state_sequence); - best_snp_data = std::move(sv_snps); - best_pos = std::make_pair(start_pos, end_pos); + querySNPRegion(chr, after_sv_start, after_sv_end, pos_depth_map, mean_chr_cov, after_sv, input_data); } - current_index++; } - // Save the SV calls as a TSV file if enabled - int64_t sv_start_pos = std::get<0>(best_pos); - int64_t sv_end_pos = std::get<1>(best_pos); - if (this->input_data->getSaveCNVData() && predicted_cnv_type != sv_types::UNKNOWN && (sv_end_pos - sv_start_pos) > 10000) + // Query the SNP region for the SV candidate + SNPData snp_data; + querySNPRegion(chr, start_pos, end_pos, pos_depth_map, mean_chr_cov, snp_data, input_data); + + // Run the Viterbi algorithm + std::pair, double> prediction; + runViterbi(hmm, snp_data, prediction); + if (prediction.first.size() == 0) { - std::string cnv_type_str = SVTypeString[predicted_cnv_type]; - std::string sv_filename = this->input_data->getOutputDir() + "/" + cnv_type_str + "_" + chr + "_" + std::to_string((int) sv_start_pos) + "-" + std::to_string((int) sv_end_pos) + "_SPLITALN.tsv"; - std::cout << "Saving SV split-alignment copy number predictions to " << sv_filename << std::endl; - this->saveSVCopyNumberToTSV(best_snp_data, sv_filename, chr, best_pos.first, best_pos.second, cnv_type_str, best_likelihood); + return std::make_tuple(0.0, SVType::UNKNOWN, Genotype::UNKNOWN, 0); } - return std::make_tuple(best_index, best_likelihood, predicted_cnv_type, genotype, snps_found); -} + std::vector& state_sequence = prediction.first; + double likelihood = prediction.second; -SNPData CNVCaller::runCIGARCopyNumberPrediction(std::string chr, std::map &sv_candidates, int min_length) -{ - SNPInfo& snp_info = this->snp_info; - CHMM& hmm = this->hmm; - int window_size = this->input_data->getWindowSize(); - double mean_chr_cov = this->mean_chr_cov; - SNPData snp_data; - - // Filter the SV candidates by length - std::map filtered_sv_candidates; - for (const auto& sv_call : sv_candidates) + // Determine if there is a majority state within the SV region + int max_state = 0; + int max_count = 0; + for (int i = 0; i < 6; i++) { - int64_t start_pos = std::get<0>(sv_call.first); - int64_t end_pos = std::get<1>(sv_call.first); - if ((end_pos - start_pos) >= min_length) + int state_count = std::count(state_sequence.begin(), state_sequence.end(), i+1); + if (state_count > max_count) { - filtered_sv_candidates[sv_call.first] = sv_call.second; + max_state = i+1; + max_count = state_count; } } - sv_candidates = std::move(filtered_sv_candidates); - int sv_count = (int) sv_candidates.size(); - if (sv_count == 0) - { - return snp_data; - } - // Get read depths for the SV candidate region - // int64_t first_pos = std::get<0>(sv_candidates.begin()->first); - // int64_t last_pos = std::get<1>(sv_candidates.rbegin()->first); - // std::unordered_map pos_depth_map; - // calculateDepthsForSNPRegion(chr, first_pos, last_pos, pos_depth_map); - - // Run copy number prediction for the SV candidates - // Loop through each SV candidate and predict the copy number state - printMessage("Predicting CIGAR string copy number states for chromosome " + chr + "..."); - - // Create a map with counts for each CNV type - std::map cnv_type_counts; - for (int i = 0; i < 6; i++) + // If there is no majority state, then set the state to unknown + double pct_threshold = 0.50; + int state_count = (int) state_sequence.size(); + if ((double) max_count / (double) state_count < pct_threshold) { - cnv_type_counts[i] = 0; + max_state = 0; } - // Split the SV candidates into chunks for each thread - int chunk_count = this->input_data->getThreadCount(); - std::vector> sv_chunks = splitSVCandidatesIntoChunks(sv_candidates, chunk_count); + Genotype genotype = getGenotypeFromCNState(max_state); + SVType predicted_cnv_type = getSVTypeFromCNState(max_state); - // Loop through each SV chunk and run the copy number prediction in parallel - std::vector> futures; - for (const auto& sv_chunk : sv_chunks) + // Save the SV calls if enabled + uint32_t min_length = 30000; + bool copy_number_change = (predicted_cnv_type != SVType::UNKNOWN && predicted_cnv_type != SVType::NEUTRAL); + if (input_data.getSaveCNVData() && copy_number_change && (end_pos - start_pos) >= min_length) { - // Run the copy number prediction for the SV chunk - std::async(std::launch::async, &CNVCaller::runCIGARCopyNumberPredictionChunk, this, chr, std::ref(sv_candidates), sv_chunk, std::ref(snp_info), hmm, window_size, mean_chr_cov, std::ref(this->pos_depth_map)); - } + // Move the state sequence to the SNP data + snp_data.state_sequence = std::move(state_sequence); - // Get the SNP data for each SV chunk - int current_chunk = 0; - for (auto& future : futures) - { - current_chunk++; - SNPData chunk_snp_data = std::move(future.get()); - if (this->input_data->getVerbose()) + // Set B-allele and population frequency values to 0 for non-SNPs + for (size_t i = 0; i < snp_data.pos.size(); i++) { - printMessage("Finished processing SV chunk " + std::to_string(current_chunk) + " of " + std::to_string(chunk_count) + "..."); + if (!snp_data.is_snp[i]) + { + snp_data.baf[i] = 0.0; + snp_data.pfb[i] = 0.0; + } + } + for (size_t i = 0; i < before_sv.pos.size(); i++) + { + if (!before_sv.is_snp[i]) + { + before_sv.baf[i] = 0.0; + before_sv.pfb[i] = 0.0; + } + } + for (size_t i = 0; i < after_sv.pos.size(); i++) + { + if (!after_sv.is_snp[i]) + { + after_sv.baf[i] = 0.0; + after_sv.pfb[i] = 0.0; + } } - } - printMessage("Finished predicting copy number states for chromosome " + chr + "..."); + // Save the SNP data to JSON + std::string cnv_type_str = getSVTypeString(predicted_cnv_type); + std::string json_filepath = input_data.getCNVOutputFile(); + printMessage("Saving SV copy number predictions to " + json_filepath + "..."); - return snp_data; + this->saveSVCopyNumberToJSON(before_sv, after_sv, snp_data, chr, start_pos, end_pos, cnv_type_str, likelihood, json_filepath); + } + + return std::make_tuple(likelihood, predicted_cnv_type, genotype, max_state); } -void CNVCaller::runCIGARCopyNumberPredictionChunk(std::string chr, std::map& sv_candidates, std::vector sv_chunk, SNPInfo& snp_info, CHMM hmm, int window_size, double mean_chr_cov, std::unordered_map& pos_depth_map) + +void CNVCaller::runCIGARCopyNumberPrediction(std::string chr, std::vector& sv_candidates, const CHMM& hmm, double mean_chr_cov, const std::vector& pos_depth_map, const InputData& input_data) const { - // printMessage("Running copy number prediction for " + std::to_string(sv_chunk.size()) + " SV candidates on chromosome " + chr + "..."); // Map with counts for each CNV type std::map cnv_type_counts; for (int i = 0; i < 6; i++) @@ -358,47 +306,39 @@ void CNVCaller::runCIGARCopyNumberPredictionChunk(std::string chr, std::map(candidate); - int64_t end_pos = std::get<1>(candidate); - - // // [TEST] Skip if not in the following list of SVs - // std::vector sv_list = {"chr19:53013528-53051102", "chr1:43593639-43617165", "chr6:35786784-35799012", "chr1:152787870-152798352", "chr17:41265461-41275765", "chr5:180950357-181003515"}; - // std::string sv_key = chr + ":" + std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos); - // if (std::find(sv_list.begin(), sv_list.end(), sv_key) == sv_list.end()) - // { - // continue; - // } - - // Get the depth at the start position. This is used as the FORMAT/DP - // value in the VCF file - int dp_value = pos_depth_map[start_pos]; - this->updateDPValue(sv_candidates, sv_call, dp_value); - - // Loop through the SV region, calculate the log2 ratios, and run the - // Viterbi algorithm to predict the copy number states - - // We will run the Viterbi algorithm on SNPs in the SV region +/- 1/2 - // the SV length - int64_t sv_half_length = (end_pos - start_pos) / 2.0; - // std::cout << "SV half length: " << sv_half_length << std::endl; - int64_t query_start = std::max((int64_t) 1, start_pos - sv_half_length); - int64_t query_end = end_pos + sv_half_length; - - // printMessage("Querying SNPs for SV " + chr + ":" + - // std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos) + - // "..."); - std::pair snp_call = this->querySNPRegion(chr, query_start, query_end, snp_info, pos_depth_map, mean_chr_cov); - SNPData& sv_snps = snp_call.first; - bool snps_found = snp_call.second; + uint32_t start_pos = sv_call.start; + uint32_t end_pos = sv_call.end; + + // Error if start > end + if (start_pos > end_pos) + { + printError("ERROR: Invalid SV region for copy number prediction: " + chr + ":" + std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos)); + continue; + } + + // Skip if not the minimum length for CNV predictions + if ((end_pos - start_pos) < input_data.getMinCNVLength()) + { + continue; + } + + // Only extend the region if "save CNV data" is enabled + SNPData snp_data; + this->querySNPRegion(chr, start_pos, end_pos, pos_depth_map, mean_chr_cov, snp_data, input_data); // Run the Viterbi algorithm - std::pair, double> prediction = runViterbi(hmm, sv_snps); + if (snp_data.pos.size() == 0) { + printError("ERROR: No SNP data found for Viterbi algorithm for CIGAR SV at " + chr + ":" + std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos)); + continue; + } + + // printMessage("Running Viterbi algorithm for copy number prediction: " + chr + ":" + std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos)); + std::pair, double> prediction; + runViterbi(hmm, snp_data, prediction); std::vector& state_sequence = prediction.first; double likelihood = prediction.second; @@ -406,14 +346,14 @@ void CNVCaller::runCIGARCopyNumberPredictionChunk(std::string chr, std::map sv_states; for (size_t i = 0; i < state_sequence.size(); i++) { - if (sv_snps.pos[i] >= start_pos && sv_snps.pos[i] <= end_pos) + if (snp_data.pos[i] >= start_pos && snp_data.pos[i] <= end_pos) { sv_states.push_back(state_sequence[i]); } } // Determine if there is a majority state within the SV region and if it - // is greater than 75% + // is greater than 50% int max_state = 0; int max_count = 0; for (int i = 0; i < 6; i++) @@ -427,93 +367,42 @@ void CNVCaller::runCIGARCopyNumberPredictionChunk(std::string chr, std::mapupdateSVCopyNumber(sv_candidates, sv_call, cnv_type, data_type, genotype, likelihood); - - // Save the SV calls as a TSV file if enabled, if the SV type is - // known, and the length is greater than 10 kb - int updated_sv_type = sv_candidates[sv_call].sv_type; - if (this->input_data->getSaveCNVData() && updated_sv_type != sv_types::UNKNOWN && (end_pos - start_pos) > 10000) - { - // Add the state sequence to the SNP data (avoid copying the data) - sv_snps.state_sequence = std::move(state_sequence); - - // Save the SV calls as a TSV file - std::string cnv_type_str = SVTypeString[updated_sv_type]; - std::string sv_filename = this->input_data->getOutputDir() + "/" + cnv_type_str + "_" + chr + "_" + std::to_string((int) start_pos) + "-" + std::to_string((int) end_pos) + "_CIGAR.tsv"; - // std::cout << "Saving SV CIGAR copy number predictions to " << - // sv_filename << std::endl; - printMessage("Saving SV CIGAR copy number predictions to " + sv_filename); - this->saveSVCopyNumberToTSV(sv_snps, sv_filename, chr, start_pos, end_pos, cnv_type_str, likelihood); - } - } -} - -void CNVCaller::updateSVCopyNumber(std::map &sv_candidates, SVCandidate key, int sv_type_update, std::string data_type, std::string genotype, double hmm_likelihood) -{ - // Update SV data from the HMM copy number prediction - // Lock the SV candidate map - std::lock_guard lock(this->sv_candidates_mtx); - - // Update the SV type if the update is not unknown, and if the types don't - // conflict (To avoid overwriting CIGAR-based SV calls with SNP-based calls) - int current_sv_type = sv_candidates[key].sv_type; - if ((sv_type_update != sv_types::UNKNOWN) && ((current_sv_type == sv_type_update) || (current_sv_type == sv_types::UNKNOWN))) - { - sv_candidates[key].sv_type = sv_type_update; // Update the SV type - sv_candidates[key].data_type.insert(data_type); // Update the data type + // Update the SV information if it does not conflict with the current SV type + SVType updated_sv_type = getSVTypeFromCNState(max_state); - // Update the likelihood if it is greater than the existing likelihood, - // or if it is currently unknown (0.0) - double previous_likelihood = sv_candidates[key].hmm_likelihood; - if (previous_likelihood == 0.0 || hmm_likelihood > previous_likelihood) + // For LOH predictions, or predictions with the same type, + // update predicted information without changing the SV type + // printMessage("Updating SV call for " + chr + ":" + std::to_string((int)start_pos) + "-" + std::to_string((int)end_pos) + " with predicted CNV type: " + getSVTypeString(updated_sv_type)); + updated_sv_type = (updated_sv_type == SVType::LOH) ? sv_call.sv_type : updated_sv_type; + bool is_valid_update = isValidCopyNumberUpdate(sv_call.sv_type, updated_sv_type); + if (is_valid_update) { - sv_candidates[key].hmm_likelihood = hmm_likelihood; + sv_call.sv_type = updated_sv_type; + // sv_call.data_type = SVDataType::HMM; + sv_call.aln_type.set(static_cast(SVDataType::HMM)); + sv_call.hmm_likelihood = likelihood; + sv_call.genotype = genotype; + sv_call.cn_state = max_state; } - - // Update the genotype - sv_candidates[key].genotype = genotype; } } -void CNVCaller::updateDPValue(std::map& sv_candidates, SVCandidate key, int dp_value) -{ - // Lock the SV candidate map - std::lock_guard lock(this->sv_candidates_mtx); - - // Update the DP value - sv_candidates[key].read_depth = dp_value; -} - -std::vector CNVCaller::splitRegionIntoChunks(std::string chr, int64_t start_pos, int64_t end_pos, int chunk_count) +std::vector CNVCaller::splitRegionIntoChunks(std::string chr, uint32_t start_pos, uint32_t end_pos, int chunk_count) const { // Split the region into chunks std::vector region_chunks; - int64_t region_length = end_pos - start_pos + 1; - int64_t chunk_size = std::ceil((double) region_length / (double) chunk_count); - int64_t chunk_start = start_pos; - int64_t chunk_end = 0; + uint32_t region_length = end_pos - start_pos + 1; + uint32_t chunk_size = std::ceil((double) region_length / (double) chunk_count); + uint32_t chunk_start = start_pos; + uint32_t chunk_end = 0; for (int i = 0; i < chunk_count; i++) { chunk_end = chunk_start + chunk_size - 1; @@ -525,580 +414,410 @@ std::vector CNVCaller::splitRegionIntoChunks(std::string chr, int64 // Add the region chunk to the vector region_chunks.push_back(chr + ":" + std::to_string(chunk_start) + "-" + std::to_string(chunk_end)); - - // Update the chunk start chunk_start = chunk_end + 1; } return region_chunks; } -std::vector> CNVCaller::splitSVCandidatesIntoChunks(std::map& sv_candidates, int chunk_count) +// Calculate the mean chromosome coverage +void CNVCaller::calculateMeanChromosomeCoverage(const std::vector& chromosomes, std::unordered_map>& chr_pos_depth_map, std::unordered_map& chr_mean_cov_map, const std::string& bam_filepath, int thread_count) const { - // Split the SV candidates into chunks - std::vector> sv_chunks; - int sv_count = (int) sv_candidates.size(); - int chunk_size = std::ceil((double) sv_count / (double) chunk_count); - int current_chunk = 0; - std::vector current_sv_chunk; - for (auto const& sv_call : sv_candidates) + // Open the BAM file + printMessage("Opening BAM file: " + bam_filepath); + samFile *bam_file = sam_open(bam_filepath.c_str(), "r"); + if (!bam_file) { - // Add the SV candidate to the current chunk - current_sv_chunk.push_back(sv_call.first); - - // If the current chunk size is reached, then add the chunk to the - // vector and reset the current chunk - if ((int) current_sv_chunk.size() == chunk_size) - { - sv_chunks.push_back(current_sv_chunk); - current_sv_chunk.clear(); - current_chunk++; - } + printError("ERROR: Could not open BAM file: " + bam_filepath); + return; } - // Add the remaining SV candidates to the last chunk - if (current_sv_chunk.size() > 0) + // Enable multi-threading while opening the BAM file + hts_set_threads(bam_file, thread_count); + + // Read the header + bam_hdr_t *bam_header = sam_hdr_read(bam_file); + if (!bam_header) { - sv_chunks.push_back(current_sv_chunk); + sam_close(bam_file); + printError("ERROR: Could not read header from BAM file: " + bam_filepath); + return; } - return sv_chunks; -} - -CNVCaller::CNVCaller(InputData &input_data) -{ - this->input_data = &input_data; -} + // Load the index + hts_idx_t *bam_index = sam_index_load(bam_file, bam_filepath.c_str()); + if (!bam_index) + { + bam_hdr_destroy(bam_header); + sam_close(bam_file); + printError("ERROR: Could not load index for BAM file: " + bam_filepath); + return; + } + // BamFileGuard bam_guard(bam_file, bam_index, bam_header); // Guard to close the BAM file -void CNVCaller::loadChromosomeData(std::string chr) -{ - // Read the HMM from file - std::string hmm_filepath = this->input_data->getHMMFilepath(); - std::cout << "Reading HMM from file: " << hmm_filepath << std::endl; - this->hmm = ReadCHMM(hmm_filepath.c_str()); - - // Calculate the mean chromosome coverage and generate the position-depth map - printMessage("Calculating mean chromosome coverage for " + chr + "..."); - mean_chr_cov = calculateMeanChromosomeCoverage(chr); - printMessage("Mean chromosome coverage for " + chr + ": " + std::to_string(mean_chr_cov)); - this->mean_chr_cov = mean_chr_cov; - - // Read the SNP positions and B-allele frequency values from the VCF file - std::cout << "Reading SNP allele frequencies for chromosome " << chr << " from VCF file..." << std::endl; - std::string snp_filepath = this->input_data->getSNPFilepath(); - readSNPAlleleFrequencies(chr, snp_filepath, this->snp_info); - - // Get the population frequencies for each SNP - std::cout << "Obtaining SNP population frequencies for chromosome " << chr << "..." << std::endl; - getSNPPopulationFrequencies(chr, this->snp_info); - std::cout << "Finished loading chromosome data for " << chr << std::endl; -} + // Initialize the record + bam1_t *bam_record = bam_init1(); + if (!bam_record) + { + // Clean up the BAM file and index + bam_hdr_destroy(bam_header); + sam_close(bam_file); + printError("ERROR: Could not initialize BAM record."); + return; + } -// Calculate the mean chromosome coverage -double CNVCaller::calculateMeanChromosomeCoverage(std::string chr) -{ - // Split the chromosome into equal parts for each thread - int num_threads = this->input_data->getThreadCount(); - uint32_t chr_len = this->input_data->getRefGenomeChromosomeLength(chr); - std::vector region_chunks = splitRegionIntoChunks(chr, 1, chr_len, num_threads); - - // Calculate the mean chromosome coverage in parallel - uint32_t pos_count = 0; - uint64_t cum_depth = 0; - std::vector>>> futures; - std::string input_filepath = this->input_data->getShortReadBam(); - for (const auto& region_chunk : region_chunks) + // Iterate through each chromosome and update the depth map + int current_chr = 0; + int total_chr_count = chromosomes.size(); + for (const std::string& chr : chromosomes) { - // Create a lambda function to get the mean chromosome coverage for the - // region chunk - auto get_mean_chr_cov = [region_chunk, input_filepath]() -> std::tuple> + // Create an iterator for the chromosome + hts_itr_t *bam_iter = sam_itr_querys(bam_index, bam_header, chr.c_str()); + if (!bam_iter) { - // Run samtools depth on the entire region, and print positions and - // depths (not chromosome) - const int cmd_size = 256; - char cmd[cmd_size]; - snprintf(cmd, cmd_size,\ - "samtools depth -r %s %s | awk '{print $2, $3}'",\ - region_chunk.c_str(), input_filepath.c_str()); - - // Open a pipe to read the output of the command - FILE *fp = popen(cmd, "r"); - if (fp == NULL) + printError("ERROR: Could not create iterator for chromosome: " + chr + ", check if the chromosome exists in the BAM file."); + continue; + } + + printMessage("(" + std::to_string(++current_chr) + "/" + std::to_string(total_chr_count) + ") Reading BAM file for chromosome: " + chr); + std::vector& pos_depth_map = chr_pos_depth_map[chr]; + int tid = bam_name2id(bam_header, chr.c_str()); + if (tid < 0) + { + printError("ERROR: Could not find chromosome " + chr + " in BAM file."); + continue; + } + // Resize the depth map to the length of the chromosome + uint32_t chr_length = bam_header->target_len[tid] + 1; + if (pos_depth_map.size() != static_cast(chr_length)) + { + printError("ERROR: Chromosome length mismatch for " + chr + ": expected " + std::to_string(chr_length) + ", found " + std::to_string(pos_depth_map.size()) + ", resizing to " + std::to_string(chr_length)); + pos_depth_map.resize(chr_length, 0); + } + while (sam_itr_next(bam_file, bam_iter, bam_record) >= 0) + { + // Ignore UNMAP, SECONDARY, QCFAIL, and DUP reads + uint16_t flag = bam_record->core.flag; + if (flag & (BAM_FUNMAP | BAM_FSECONDARY | BAM_FQCFAIL | BAM_FDUP)) { - printError("ERROR: Could not open pipe for command: " + std::string(cmd)); - exit(EXIT_FAILURE); + continue; } - // Parse the outputs (position and depth) - std::unordered_map pos_depth_map; - const int line_size = 256; - char line[line_size]; - uint32_t pos; - int depth; - uint32_t pos_count = 0; - uint64_t cum_depth = 0; - while (fgets(line, line_size, fp) != NULL) + // Parse the CIGAR string to get the depth (match, sequence match, and + // mismatch) + uint32_t pos = (uint32_t)bam_record->core.pos + 1; // 0-based to 1-based + uint32_t ref_pos = pos; + uint32_t cigar_len = bam_record->core.n_cigar; + uint32_t *cigar = bam_get_cigar(bam_record); + for (uint32_t i = 0; i < cigar_len; i++) { - if (sscanf(line, "%u%d", &pos, &depth) == 2) + uint32_t op = bam_cigar_op(cigar[i]); + uint32_t op_len = bam_cigar_oplen(cigar[i]); + if (op == BAM_CMATCH || op == BAM_CEQUAL || op == BAM_CDIFF) { - pos_depth_map[pos] = depth; - pos_count++; - cum_depth += depth; + // Update the depth for each position in the alignment + for (uint32_t j = 0; j < op_len; j++) + { + if (ref_pos + j >= pos_depth_map.size()) + { + printError("ERROR: Reference position out of range for " + chr + ":" + std::to_string(ref_pos+j)); + continue; + } + pos_depth_map[ref_pos + j]++; + } + } + + // Update the reference coordinate based on the CIGAR operation + // https://samtools.github.io/hts-specs/SAMv1.pdf + if (op == BAM_CMATCH || op == BAM_CDEL || op == BAM_CREF_SKIP || op == BAM_CEQUAL || op == BAM_CDIFF) { + ref_pos += op_len; + } else if (op == BAM_CINS || op == BAM_CSOFT_CLIP || op == BAM_CHARD_CLIP || op == BAM_CPAD) { + // Do nothing + } else { + printError("ERROR: Unknown CIGAR operation: " + std::to_string(op)); } } - pclose(fp); // Close the process - - return std::make_tuple(pos_count, cum_depth, pos_depth_map); - }; - std::future>> future = std::async(std::launch::async, get_mean_chr_cov); - futures.push_back(std::move(future)); - } + } + hts_itr_destroy(bam_iter); - // Loop through the futures and get the results - for (auto& future : futures) - { - future.wait(); - std::tuple> result = std::move(future.get()); + uint64_t cum_depth = std::accumulate(pos_depth_map.begin(), pos_depth_map.end(), 0ULL); + uint32_t pos_count = std::count_if(pos_depth_map.begin(), pos_depth_map.end(), [](uint32_t depth) { return depth > 0; }); - // Update the position count, cumulative depth, and merge the position-depth maps - pos_count += std::get<0>(result); - cum_depth += std::get<1>(result); - this->mergePosDepthMaps(this->pos_depth_map, std::get<2>(result)); + // Calculate the mean coverage for the chromosome + double mean_chr_cov = (pos_count > 0) ? static_cast(cum_depth) / static_cast(pos_count) : 0.0; + printMessage("Mean coverage for chromosome " + chr + ": " + std::to_string(mean_chr_cov)); + if (mean_chr_cov != 0.0) { + chr_mean_cov_map[chr] = mean_chr_cov; + } } - double mean_chr_cov = (double) cum_depth / (double) pos_count; - return mean_chr_cov; + // Clean up the BAM file and index + printMessage("Closing BAM file " + bam_filepath); + bam_destroy1(bam_record); + hts_idx_destroy(bam_index); + bam_hdr_destroy(bam_header); + sam_close(bam_file); + bam_record = nullptr; + bam_index = nullptr; + bam_header = nullptr; + bam_file = nullptr; + printMessage("BAM file closed."); } -void CNVCaller::calculateDepthsForSNPRegion(std::string chr, int64_t start_pos, int64_t end_pos, std::unordered_map& pos_depth_map) +void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, uint32_t end_pos, std::vector& snp_pos, std::unordered_map& snp_baf, std::unordered_map& snp_pfb, const InputData& input_data) const { - std::cout << "Calculating read depths for SV region " << chr << ":" << start_pos << "-" << end_pos << "..." << std::endl; - - // // If extending the CNV regions, then extend the SV region by window size * - // // N. Otherwise, log2 ratios will be zero due to missing read depth data - // // before/after the first/last SV positions - // if (this->input_data->getSaveCNVData()) - // { - // int extend_factor = 100; - // int window_size = this->input_data->getWindowSize(); - // start_pos = std::max((int64_t) 1, start_pos - (window_size * extend_factor)); - // end_pos = end_pos + (window_size * extend_factor); - // } - - // // Split the region into equal parts for each thread if the region is larger - // // than 100 kb - // int num_threads = this->input_data->getThreadCount(); - // std::vector region_chunks; - // int64_t region_size = end_pos - start_pos; - // if (region_size < 100000) - // { - // region_chunks.push_back(chr + ":" + std::to_string(start_pos) + "-" + std::to_string(end_pos)); - // } else { - // region_chunks = splitRegionIntoChunks(chr, start_pos, end_pos, num_threads); - // } - - // // Loop through each region chunk and get the mean chromosome coverage in - // // parallel - // std::string input_filepath = this->input_data->getShortReadBam(); - // std::vector>> futures; - // for (const auto& region_chunk : region_chunks) - // { - // // Create a lambda function to get the mean chromosome coverage for the - // // region chunk - // auto get_pos_depth_map = [region_chunk, input_filepath]() -> std::unordered_map - // { - // // Run samtools depth on the entire region, and print positions and - // // depths (not chromosome) - // const int cmd_size = 256; - // char cmd[cmd_size]; - // snprintf(cmd, cmd_size, - // "samtools depth -r %s %s | awk '{print $2, $3}'", - // region_chunk.c_str(), input_filepath.c_str()); - - // // Open a pipe to read the output of the command - // FILE *fp = popen(cmd, "r"); - // if (fp == NULL) - // { - // std::cerr << "ERROR: Could not open pipe for command: " << cmd << std::endl; - // exit(EXIT_FAILURE); - // } - - // // Create a map of positions and depths - // std::unordered_map pos_depth_map; - // const int line_size = 1024; - // char line[line_size]; - // while (fgets(line, line_size, fp) != NULL) - // { - // // Parse the line - // uint64_t pos; - // int depth; - // if (sscanf(line, "%ld%d", &pos, &depth) == 2) - // { - // // Add the position and depth to the map - // pos_depth_map[pos] = depth; - // } else { - // // No reads - // } - // } - - // // Close the pipe - // pclose(fp); - - // return pos_depth_map; - // }; - - // // Create a future for the thread - // std::future> future = std::async(std::launch::async, get_pos_depth_map); - - // // Add the future to the vector - // futures.push_back(std::move(future)); - // } - - // // Loop through the futures and get the results - // int current_chunk = 0; - // for (auto& future : futures) - // { - // current_chunk++; - // future.wait(); - // std::unordered_map result = std::move(future.get()); - - // // Merge the position depth maps - // this->mergePosDepthMaps(pos_depth_map, result); - // if (this->input_data->getVerbose()) - // { - // printMessage("Completed region chunk " + std::to_string(current_chunk) + " of " + std::to_string(region_chunks.size()) + "..."); - // } - // } -} + // Lock during reading + std::shared_lock lock(this->shared_mutex); -void CNVCaller::mergePosDepthMaps(std::unordered_map& main_map, std::unordered_map& map_update) -{ - // Merge the second depth map into the first - main_map.reserve(main_map.size() + map_update.size()); - for (auto& pos_depth : map_update) + // --------- SNP file --------- + const std::string snp_filepath = input_data.getSNPFilepath(); + if (snp_filepath.empty()) { - main_map[pos_depth.first] = std::move(pos_depth.second); + printError("ERROR: SNP file path is empty."); + return; } -} -double CNVCaller::calculateLog2Ratio(uint32_t start_pos, uint32_t end_pos, std::unordered_map &pos_depth_map, double mean_chr_cov) -{ - // Use the position and depth map to calculate the log2 ratio - double cum_depth = 0; - int pos_count = 0; - for (uint32_t i = start_pos; i <= end_pos; i++) + // Initialize the SNP file reader + bcf_srs_t *snp_reader = bcf_sr_init(); + if (!snp_reader) { - // Check if the position is in the map - auto it = pos_depth_map.find(i); - if (it == pos_depth_map.end()) - { - continue; - } - int depth = pos_depth_map[i]; - pos_count++; - cum_depth += depth; + printError("ERROR: Could not initialize SNP reader."); + return; } + snp_reader->require_index = 1; - // Calculate the window coverage log2 ratio (0 if no positions) - double window_mean_cov = 0; - if (pos_count > 0) - { - window_mean_cov = (double) cum_depth / (double) pos_count; - } + // Use multi-threading if not threading by chromosome + int thread_count = input_data.getThreadCount(); + bcf_sr_set_threads(snp_reader, thread_count); - // Calculate the log2 ratio for the window - // Avoid log2(0) by using a small value - if (window_mean_cov == 0) + // Add the SNP file to the reader + if (bcf_sr_add_reader(snp_reader, snp_filepath.c_str()) < 0) { - window_mean_cov = 0.0001; + bcf_sr_destroy(snp_reader); + printError("ERROR: Could not add SNP file to reader: " + snp_filepath); + return; } - double window_log2_ratio = log2(window_mean_cov / mean_chr_cov); - return window_log2_ratio; -} + // --------- Population allele frequency file --------- -void CNVCaller::readSNPAlleleFrequencies(std::string chr, std::string filepath, SNPInfo& snp_info) -{ - // Check that the SNP file is sorted by running bcftools index and reading - // the error output - std::string index_cmd = "bcftools index " + filepath + " 2>&1 | grep -i error"; - if (this->input_data->getVerbose()) { - std::cout << "Command: " << index_cmd << std::endl; + // Get the population allele frequency file path + bool use_pfb = true; + const std::string pfb_filepath = input_data.getAlleleFreqFilepath(chr); + if (pfb_filepath.empty()) + { + use_pfb = false; } - // Open a pipe to read the output of the command - FILE *index_fp = popen(index_cmd.c_str(), "r"); - if (index_fp == NULL) + // Ensure the file exists (ifsstream will throw an exception if the file + // does not exist) + std::ifstream pfb_file(pfb_filepath); + if (!pfb_file) { - std::cerr << "ERROR: Could not open pipe for command: " << index_cmd << std::endl; - exit(1); + use_pfb = false; } + pfb_file.close(); - // Read the output of the command - const int error_size = 256; - char index_error[error_size]; - while (fgets(index_error, error_size, index_fp) != NULL) + bcf_srs_t *pfb_reader = bcf_sr_init(); + std::string chr_gnomad = chr; + std::string AF_key; + if (use_pfb) { - std::cerr << "ERROR: " << index_error << std::endl; - exit(1); - } + // Determine the ethnicity-specific allele frequency key + AF_key = "AF"; + const std::string eth = input_data.getEthnicity(); + if (eth != "") + { + AF_key += "_" + eth; + } - // Close the pipe - pclose(index_fp); + // Check if the filepath uses the 'chr' prefix notations based on the + // chromosome name (*.chr1.vcf.gz vs *.1.vcf.gz) + std::string chr_prefix = "chr"; + if (pfb_filepath.find(chr_prefix) == std::string::npos) + { + // Remove the 'chr' prefix from the chromosome name + if (chr_gnomad.find(chr_prefix) != std::string::npos) + { + chr_gnomad = chr_gnomad.substr(chr_prefix.length()); + } + } else { + // Add the 'chr' prefix to the chromosome name + if (chr_gnomad.find(chr_prefix) == std::string::npos) + { + chr_gnomad = chr_prefix + chr; + } + } - // Filter variants by depth, quality, and region - if (this->input_data->getVerbose()) { - std::cout << "Filtering SNPs by depth, quality, and region..." << std::endl; - } + // Initialize the population allele frequency reader + if (!pfb_reader) + { + printError("ERROR: Could not initialize population allele frequency reader."); - // // Check if a region was specified by the user - std::string region_str = chr; - if (this->input_data->isRegionSet()) - { - std::pair region = this->input_data->getRegion(); - region_str = chr + ":" + std::to_string(region.first) + "-" + std::to_string(region.second); - } + // Clean up + bcf_sr_destroy(snp_reader); + return; + } + pfb_reader->require_index = 1; - std::string filtered_snp_vcf_filepath = this->input_data->getOutputDir() + "/filtered_snps.vcf"; - std::string cmd = "bcftools view -r " + region_str + " -v snps -i 'QUAL > 30 && DP > 10 && FILTER = \"PASS\"' " + filepath + " > " + filtered_snp_vcf_filepath; - if (this->input_data->getVerbose()) { - std::cout << "Filtering SNPs by depth and quality..." << std::endl; - std::cout << "Command: " << cmd << std::endl; - } - system(cmd.c_str()); - - if (this->input_data->getVerbose()) { - std::cout << "Filtered SNPs written to " << filtered_snp_vcf_filepath << std::endl; - } + // Add the population allele frequency file to the reader + if (bcf_sr_add_reader(pfb_reader, pfb_filepath.c_str()) < 0) + { + printError("ERROR: Could not add population allele frequency file to reader: " + pfb_filepath); - // Extract B-allele frequency data from the VCF file and sort by chromosome - // and position - if (this->input_data->getVerbose()) { - std::cout << "Extracting B-allele frequency data from filtered SNPs..." << std::endl; + // Clean up + bcf_sr_destroy(pfb_reader); + bcf_sr_destroy(snp_reader); + return; + } + + // Use multi-threading if not threading by chromosome + int thread_count = input_data.getThreadCount(); + bcf_sr_set_threads(pfb_reader, thread_count); } - cmd = "bcftools query -f '%POS,[%AD]\n' " + filtered_snp_vcf_filepath + " 2>/dev/null"; - FILE *fp = popen(cmd.c_str(), "r"); - if (fp == NULL) + + // Read the SNP data + + // Set the region + std::string region_str = chr + ":" + std::to_string(start_pos) + "-" + std::to_string(end_pos); + if (bcf_sr_set_regions(snp_reader, region_str.c_str(), 0) < 0) //chr.c_str(), 0) < 0) { - std::cerr << "ERROR: Could not open pipe for command: " << cmd << std::endl; - exit(1); + printError("ERROR: Could not set region for SNP reader: " + chr); + bcf_sr_destroy(snp_reader); + bcf_sr_destroy(pfb_reader); + return; } - // Read the reference and alternate allele depths from the VCF file - std::string alt_allele = ""; // Alternate allele - uint64_t pos = 0; - int ref_ad = 0; - int alt_ad = 0; - const int line_size = 256; - char line[line_size]; // Line buffer - std::vector locations; - std::vector bafs; - while (fgets(line, line_size, fp) != NULL) + bool snp_found = false; + while (bcf_sr_next_line(snp_reader) > 0) { - // Parse the line - char *tok = strtok(line, ","); // Tokenize the line - int col = 0; // Column index - while (tok != NULL) + if (!bcf_sr_has_line(snp_reader, 0)) + { + continue; + } + bcf1_t *snp_record = bcf_sr_get_line(snp_reader, 0); + if (snp_record) { - // Get the position from column 2 - if (col == 0) + uint32_t pos = (uint32_t)snp_record->pos + 1; + + // Skip if not a SNP + if (!bcf_is_snp(snp_record)) { - pos = atoi(tok); + continue; } - // Get the AD for the reference allele from column 3 - else if (col == 1) + // Get the QUAL, DP, and AD values + if (bcf_float_is_missing(snp_record->qual) || snp_record->qual <= 30) { - ref_ad = atoi(tok); + continue; } - // Get the AD for the non-reference allele from column 4 - else if (col == 2) + // Extract DP from FORMAT field + int32_t *dp = 0; + int dp_count = 0; + int dp_ret = bcf_get_format_int32(snp_reader->readers[0].header, snp_record, "DP", &dp, &dp_count); + if (dp_ret < 0 || dp[0] <= 10) { - alt_ad = atoi(tok); + continue; } + free(dp); - // Move to the next token - tok = strtok(NULL, ","); - col++; - } - - // Calculate the B-allele frequency (BAF) as the ratio of the alternate - // allele depth to the total depth (reference + alternate) - double baf = (double) alt_ad / (double) (ref_ad + alt_ad); + // Skip if the SNP does not pass the filter + if (bcf_has_filter(snp_reader->readers[0].header, snp_record, const_cast("PASS")) != 1) + { + continue; + } - // Add a new location and BAF value to the chromosome's SNP data - // (population frequency and log2 ratio will be added later) - snp_info.insertSNPAlleleFrequency(chr, pos, baf); - } + // Extract AD from FORMAT field + int32_t *ad = 0; + int ad_count = 0; + int ad_ret = bcf_get_format_int32(snp_reader->readers[0].header, snp_record, "AD", &ad, &ad_count); + if (ad_ret < 0 || ad_count < 2) + { + continue; + } - // Close the pipe - pclose(fp); + // Calculate the B-allele frequency (BAF) + double baf = (double) ad[1] / (double) (ad[0] + ad[1]); + free(ad); - if (this->input_data->getVerbose()) { - std::cout << "Finished extracting B-allele frequency data from filtered SNPs" << std::endl; + // Add the SNP position and BAF information + snp_pos.push_back(pos); + snp_baf[pos] = baf; + snp_found = true; + } } -} -void CNVCaller::getSNPPopulationFrequencies(std::string chr, SNPInfo& snp_info) -{ - // Get the population frequency file for the chromosome - std::string pfb_filepath = this->input_data->getAlleleFreqFilepath(chr); - if (pfb_filepath == "") + if (snp_reader->errnum) { - std::cout << "No population frequency file provided for chromosome " << chr << std::endl; - return; + printError("ERROR: " + std::string(bcf_sr_strerror(snp_reader->errnum))); } - // Determine the ethnicity-specific allele frequency key - std::string AF_key = "AF"; - if (this->input_data->getEthnicity() != "") + // Continue if no SNP was found in the region + if (!snp_found) { - AF_key += "_" + this->input_data->getEthnicity(); + bcf_sr_destroy(snp_reader); + bcf_sr_destroy(pfb_reader); + return; } - // Check if the filepath uses the 'chr' prefix notations based on the - // chromosome name (e.g., *.chr1.vcf.gz vs *.1.vcf.gz) - std::string chr_gnomad = chr; // gnomAD data may or may not have the 'chr' prefix - std::string chr_prefix = "chr"; - if (pfb_filepath.find(chr_prefix) == std::string::npos) + // Read the population allele frequency data ---------------------- + // Get the minimum and maximum SNP positions + uint32_t min_snp_pos = *std::min_element(snp_pos.begin(), snp_pos.end()); + uint32_t max_snp_pos = *std::max_element(snp_pos.begin(), snp_pos.end()); + std::unordered_set snp_pos_set(snp_pos.begin(), snp_pos.end()); + if (use_pfb) { - // gnomaAD does not use the 'chr' prefix - // Remove the 'chr' prefix from the chromosome name - if (chr_gnomad.find(chr_prefix) != std::string::npos) + // Set the region for the population allele frequency reader + std::string pfb_region_str = chr_gnomad + ":" + std::to_string(min_snp_pos) + "-" + std::to_string(max_snp_pos); + if (bcf_sr_set_regions(pfb_reader, pfb_region_str.c_str(), 0) < 0) { - chr_gnomad = chr_gnomad.substr(chr_prefix.length()); - } - } else { - // Add the 'chr' prefix to the chromosome name - if (chr_gnomad.find(chr_prefix) == std::string::npos) - { - chr_gnomad = chr_prefix + chr; - } - } - - // Remove the 'chr' prefix from the chromosome name for SNP data. All - // SNP data in this program does not use the 'chr' prefix - std::string chr_snp = chr; - if (chr_snp.find(chr_prefix) != std::string::npos) - { - chr_snp = chr_snp.substr(chr_prefix.length()); - } - std::cout << "Reading population frequencies for chromosome " << chr << " from " << pfb_filepath << std::endl; - - // Get the start and end SNP positions for the chromosome (1-based - // index) - std::pair snp_range = snp_info.getSNPRange(chr); - int64_t snp_start = snp_range.first; - int64_t snp_end = snp_range.second; - if (this->input_data->isRegionSet()) - { - // Get the user-defined region - std::pair region = this->input_data->getRegion(); - if (snp_start < region.first) { - snp_start = region.first; - } else if (snp_end > region.second) { - snp_end = region.second; + printError("ERROR: Could not set region for population allele frequency reader: " + pfb_region_str); } - } - // Split region into chunks and get the population frequencies in parallel - std::cout << "SNP range for chromosome " << chr << ": " << snp_start << "-" << snp_end << std::endl; - int num_threads = this->input_data->getThreadCount(); - std::vector region_chunks = splitRegionIntoChunks(chr_gnomad, snp_start, snp_end, num_threads); - std::unordered_map pos_pfb_map; - std::vector threads; - std::vector>> futures; - for (const auto& region_chunk : region_chunks) - { - // Create a lambda function to get the population frequencies for the - // region chunk - auto get_pfb = [region_chunk, pfb_filepath, AF_key]() -> std::unordered_map + // Find the SNP position in the population allele frequency file + float *pfb_f = NULL; + int count = 0; + while (bcf_sr_next_line(pfb_reader) > 0) { - // Run bcftools query to get the population frequencies for the - // chromosome within the SNP region, filtering for SNPS only, - // and within the MIN-MAX range of frequencies. - // TODO: Update to use ethnicity-specific population frequencies - // Example from gnomAD: - // ##INFO= - // std::string ethnicity_suffix = "_asj"; // Ashkenazi Jewish - // (leave empty for all populations) - std::string filter_criteria = "INFO/variant_type=\"snv\" && " + AF_key + " >= " + std::to_string(MIN_PFB) + " && " + AF_key + " <= " + std::to_string(MAX_PFB); - std::string cmd = \ - "bcftools query -r " + region_chunk + " -f '%POS\t%" + AF_key + "\n' -i '" + filter_criteria + "' " + pfb_filepath + " 2>/dev/null"; - - std::cout << "Command: " << cmd << std::endl; - - // Open a pipe to read the output of the command - FILE *fp = popen(cmd.c_str(), "r"); - if (fp == NULL) + // Get the SNP record and validate + bcf1_t *pfb_record = bcf_sr_get_line(pfb_reader, 0); + if (!pfb_record || !bcf_is_snp(pfb_record)) { - std::cerr << "ERROR: Could not open pipe for command: " << cmd << std::endl; - exit(1); + continue; // Skip if not a SNP } - // Loop through the BCFTOOLS output and populate the map of population - // frequencies - std::unordered_map pos_pfb_map; - const int line_size = 256; - char line[line_size]; - while (fgets(line, line_size, fp) != NULL) + // Get the SNP position + uint32_t pfb_pos = (uint32_t) pfb_record->pos + 1; + if (snp_pos_set.find(pfb_pos) == snp_pos_set.end()) { - // Parse the line - int pos; - double pfb; - if (sscanf(line, "%d%lf", &pos, &pfb) == 2) - { - // Add the position and population frequency to the map - pos_pfb_map[pos] = pfb; - } + continue; // Skip if the SNP position is not in the set } - pclose(fp); - - return pos_pfb_map; - }; - // Create a future for the thread - std::future> future = std::async(std::launch::async, get_pfb); - futures.push_back(std::move(future)); - } - - // Loop through the futures and get the results - int pfb_count = 0; - for (auto& future : futures) - { - // Wait for the future to finish - future.wait(); - - // Get the result from the future - std::unordered_map result = std::move(future.get()); - - // Loop through the result and add to SNPInfo - // printMessage("Adding population frequencies to SNPInfo..."); - for (auto& pair : result) - { - int pos = pair.first; - double pfb = pair.second; - - // Lock the SNPInfo mutex - this->snp_data_mtx.lock(); - snp_info.insertSNPPopulationFrequency(chr_snp, pos, pfb); - this->snp_data_mtx.unlock(); - - // Increment the population frequency count - pfb_count++; + // Get the population frequency for the SNP + int pfb_status = bcf_get_info_float(pfb_reader->readers[0].header, pfb_record, AF_key.c_str(), &pfb_f, &count); + if (pfb_status < 0 || count == 0) + { + continue; + } + double pfb = static_cast(pfb_f[0]); - // [TEST] Print 15 values - if (pfb_count < 15) + // Skip if outside the acceptable range + if (pfb <= MIN_PFB || pfb >= MAX_PFB) { - printMessage("Population frequency for " + chr + ":" + std::to_string(pos) + " = " + std::to_string(pfb)); + continue; } + snp_pfb[pfb_pos] = pfb; + break; } + free(pfb_f); } + + // Clean up + bcf_sr_destroy(snp_reader); + bcf_sr_destroy(pfb_reader); } -void CNVCaller::saveSVCopyNumberToTSV(SNPData& snp_data, std::string filepath, std::string chr, int64_t start, int64_t end, std::string sv_type, double likelihood) +void CNVCaller::saveSVCopyNumberToTSV(SNPData& snp_data, std::string filepath, std::string chr, uint32_t start, uint32_t end, std::string sv_type, double likelihood) const { // Open the TSV file for writing std::ofstream tsv_file(filepath); @@ -1151,7 +870,7 @@ void CNVCaller::saveSVCopyNumberToTSV(SNPData& snp_data, std::string filepath, s for (int i = 0; i < snp_count; i++) { // Get the SNP data - int64_t pos = snp_data.pos[i]; + uint32_t pos = snp_data.pos[i]; bool is_snp = snp_data.is_snp[i]; double pfb = snp_data.pfb[i]; double baf = snp_data.baf[i]; @@ -1180,7 +899,172 @@ void CNVCaller::saveSVCopyNumberToTSV(SNPData& snp_data, std::string filepath, s tsv_file.close(); } -void CNVCaller::updateSNPData(SNPData& snp_data, int64_t pos, double pfb, double baf, double log2_cov, bool is_snp) +void CNVCaller::saveSVCopyNumberToJSON(SNPData &before_sv, SNPData &after_sv, SNPData &snp_data, std::string chr, uint32_t start, uint32_t end, std::string sv_type, double likelihood, const std::string& filepath) const +{ + // Append the SV information to the JSON file + std::ofstream json_file(filepath, std::ios::app); + if (!json_file.is_open()) + { + std::cerr << "ERROR: Could not open JSON file for writing: " << filepath << std::endl; + exit(1); + } + + // If not the first record, write the closing bracket + // Check if file is empty + if (isFileEmpty(filepath)) + { + json_file << "[\n"; + } else { + // Close the previous JSON object + json_file << "},\n"; + } + + json_file << "{\n"; + json_file << " \"chromosome\": \"" << chr << "\",\n"; + json_file << " \"start\": " << start << ",\n"; + json_file << " \"end\": " << end << ",\n"; + json_file << " \"sv_type\": \"" << sv_type << "\",\n"; + json_file << " \"likelihood\": " << likelihood << ",\n"; + json_file << " \"size\": " << (end - start + 1) << ",\n"; + json_file << " \"before_sv\": {\n"; + json_file << " \"positions\": ["; + for (size_t i = 0; i < before_sv.pos.size(); ++i) + { + json_file << before_sv.pos[i]; + if (i < before_sv.pos.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"b_allele_freq\": ["; + for (size_t i = 0; i < before_sv.baf.size(); ++i) + { + json_file << before_sv.baf[i]; + if (i < before_sv.baf.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"population_freq\": ["; + for (size_t i = 0; i < before_sv.pfb.size(); ++i) + { + json_file << before_sv.pfb[i]; + if (i < before_sv.pfb.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"log2_ratio\": ["; + for (size_t i = 0; i < before_sv.log2_cov.size(); ++i) + { + json_file << before_sv.log2_cov[i]; + if (i < before_sv.log2_cov.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"is_snp\": ["; + for (size_t i = 0; i < before_sv.is_snp.size(); ++i) + { + json_file << before_sv.is_snp[i]; + if (i < before_sv.is_snp.size() - 1) + json_file << ", "; + } + json_file << "]\n"; + json_file << " },\n"; + json_file << " \"after_sv\": {\n"; + json_file << " \"positions\": ["; + for (size_t i = 0; i < after_sv.pos.size(); ++i) + { + json_file << after_sv.pos[i]; + if (i < after_sv.pos.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"b_allele_freq\": ["; + for (size_t i = 0; i < after_sv.baf.size(); ++i) + { + json_file << after_sv.baf[i]; + if (i < after_sv.baf.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"population_freq\": ["; + for (size_t i = 0; i < after_sv.pfb.size(); ++i) + { + json_file << after_sv.pfb[i]; + if (i < after_sv.pfb.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"log2_ratio\": ["; + for (size_t i = 0; i < after_sv.log2_cov.size(); ++i) + { + json_file << after_sv.log2_cov[i]; + if (i < after_sv.log2_cov.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"is_snp\": ["; + for (size_t i = 0; i < after_sv.is_snp.size(); ++i) + { + json_file << after_sv.is_snp[i]; + if (i < after_sv.is_snp.size() - 1) + json_file << ", "; + } + json_file << "]\n"; + json_file << " },\n"; + json_file << " \"sv\": {\n"; + json_file << " \"positions\": ["; + for (size_t i = 0; i < snp_data.pos.size(); ++i) + { + json_file << snp_data.pos[i]; + if (i < snp_data.pos.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"b_allele_freq\": ["; + for (size_t i = 0; i < snp_data.baf.size(); ++i) + { + json_file << snp_data.baf[i]; + if (i < snp_data.baf.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"population_freq\": ["; + for (size_t i = 0; i < snp_data.pfb.size(); ++i) + { + json_file << snp_data.pfb[i]; + if (i < snp_data.pfb.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"log2_ratio\": ["; + for (size_t i = 0; i < snp_data.log2_cov.size(); ++i) + { + json_file << snp_data.log2_cov[i]; + if (i < snp_data.log2_cov.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"states\": ["; + for (size_t i = 0; i < snp_data.state_sequence.size(); ++i) + { + json_file << snp_data.state_sequence[i]; + if (i < snp_data.state_sequence.size() - 1) + json_file << ", "; + } + json_file << "],\n"; + json_file << " \"is_snp\": ["; + for (size_t i = 0; i < snp_data.is_snp.size(); ++i) + { + json_file << snp_data.is_snp[i]; + if (i < snp_data.is_snp.size() - 1) + json_file << ", "; + } + json_file << "]\n"; + json_file << " }\n"; + json_file.close(); + printMessage("Saved copy number predictions for " + chr + ":" + std::to_string(start) + "-" + std::to_string(end) + " to " + filepath); +} + +void CNVCaller::updateSNPData(SNPData& snp_data, uint32_t pos, double pfb, double baf, double log2_cov, bool is_snp) { // Update the SNP data snp_data.pos.emplace_back(pos); diff --git a/src/cnv_data.cpp b/src/cnv_data.cpp deleted file mode 100644 index 0c4593c0..00000000 --- a/src/cnv_data.cpp +++ /dev/null @@ -1,73 +0,0 @@ -#include "cnv_data.h" - -/// @cond -#include -#include -#include -#include -#include -#include - -#include "sv_types.h" -/// @endcond - -// Include the SV types namespace -using namespace sv_types; - -void CNVData::addCNVCall(std::string chr, int snp_pos, int cnv_type) -{ - // Add the CNV call to the map - SNPLocation key(chr, snp_pos); - this->cnv_calls[key] = cnv_type; -} - -void CNVData::loadFromFile(std::string filepath) -{ - // Load CNV calls from file - std::ifstream cnv_file(filepath); - std::string line; - std::string chr; - int snp_pos; - int cnv_type; - - // Check if the file was opened successfully - if (!cnv_file.is_open()) { - std::cerr << "Error: Could not open CNV file " << filepath << std::endl; - exit(1); - } - - // Skip the first line (header) - std::getline(cnv_file, line); - - // Read the file line by line - int line_num = 1; - while (std::getline(cnv_file, line)) { - - // Parse the line - std::istringstream iss(line); - - // Get columns 1, 2, and 5 (chr, pos, cnv_type) - std::string chr; - std::getline(iss, chr, '\t'); - - std::string pos_str; - std::getline(iss, pos_str, '\t'); - snp_pos = std::stoi(pos_str); - - std::string skip_str; - std::getline(iss, skip_str, '\t'); - std::getline(iss, skip_str, '\t'); - - std::string cnv_type_str; - std::getline(iss, cnv_type_str, '\t'); - cnv_type = std::stoi(cnv_type_str); - - // Add the CNV call to the map - this->addCNVCall(chr, snp_pos, cnv_type); - - line_num++; - } - cnv_file.close(); - - std::cout << "Loaded " << line_num << " CNV calls" << std::endl; -} diff --git a/src/contextsv.cpp b/src/contextsv.cpp index 47d68054..3957fe95 100644 --- a/src/contextsv.cpp +++ b/src/contextsv.cpp @@ -12,37 +12,11 @@ #include "utils.h" /// @endcond -ContextSV::ContextSV(InputData& input_data) -{ - this->input_data = &input_data; -} -// Entry point -int ContextSV::run() +int ContextSV::run(const InputData& input_data) const { - // Start the program's timer - auto start_sv = std::chrono::high_resolution_clock::now(); - - // Get the reference genome - FASTAQuery ref_genome = this->input_data->getRefGenome(); - - // Call SVs from long read alignments: - std::cout << "Running alignment-based SV calling..." << std::endl; - SVCaller sv_caller(*this->input_data); - SVData sv_calls = sv_caller.run(); - - // Print the total number of SVs called - std::cout << "Total SVs called: " << sv_calls.totalCalls() << std::endl; - - // Write SV calls to file - std::string output_dir = this->input_data->getOutputDir(); - std::cout << "Writing SV calls to file " << output_dir << "/output.vcf..." << std::endl; - sv_calls.saveToVCF(ref_genome, output_dir); - - // Format and print the time taken to call SVs - auto end_sv = std::chrono::high_resolution_clock::now(); - std::string elapsed_time = getElapsedTime(start_sv, end_sv); - std::cout << "SV calling complete. Found " << sv_calls.totalCalls() << " total SVs. Time taken (h:m:s) = " << elapsed_time << std::endl; + SVCaller sv_caller; + sv_caller.run(input_data); return 0; } diff --git a/src/dbscan.cpp b/src/dbscan.cpp new file mode 100644 index 00000000..c1f3f314 --- /dev/null +++ b/src/dbscan.cpp @@ -0,0 +1,81 @@ +#include "dbscan.h" + +#include +#include +#include +#include + + +void DBSCAN::fit(const std::vector& sv_calls) { + int clusterId = 0; + clusters.assign(sv_calls.size(), -1); // -1 means unclassified + + for (size_t i = 0; i < sv_calls.size(); ++i) { + if (clusters[i] == -1) { // if point is not yet classified + if (expandCluster(sv_calls, i, clusterId)) { + ++clusterId; + } + } + } +} + +const std::vector& DBSCAN::getClusters() const { + return clusters; +} + +bool DBSCAN::expandCluster(const std::vector& sv_calls, size_t pointIdx, int clusterId) { + std::vector seeds = regionQuery(sv_calls, pointIdx); + if (static_cast(seeds.size()) < minPts) { + clusters[pointIdx] = -2; // mark as noise + return false; + } + + for (size_t seedIdx : seeds) { + clusters[seedIdx] = clusterId; + } + + seeds.erase(std::remove(seeds.begin(), seeds.end(), pointIdx), seeds.end()); + + while (!seeds.empty()) { + size_t currentPoint = seeds.back(); + seeds.pop_back(); + + std::vector result = regionQuery(sv_calls, currentPoint); + if (static_cast(result.size()) >= minPts) { + for (size_t resultPoint : result) { + if (clusters[resultPoint] == -1 || clusters[resultPoint] == -2) { + if (clusters[resultPoint] == -1) { + seeds.push_back(resultPoint); + } + clusters[resultPoint] = clusterId; + } + } + } + } + + return true; +} + +std::vector DBSCAN::regionQuery(const std::vector& sv_calls, size_t pointIdx) const { + std::vector neighbors; + for (size_t i = 0; i < sv_calls.size(); ++i) { + if (distance(sv_calls[pointIdx], sv_calls[i]) <= epsilon) { + neighbors.push_back(i); + } + } + return neighbors; +} + +double DBSCAN::distance(const SVCall& point1, const SVCall& point2) const { + + // Calculate reciprocal overlap-based distance + // https://genomebiology.biomedcentral.com/articles/10.1186/s13059-022-02840-6 + // https://link.springer.com/article/10.1186/gb-2009-10-10-r119 + int overlap = std::max(0, std::min(static_cast(point1.end), static_cast(point2.end)) - std::max(static_cast(point1.start), static_cast(point2.start))); + int length1 = static_cast(point1.end - point1.start); + int length2 = static_cast(point2.end - point2.start); + + // Minimum reciprocal overlap + double distance = 1.0 - std::min(static_cast(overlap) / static_cast(length1), static_cast(overlap) / static_cast(length2)); + return distance; // 0.0 means identical, 1.0 means no overlap +} diff --git a/src/dbscan1d.cpp b/src/dbscan1d.cpp new file mode 100644 index 00000000..90fc9458 --- /dev/null +++ b/src/dbscan1d.cpp @@ -0,0 +1,90 @@ +#include "dbscan1d.h" + +#include +#include +#include +#include + +void DBSCAN1D::fit(const std::vector& points) { + int clusterId = 0; + clusters.assign(points.size(), -1); // -1 means unclassified + + for (size_t i = 0; i < points.size(); ++i) { + if (clusters[i] == -1) { // if point is not yet classified + if (expandCluster(points, i, clusterId)) { + ++clusterId; + } + } + } +} + +const std::vector& DBSCAN1D::getClusters() const { + return clusters; +} + +bool DBSCAN1D::expandCluster(const std::vector& points, size_t pointIdx, int clusterId) { + std::vector seeds = regionQuery(points, pointIdx); + if (static_cast(seeds.size()) < minPts) { + clusters[pointIdx] = -2; // mark as noise + return false; + } + + for (size_t seedIdx : seeds) { + clusters[seedIdx] = clusterId; + } + + seeds.erase(std::remove(seeds.begin(), seeds.end(), pointIdx), seeds.end()); + + while (!seeds.empty()) { + size_t currentPoint = seeds.back(); + seeds.pop_back(); + + std::vector result = regionQuery(points, currentPoint); + if (static_cast(result.size()) >= minPts) { + for (size_t resultPoint : result) { + if (clusters[resultPoint] == -1 || clusters[resultPoint] == -2) { + if (clusters[resultPoint] == -1) { + seeds.push_back(resultPoint); + } + clusters[resultPoint] = clusterId; + } + } + } + } + + return true; +} + +std::vector DBSCAN1D::regionQuery(const std::vector& points, size_t pointIdx) const { + std::vector neighbors; + for (size_t i = 0; i < points.size(); ++i) { + if (distance(points[pointIdx], points[i]) <= epsilon) { + neighbors.push_back(i); + } + } + return neighbors; +} + +double DBSCAN1D::distance(int point1, int point2) const { + return std::abs(point1 - point2); +} + +std::vector DBSCAN1D::getLargestCluster(const std::vector &points) +{ + std::vector clusters = getClusters(); + std::map> cluster_map; + for (size_t i = 0; i < clusters.size(); ++i) { + cluster_map[clusters[i]].push_back(points[i]); + } + + int largest_cluster_id = -1; + size_t largest_size = 0; + for (const auto &entry : cluster_map) { + if (entry.first >= 0 && entry.second.size() > largest_size) { + largest_size = entry.second.size(); + largest_cluster_id = entry.first; + } + } + + return cluster_map[largest_cluster_id]; +} diff --git a/src/debug.cpp b/src/debug.cpp new file mode 100644 index 00000000..2028e5f6 --- /dev/null +++ b/src/debug.cpp @@ -0,0 +1,4 @@ +// debug.cpp +#include "debug.h" + +std::mutex debug_mutex; diff --git a/src/fasta_query.cpp b/src/fasta_query.cpp index e1bc9bea..e4f0e1dc 100644 --- a/src/fasta_query.cpp +++ b/src/fasta_query.cpp @@ -12,7 +12,10 @@ /// @endcond -int FASTAQuery::setFilepath(std::string fasta_filepath) +#include "utils.h" + + +int ReferenceGenome::setFilepath(std::string fasta_filepath) { if (fasta_filepath == "") { @@ -31,8 +34,6 @@ int FASTAQuery::setFilepath(std::string fasta_filepath) } // Get the chromosomes and sequences - std::vector chromosomes; - std::unordered_map chr_to_seq; std::string current_chr = ""; std::string sequence = ""; std::string line_str = ""; @@ -45,13 +46,13 @@ int FASTAQuery::setFilepath(std::string fasta_filepath) // Store the previous chromosome and sequence if (current_chr != "") { - chromosomes.push_back(current_chr); // Add the chromosome to the list - chr_to_seq[current_chr] = sequence; // Add the sequence to the map + this->chromosomes.push_back(current_chr); // Add the chromosome to the list + this->chr_to_seq[current_chr] = sequence; // Add the sequence to the map + this->chr_to_length[current_chr] = sequence.length(); // Add the sequence length to the map sequence = ""; // Reset the sequence } - // Get the new chromosome - current_chr = line_str.substr(1); + current_chr = line_str.substr(1); // Remove the '>' character // Remove the description size_t space_pos = current_chr.find(" "); @@ -59,15 +60,7 @@ int FASTAQuery::setFilepath(std::string fasta_filepath) { current_chr.erase(space_pos); } - - // Check if the chromosome is already in the map - if (chr_to_seq.find(current_chr) != chr_to_seq.end()) - { - std::cerr << "Duplicate chromosome " << current_chr << std::endl; - exit(1); - } } else { - // Sequence line sequence += line_str; } } @@ -75,66 +68,78 @@ int FASTAQuery::setFilepath(std::string fasta_filepath) // Add the last chromosome at the end of the file if (current_chr != "") { - chromosomes.push_back(current_chr); // Add the chromosome to the list - chr_to_seq[current_chr] = sequence; // Add the sequence to the map + this->chromosomes.push_back(current_chr); // Add the chromosome to the list + this->chr_to_seq[current_chr] = sequence; // Add the sequence to the map + this->chr_to_length[current_chr] = sequence.length(); // Add the sequence length to the map } - // Close the file fasta_file.close(); - - // Sort the chromosomes - std::sort(chromosomes.begin(), chromosomes.end()); - - // Set the chromosomes and sequences - this->chromosomes = chromosomes; - this->chr_to_seq = chr_to_seq; + std::sort(this->chromosomes.begin(), this->chromosomes.end()); return 0; } -std::string FASTAQuery::getFilepath() +std::string ReferenceGenome::getFilepath() const { return this->fasta_filepath; } // Function to get the reference sequence at a given position range -std::string FASTAQuery::query(std::string chr, int64_t pos_start, int64_t pos_end) -{ +std::string_view ReferenceGenome::query(const std::string& chr, uint32_t pos_start, uint32_t pos_end) const +{ // Convert positions from 1-indexed (reference) to 0-indexed (string indexing) pos_start--; pos_end--; - // Ensure that the start position is not negative, and the end position is - // not larger than the chromosome length - if (pos_start < 0) + // Ensure that the end position is not larger than the chromosome length + const std::string& sequence = this->chr_to_seq.at(chr); + if (pos_end >= sequence.length() || pos_start > pos_end) { - return ""; + return {}; } - if (pos_end >= (int64_t)this->chr_to_seq[chr].length()) + + return std::string_view(sequence).substr(pos_start, (pos_end - pos_start) + 1); +} + +// Function to compare the reference sequence at a given position range +bool ReferenceGenome::compare(const std::string& chr, uint32_t pos_start, uint32_t pos_end, const std::string& compare_seq, float match_threshold) const +{ + // Convert positions from 1-indexed (reference) to 0-indexed (string indexing) + pos_start--; + pos_end--; + + // Ensure that the end position is not larger than the chromosome length + const std::string& sequence = this->chr_to_seq.at(chr); + if (pos_end >= sequence.length() || pos_start >= pos_end) { - return ""; + return {}; } - int64_t length = pos_end - pos_start + 1; - - // Get the sequence - const std::string& sequence = this->chr_to_seq[chr]; - - // Get the substring - // std::string subsequence = sequence.substr(pos_start, length); + std::string_view subseq = std::string_view(sequence).substr(pos_start, pos_end - pos_start + 1); + if (subseq.length() != compare_seq.length()) + { + printError("ERROR: Sequence lengths do not match for comparison"); + return false; + } - // If the subsequence is empty, return empty string - if (sequence.substr(pos_start, length).empty()) + // Calculate the match rate + size_t num_matches = 0; + for (size_t i = 0; i < subseq.length(); i++) { - return ""; + if (subseq[i] == compare_seq[i]) + { + num_matches++; + } } + float match_rate = (float)num_matches / (float)subseq.length(); - return sequence.substr(pos_start, length); + return match_rate >= match_threshold; } // Function to get the chromosome contig lengths in VCF header format -std::string FASTAQuery::getContigHeader() +std::string ReferenceGenome::getContigHeader() const { + std::shared_lock lock(this->shared_mutex); std::string contig_header = ""; // Sort the chromosomes @@ -144,12 +149,10 @@ std::string FASTAQuery::getContigHeader() chromosomes.push_back(chr_seq.first); } std::sort(chromosomes.begin(), chromosomes.end()); - - // Iterate over the chromosomes and add them to the contig header for (auto const& chr : chromosomes) { // Add the contig header line - contig_header += "##contig=\n"; + contig_header += "##contig=\n"; } // Remove the last newline character @@ -158,12 +161,20 @@ std::string FASTAQuery::getContigHeader() return contig_header; } -std::vector FASTAQuery::getChromosomes() +std::vector ReferenceGenome::getChromosomes() const { return this->chromosomes; } -int64_t FASTAQuery::getChromosomeLength(std::string chr) +uint32_t ReferenceGenome::getChromosomeLength(std::string chr) const { - return this->chr_to_seq[chr].length(); + try + { + return this->chr_to_length.at(chr); + } + catch (const std::out_of_range& e) + { + printError("Chromosome " + chr + " not found in reference genome"); + return 0; + } } diff --git a/src/input_data.cpp b/src/input_data.cpp index 85e4f8d1..3b53a7d7 100644 --- a/src/input_data.cpp +++ b/src/input_data.cpp @@ -8,6 +8,7 @@ #include #include "utils.h" +#include "debug.h" // For DEBUG_PRINT /// @endcond #define MIN_PFB 0.01 // Minimum SNP population allele frequency @@ -16,7 +17,6 @@ // Constructor InputData::InputData() { - this->short_read_bam = ""; this->long_read_bam = ""; this->ref_filepath = ""; this->snp_vcf_filepath = ""; @@ -24,40 +24,42 @@ InputData::InputData() this->start_end = std::make_pair(0, 0); this->region_set = false; this->output_dir = ""; - this->window_size = 2500; - this->min_cnv_length = 1000; + this->sample_size = 20; + this->min_cnv_length = 2000; // Default minimum CNV length + this->min_reads = 5; + this->dbscan_epsilon = 0.1; + this->dbscan_min_pts_pct = 0.1; this->thread_count = 1; this->hmm_filepath = "data/wgs.hmm"; this->verbose = false; this->save_cnv_data = false; -} - -std::string InputData::getShortReadBam() -{ - return this->short_read_bam; -} - -void InputData::setShortReadBam(std::string filepath) -{ - this->short_read_bam = filepath; - - // Check if empty string - if (filepath == "") + this->single_chr = false; + this->cnv_output_file = ""; + this->assembly_gaps = ""; +} + +void InputData::printParameters() const +{ + DEBUG_PRINT("Input parameters:"); + DEBUG_PRINT("Long read BAM: " << this->long_read_bam); + DEBUG_PRINT("Reference genome: " << this->ref_filepath); + DEBUG_PRINT("SNP VCF: " << this->snp_vcf_filepath); + DEBUG_PRINT("Output directory: " << this->output_dir); + DEBUG_PRINT("Sample size: " << this->sample_size); + DEBUG_PRINT("Minimum CNV length: " << this->min_cnv_length); + DEBUG_PRINT("DBSCAN epsilon: " << this->dbscan_epsilon); + DEBUG_PRINT("DBSCAN minimum points percentage: " << this->dbscan_min_pts_pct * 100.0f << "%"); + if (this->region_set) { - return; - - } else { - // Check if the file exists - FILE *fp = fopen(filepath.c_str(), "r"); - if (fp == NULL) - { - std::cerr << "Short read BAM file does not exist: " << filepath << std::endl; - exit(1); - } + DEBUG_PRINT("Region set to: chr" + this->chr + ":" + std::to_string(this->start_end.first) + "-" + std::to_string(this->start_end.second)); + } + else + { + DEBUG_PRINT("Running on whole genome"); } } -std::string InputData::getLongReadBam() +std::string InputData::getLongReadBam() const { return this->long_read_bam; } @@ -67,7 +69,7 @@ void InputData::setLongReadBam(std::string filepath) this->long_read_bam = filepath; // Check if empty string - if (filepath == "") + if (filepath.empty()) { return; @@ -76,39 +78,24 @@ void InputData::setLongReadBam(std::string filepath) FILE *fp = fopen(filepath.c_str(), "r"); if (fp == NULL) { - std::cerr << "Long read BAM file does not exist: " << filepath << std::endl; - exit(1); + throw std::runtime_error("Long read BAM file does not exist: " + filepath); + } else { + fclose(fp); } } } -void InputData::setRefGenome(std::string fasta_filepath) -{ - // Set the reference genome - this->fasta_query.setFilepath(fasta_filepath); -} - -const FASTAQuery &InputData::getRefGenome() const -{ - return this->fasta_query; -} - -std::string InputData::queryRefGenome(std::string chr, int64_t pos_start, int64_t pos_end) +void InputData::setRefGenome(std::string filepath) { - return this->fasta_query.query(chr, pos_start, pos_end); + this->ref_filepath = filepath; } -std::vector InputData::getRefGenomeChromosomes() +std::string InputData::getRefGenome() const { - return this->fasta_query.getChromosomes(); + return this->ref_filepath; } -int64_t InputData::getRefGenomeChromosomeLength(std::string chr) -{ - return this->fasta_query.getChromosomeLength(chr); -} - -std::string InputData::getOutputDir() +std::string InputData::getOutputDir() const { return this->output_dir; } @@ -116,23 +103,28 @@ std::string InputData::getOutputDir() void InputData::setOutputDir(std::string dirpath) { this->output_dir = dirpath; - - // Create the output directory std::string cmd = "mkdir -p " + output_dir; - system(cmd.c_str()); + try + { + std::system(cmd.c_str()); + } catch (const std::exception& e) + { + std::cerr << "Error creating output directory: " << e.what() << std::endl; + exit(1); + } } -int InputData::getWindowSize() +int InputData::getSampleSize() const { - return this->window_size; + return this->sample_size; } -void InputData::setWindowSize(int window_size) +void InputData::setSampleSize(int sample_size) { - this->window_size = window_size; + this->sample_size = sample_size; } -std::string InputData::getSNPFilepath() +std::string InputData::getSNPFilepath() const { return this->snp_vcf_filepath; } @@ -142,7 +134,7 @@ void InputData::setSNPFilepath(std::string filepath) this->snp_vcf_filepath = filepath; } -std::string InputData::getEthnicity() +std::string InputData::getEthnicity() const { return this->ethnicity; } @@ -152,26 +144,80 @@ void InputData::setEthnicity(std::string ethnicity) this->ethnicity = ethnicity; } -int InputData::getMinCNVLength() +void InputData::setAssemblyGaps(std::string filepath) +{ + // Check if the file exists + FILE *fp = fopen(filepath.c_str(), "r"); + if (fp == NULL) + { + std::cerr << "Assembly gaps file does not exist: " << filepath << std::endl; + exit(1); + } + + // Check if the file is a BED file + std::string ext = filepath.substr(filepath.find_last_of(".") + 1); + if (ext != "bed") + { + std::cerr << "Assembly gaps file is not a BED file: " << filepath << std::endl; + exit(1); + } + fclose(fp); + + // Set the assembly gaps file + this->assembly_gaps = filepath; +} + +std::string InputData::getAssemblyGaps() const +{ + return this->assembly_gaps; +} + +uint32_t InputData::getMinCNVLength() const { return this->min_cnv_length; } void InputData::setMinCNVLength(int min_cnv_length) { - this->min_cnv_length = min_cnv_length; + this->min_cnv_length = (uint32_t) min_cnv_length; +} + +void InputData::setDBSCAN_Epsilon(double epsilon) +{ + this->dbscan_epsilon = epsilon; +} + +double InputData::getDBSCAN_Epsilon() const +{ + return this->dbscan_epsilon; +} + +void InputData::setDBSCAN_MinPtsPct(double min_pts_pct) +{ + this->dbscan_min_pts_pct = min_pts_pct; +} + +double InputData::getDBSCAN_MinPtsPct() const +{ + return this->dbscan_min_pts_pct; } void InputData::setChromosome(std::string chr) { this->chr = chr; + this->single_chr = true; } -std::string InputData::getChromosome() +std::string InputData::getChromosome() const { return this->chr; } +bool InputData::isSingleChr() const +{ + return this->single_chr; +} + void InputData::setRegion(std::string region) { // Check if the region is valid @@ -197,25 +243,24 @@ void InputData::setRegion(std::string region) // Set the region this->start_end = std::make_pair(start, end); this->region_set = true; + + std::cout << "Region set to " << this->chr << ":" << start << "-" << end << std::endl; } } - std::cout << "Region set to " << this->start_end.first << "-" << this->start_end.second << std::endl; } -std::pair InputData::getRegion() +std::pair InputData::getRegion() const { return this->start_end; } -bool InputData::isRegionSet() +bool InputData::isRegionSet() const { return this->region_set; } void InputData::setAlleleFreqFilepaths(std::string filepath) { - // this->pfb_filepath = filepath; - // Check if empty string if (filepath == "") { @@ -298,14 +343,22 @@ void InputData::setAlleleFreqFilepaths(std::string filepath) } } -std::string InputData::getAlleleFreqFilepath(std::string chr) +std::string InputData::getAlleleFreqFilepath(std::string chr) const { // Remove the chr notation if (chr.find("chr") != std::string::npos) { chr = chr.substr(3, chr.size() - 3); } - return this->pfb_filepaths[chr]; + + try + { + return this->pfb_filepaths.at(chr); + } + catch (const std::out_of_range& e) + { + return ""; + } } void InputData::setThreadCount(int thread_count) @@ -313,12 +366,12 @@ void InputData::setThreadCount(int thread_count) this->thread_count = thread_count; } -int InputData::getThreadCount() +int InputData::getThreadCount() const { return this->thread_count; } -std::string InputData::getHMMFilepath() +std::string InputData::getHMMFilepath() const { return this->hmm_filepath; } @@ -360,7 +413,17 @@ void InputData::saveCNVData(bool save_cnv_data) this->save_cnv_data = save_cnv_data; } -bool InputData::getSaveCNVData() +bool InputData::getSaveCNVData() const { return this->save_cnv_data; } + +void InputData::setCNVOutputFile(std::string filepath) +{ + this->cnv_output_file = filepath; +} + +std::string InputData::getCNVOutputFile() const +{ + return this->cnv_output_file; +} diff --git a/src/khmm.cpp b/src/khmm.cpp index a5d553be..7dfce96e 100644 --- a/src/khmm.cpp +++ b/src/khmm.cpp @@ -3,12 +3,17 @@ /// @cond #include +#include #include #include #include #include +#include +#include /// @endcond +#include "utils.h" + #define STATE_CHANGE 100000.0 /*this is the expected changes (D value) in the transition matrix*/ #define VITHUGE 100000000000.0 #define FLOAT_MINIMUM 1.175494351e-38 /*this is indeed machine dependent*/ @@ -50,30 +55,42 @@ std::pair, double> testVit_CHMM(CHMM hmm, int T, std::vector mean, std::vector sd, double uf, double o) { - if (o < mean[1]) + // Get the values (0-based indexing) + + // Fix within the expected normalized coverage range + if (o < mean[0]) + { + o = mean[0]; + } else if (o > mean[5]) { - o = mean[1]; + o = mean[5]; } - double p = uf + ((1 - uf) * pdf_normal(o, mean[state], sd[state])); + + double p = uf + ((1 - uf) * pdf_normal(o, mean[state-1], sd[state-1])); + + // Print the equation and the result + // printMessage("b1iot: state = " + std::to_string(state) + ", mean = " + std::to_string(mean[state-1]) + ", sd = " + std::to_string(sd[state-1]) + ", uf = " + std::to_string(uf) + ", o = " + std::to_string(o) + ", p = " + std::to_string(p)); + // printMessage("Equation: b1iot = uf + ((1 - uf) * pdf_normal(o, mean[state-1], sd[state-1]))"); return log(p); } -double b2iot(int state, double *mean, double *sd, double uf, double pfb, double b) +double b2iot(int state, const std::vector mean, const std::vector sd, double uf, double pfb, double b) { + // Get the values (0-based indexing) double p = 0; - double mean0 = mean[1]; // mean[1] = 0 - double mean25 = mean[2]; // mean[2] = 0.25 - double mean33 = mean[3]; // mean[3] = 0.33 - double mean50 = mean[4]; // mean[4] = 0.5 - double mean50_state1 = mean[5]; // mean[5] = 0.5 - double sd0 = sd[1]; // sd[1] = 0 - double sd25 = sd[2]; // sd[2] = 0.25 - double sd33 = sd[3]; // sd[3] = 0.33 - double sd50 = sd[4]; // sd[4] = 0.5 - double sd50_state1 = sd[5]; // sd[5] = 0.5 + double mean0 = mean[0]; // mean[0] = 0 + double mean25 = mean[1]; // mean[1] = 0.25 + double mean33 = mean[2]; // mean[2] = 0.33 + double mean50 = mean[3]; // mean[3] = 0.5 + double mean50_state1 = mean[4]; // mean[4] = 0.5 + double sd0 = sd[0]; // sd[0] = 0 + double sd25 = sd[1]; // sd[1] = 0.25 + double sd33 = sd[2]; // sd[2] = 0.33 + double sd50 = sd[3]; // sd[3] = 0.5 + double sd50_state1 = sd[4]; // sd[4] = 0.5 p = uf; // UF = previous alpha (transition probability) // PDF normal is the transition probability distrubution a_ij (initialized @@ -247,7 +264,8 @@ std::pair, double> ViterbiLogNP_CHMM(CHMM hmm, int T, std::vect { for (j = 1; j <= hmm.N; j++) { - A1[i][j] = hmm.A[i][j]; + // Update for 0-based indexing + A1[i][j] = hmm.A[i-1][j-1]; } } @@ -257,9 +275,11 @@ std::pair, double> ViterbiLogNP_CHMM(CHMM hmm, int T, std::vect // Threshold any zero values to avoid calculation issues. for (i = 1; i <= hmm.N; i++) { - if (hmm.pi[i] == 0) - hmm.pi[i] = 1e-9; /*eliminate problems with zero probability*/ - hmm.pi[i] = log(hmm.pi[i]); // Convert to log probability due to underflow + // Update to 0-based indexing + if (hmm.pi[i-1] == 0) { + hmm.pi[i-1] = 1e-9; /*eliminate problems with zero probability*/ + } + hmm.pi[i-1] = log(hmm.pi[i-1]); // Convert to log probability due to underflow } // Biot is the NxT matrix of state observation likelihoods. @@ -301,8 +321,9 @@ std::pair, double> ViterbiLogNP_CHMM(CHMM hmm, int T, std::vect /* 1. Initialization */ for (i = 1; i <= hmm.N; i++) - { - delta[1][i] = hmm.pi[i] + biot[i][1]; // Initialize the delta matrix (log probability) to the initial state distribution + the emission probability + { + // Update to 0-based indexing + delta[1][i] = hmm.pi[i-1] + biot[i][1]; // Initialize the delta matrix psi[1][i] = 0; // Initialize the psi matrix (state sequence) to 0 (no state) } @@ -339,12 +360,12 @@ std::pair, double> ViterbiLogNP_CHMM(CHMM hmm, int T, std::vect // of the state sequence ending in state i at time T, along with observing // the sequence O1, O2. q[T] = 1; - double min_prob = -VITHUGE; + double final_lh = -VITHUGE; for (i = 1; i <= hmm.N; i++) { - if (delta[T][i] > min_prob) + if (delta[T][i] > final_lh) { - min_prob = delta[T][i]; + final_lh = delta[T][i]; q[T] = i; } } @@ -359,178 +380,207 @@ std::pair, double> ViterbiLogNP_CHMM(CHMM hmm, int T, std::vect q[t] = psi[t + 1][q[t + 1]]; } - // // Print t, the state, delta, biot, and psi - // for (t = 1; t <= T; t++) - // { - // std::cout << "Time " << t << " with state " << q[t] << ":" << std::endl; - // for (i = 1; i <= hmm.N; i++) - // { - // std::cout << "State " << i << ": delta = " << delta[t][i] << ", biot = " << biot[i][t] << ", psi = " << psi[t][i] << ", LRR = " << O1[t-1] << ", BAF = " << O2[t-1] << std::endl; - // } - // std::cout << std::endl; - // } - for (i = 1; i <= hmm.N; i++) { /*recover the HMM model as original*/ - hmm.pi[i] = exp(hmm.pi[i]); + // Update to 0-based indexing + hmm.pi[i-1] = exp(hmm.pi[i-1]); } free_dmatrix(biot, 1, hmm.N, 1, T); free_dmatrix(A1, 1, hmm.N, 1, hmm.N); - // Return the state sequence and its likelihood - return std::make_pair(q, min_prob); + return std::make_pair(q, final_lh); } -CHMM ReadCHMM(const char *filename) +CHMM ReadCHMM(const std::string filename) { - FILE *fp; + std::ifstream file(filename); + if (!file.is_open()) + { + printError("Error opening file"); + return CHMM(); + } CHMM hmm; - int i, j, k; - fp = fopen(filename, "r"); - if (!fp) - fprintf(stderr, "Error: cannot read from HMM file %s\n", filename); - if (fscanf(fp, "M=%d\n", &(hmm.M)) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read M annotation from HMM file"); - if (fscanf(fp, "N=%d\n", &(hmm.N)) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read N annotation from HMM file"); + // Read M + std::string line; + std::getline(file, line); + if (sscanf(line.c_str(), "M=%d", &hmm.M) != 1) + { + printError("Error reading M"); + return CHMM(); + } - if (fscanf(fp, "A:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read A annotation from HMM file"); - hmm.A = (double **)dmatrix(1, hmm.N, 1, hmm.N); - for (i = 1; i <= hmm.N; i++) + // Read N + std::getline(file, line); + if (sscanf(line.c_str(), "N=%d", &hmm.N) != 1) { - for (j = 1; j <= hmm.N; j++) - { - if (fscanf(fp, "%lf", &(hmm.A[i][j])) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read A matrix from HMM file"); - } - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); + printError("Error reading N"); + return CHMM(); } - if (fscanf(fp, "B:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B annotation from HMM file"); - hmm.B = (double **)dmatrix(1, hmm.N, 1, hmm.M); - for (j = 1; j <= hmm.N; j++) + // Read A + std::getline(file, line); + if (line != "A:") { - for (k = 1; k <= hmm.M; k++) - { - if (fscanf(fp, "%lf", &(hmm.B[j][k])) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B matrix from HMM file"); - } - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); + printError("Error reading A"); + return CHMM(); + } + hmm.A = readMatrix(file, hmm.N, hmm.N); + if (hmm.A.size() != (size_t)hmm.N || hmm.A[0].size() != (size_t)hmm.N) + { + printError("Error reading A"); + return CHMM(); } - if (fscanf(fp, "pi:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read PI annotation from HMM file"); - hmm.pi = (double *)dvector(1, hmm.N); - for (i = 1; i <= hmm.N; i++) + // Read B + std::getline(file, line); + if (line != "B:") { - if (fscanf(fp, "%lf", &(hmm.pi[i])) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read PI vector from HMM file"); - if (hmm.pi[i] < 1e-6) - hmm.pi[i] = 1e-6; + printError("Error reading B"); + return CHMM(); + } + hmm.B = readMatrix(file, hmm.N, hmm.M); + if (hmm.B.size() != (size_t)hmm.N || hmm.B[0].size() != (size_t)hmm.M) + { + printError("Error reading B"); + return CHMM(); + } + + // Read pi + std::getline(file, line); + if (line != "pi:") + { + printError("Error reading pi"); + return CHMM(); + } + hmm.pi = readVector(file, hmm.N); + if (hmm.pi.size() != (size_t)hmm.N) + { + printError("Error reading pi"); + return CHMM(); + } + + // Read B1_mean + std::getline(file, line); + if (line != "B1_mean:") + { + printError("Error reading B1_mean"); + return CHMM(); + } + hmm.B1_mean = readVector(file, hmm.N); + if (hmm.B1_mean.size() != (size_t)hmm.N) + { + printError("Error reading B1_mean"); + return CHMM(); + } + + // Read B1_sd + std::getline(file, line); + if (line != "B1_sd:") + { + printError("Error reading B1_sd"); + return CHMM(); + } + hmm.B1_sd = readVector(file, hmm.N); + if (hmm.B1_sd.size() != (size_t)hmm.N) + { + printError("Error reading B1_sd"); + return CHMM(); + } + + // Read B1_uf + std::getline(file, line); + if (line != "B1_uf:") + { + printError("Error reading B1_uf"); + return CHMM(); + } + std::getline(file, line); + try { + hmm.B1_uf = std::stod(line); + } catch (const std::invalid_argument& e) { + printError("Error reading B1_uf"); + return CHMM(); + } + + // Read B2_mean + std::getline(file, line); + if (line != "B2_mean:") + { + printError("Error reading B2_mean"); + return CHMM(); + } + hmm.B2_mean = readVector(file, 5); + if (hmm.B2_mean.size() != (size_t)5) + { + printError("Error reading B2_mean"); + return CHMM(); + } + + // Read B2_sd + std::getline(file, line); + if (line != "B2_sd:") + { + printError("Error reading B2_sd"); + return CHMM(); + } + hmm.B2_sd = readVector(file, 5); + if (hmm.B2_sd.size() != (size_t)5) + { + printError("Error reading B2_sd"); + return CHMM(); + + } + + // Read B2_uf + std::getline(file, line); + if (line != "B2_uf:") + { + printError("Error reading B2_uf"); + return CHMM(); + } + std::getline(file, line); + try { + hmm.B2_uf = std::stod(line); + } catch (const std::invalid_argument& e) { + printError("Error reading B2_uf"); + return CHMM(); } - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - if (fscanf(fp, "B1_mean:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B1_mean annotation from HMM file"); - hmm.B1_mean = (double *)dvector(1, hmm.N); - for (i = 1; i <= hmm.N; i++) - if (fscanf(fp, "%lf", &(hmm.B1_mean[i])) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B1_mean vector from HMM file"); - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - - if (fscanf(fp, "B1_sd:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B1_sd annotation from HMM file"); - hmm.B1_sd = (double *)dvector(1, hmm.N); - for (i = 1; i <= hmm.N; i++) - if (fscanf(fp, "%lf", &(hmm.B1_sd[i])) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B1_sd from HMM file"); - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - - if (fscanf(fp, "B1_uf:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B1_uf annotation from HMM file"); - if (fscanf(fp, "%lf", &(hmm.B1_uf)) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B1_uf from HMM file"); - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - - if (fscanf(fp, "B2_mean:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B2_mean annotation from HMM file"); - hmm.B2_mean = (double *)dvector(1, 5); - for (i = 1; i <= 5; i++) - if (fscanf(fp, "%lf", &(hmm.B2_mean[i])) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B2_mean from HMM file"); - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - - if (fscanf(fp, "B2_sd:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B2_sd annotation from HMM file"); - hmm.B2_sd = (double *)dvector(1, 5); - for (i = 1; i <= 5; i++) - if (fscanf(fp, "%lf", &(hmm.B2_sd[i])) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B2_sd from HMM file"); - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - - if (fscanf(fp, "B2_uf:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B2_uf annotation from HMM file"); - if (fscanf(fp, "%lf", &(hmm.B2_uf)) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B2_uf from HMM file"); - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - - if (fscanf(fp, "B3_mean:\n") != EOF) - { - hmm.NP_flag = 1; - hmm.B3_mean = (double *)dvector(1, hmm.N); - for (i = 1; i <= hmm.N; i++) - if (fscanf(fp, "%lf", &(hmm.B3_mean[i])) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B3_mean from HMM file"); - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - if (fscanf(fp, "B3_sd:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B3_sd annotation from HMM file"); - hmm.B3_sd = (double *)dvector(1, hmm.N); - for (i = 1; i <= hmm.N; i++) - if (fscanf(fp, "%lf", &(hmm.B3_sd[i])) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B3_sd from HMM file"); - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - if (fscanf(fp, "B3_uf:\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B3_uf annotation from HMM file"); - if (fscanf(fp, "%lf", &(hmm.B3_uf)) == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read B3_uf from HMM file"); - if (fscanf(fp, "\n") == EOF) - fprintf(stderr, "khmm::ReadCHMM: cannot read return character from HMM file"); - } - else - { - hmm.NP_flag = 0; - } - - if (fscanf(fp, "DIST:\n") != EOF) - { - if (fscanf(fp, "%d", &(hmm.dist)) == EOF) - fprintf(stderr, "khmm:ReadCHMM: cannot read DIST from HMM file"); - } - else - { - // hmm.dist = STATE_CHANGE; - // snp_dist is the default distance between two SNPs in the same state - // (not used in this implementation) - // Set it to 1 to disable the distance model - hmm.dist = 1; - } - - fclose(fp); return hmm; } + +std::vector> readMatrix(std::ifstream &file, int rows, int cols) +{ + std::vector> matrix(rows, std::vector(cols)); + for (int i = 0; i < rows; i++) + { + for (int j = 0; j < cols; j++) + { + if (!(file >> matrix[i][j])) + { + printError("Error reading matrix"); + return std::vector>(); + } + } + } + file.ignore(std::numeric_limits::max(), '\n'); + return matrix; +} + +std::vector readVector(std::ifstream &file, int size) +{ + std::vector vector(size); + for (int i = 0; i < size; i++) + { + if (!(file >> vector[i])) + { + printError("Error reading vector"); + return std::vector(); + } + } + file.ignore(std::numeric_limits::max(), '\n'); + return vector; +} diff --git a/src/main.cpp b/src/main.cpp new file mode 100644 index 00000000..874f444f --- /dev/null +++ b/src/main.cpp @@ -0,0 +1,223 @@ + +#include "swig_interface.h" + +/// @cond DOXYGEN_IGNORE +#include +#include +#include + +// For signal handling +#include +#include + +// #include +/// @endcond + +#include "input_data.h" +#include "version.h" +#include "utils.h" + + +void printStackTrace(int sig) +{ + void *array[10]; + size_t size; + + // get void*'s for all entries on the stack + size = backtrace(array, 10); + + // print out all the frames to stderr + fprintf(stderr, "Error: signal %d:\n", sig); + backtrace_symbols_fd(array, size, STDERR_FILENO); + exit(1); +} + + +void printBanner() +{ + std::time_t now = std::time(nullptr); + char date_str[100]; + std::strftime(date_str, sizeof(date_str), "%Y-%m-%d", std::localtime(&now)); + std::cout << "═══════════════════════════════════════════════════════════════" << std::endl; + std::cout << " ContextSV - Long-read Structural Variant Caller" << std::endl; + std::cout << " Version: " << VERSION << std::endl; + std::cout << " Date: " << date_str << std::endl; + std::cout << "═══════════════════════════════════════════════════════════════" << std::endl; +} + +void runContextSV(const std::unordered_map& args) +{ + // Set up signal handling + signal(SIGSEGV, printStackTrace); + signal(SIGABRT, printStackTrace); + signal(SIGINT, printStackTrace); + signal(SIGTERM, printStackTrace); + signal(SIGILL, printStackTrace); + signal(SIGFPE, printStackTrace); + signal(SIGBUS, printStackTrace); + + printBanner(); + + // Set up input data + InputData input_data; + input_data.setLongReadBam(args.at("bam-file")); + input_data.setRefGenome(args.at("ref-file")); + input_data.setSNPFilepath(args.at("snps-file")); + input_data.setOutputDir(args.at("output-dir")); + if (args.find("chr") != args.end()) { + input_data.setChromosome(args.at("chr")); + } + if (args.find("region") != args.end()) { + input_data.setRegion(args.at("region")); + } + if (args.find("thread-count") != args.end()) { + input_data.setThreadCount(std::stoi(args.at("thread-count"))); + } + if (args.find("hmm-file") != args.end()) { + input_data.setHMMFilepath(args.at("hmm-file")); + } + if (args.find("sample-size") != args.end()) { + input_data.setSampleSize(std::stoi(args.at("sample-size"))); + } + if (args.find("min-cnv") != args.end()) { + input_data.setMinCNVLength(std::stoi(args.at("min-cnv"))); + } + if (args.find("eth") != args.end()) { + input_data.setEthnicity(args.at("eth")); + } + if (args.find("pfb-file") != args.end()) { + input_data.setAlleleFreqFilepaths(args.at("pfb-file")); + } + if (args.find("assembly-gaps") != args.end()) { + input_data.setAssemblyGaps(args.at("assembly-gaps")); + } + if (args.find("save-cnv") != args.end()) { + input_data.saveCNVData(true); + } + if (args.find("debug") != args.end()) { + input_data.setVerbose(true); + } + + // DBSCAN parameters + if (args.find("epsilon") != args.end()) { + input_data.setDBSCAN_Epsilon(std::stod(args.at("epsilon"))); + } + + if (args.find("min-pts-pct") != args.end()) { + input_data.setDBSCAN_MinPtsPct(std::stod(args.at("min-pts-pct"))); + } + + // Set up the CNV JSON file if enabled + if (input_data.getSaveCNVData()) { + const std::string output_dir = input_data.getOutputDir(); + std::string json_filepath = output_dir + "/CNVCalls.json"; + + // Remove the old JSON file if it exists + if (fileExists(json_filepath)) { + remove(json_filepath.c_str()); + } + input_data.setCNVOutputFile(json_filepath); + std::cout << "Saving CNV data to: " << json_filepath << std::endl; + } + + // Run ContextSV + run(input_data); +} + +void printUsage(const std::string& programName) { + std::cerr << "Usage: " << programName << " [options]\n" + << "Options:\n" + << " -b, --bam Long-read BAM file (required)\n" + << " -r, --ref Reference genome FASTA file (required)\n" + << " -s, --snp SNPs VCF file (required)\n" + << " -o, --outdir Output directory (required)\n" + << " -c, --chr Chromosome\n" + << " -r, --region Region (start-end)\n" + << " -t, --threads Number of threads\n" + << " -h, --hmm HMM file\n" + << " -n, --sample-size Sample size for HMM predictions\n" + << " --min-cnv Minimum CNV length\n" + << " --eps DBSCAN epsilon\n" + << " --min-pts-pct Percentage of mean chr. coverage to use for DBSCAN minimum points\n" + << " -e, --eth ETH file\n" + << " -p, --pfb PFB file\n" + << " --assembly-gaps Assembly gaps file\n" + << " --save-cnv Save CNV data\n" + << " --debug Debug mode with verbose logging\n" + << " --version Print version and exit\n" + << " -h, --help Print usage and exit\n"; +} + +std::unordered_map parseArguments(int argc, char* argv[]) { + std::unordered_map args; + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + // Handle short and long options + if ((arg == "-b" || arg == "--bam") && i + 1 < argc) { + args["bam-file"] = argv[++i]; + } else if ((arg == "-r" || arg == "--ref") && i + 1 < argc) { + args["ref-file"] = argv[++i]; + } else if ((arg == "-s" || arg == "--snp") && i + 1 < argc) { + args["snps-file"] = argv[++i]; + } else if ((arg == "-o" || arg == "--outdir") && i + 1 < argc) { + args["output-dir"] = argv[++i]; + } else if ((arg == "-c" || arg == "--chr") && i + 1 < argc) { + args["chr"] = argv[++i]; + } else if ((arg == "-r" || arg == "--region") && i + 1 < argc) { + args["region"] = argv[++i]; + } else if ((arg == "-t" || arg == "--threads") && i + 1 < argc) { + args["thread-count"] = argv[++i]; + } else if ((arg == "-h" || arg == "--hmm") && i + 1 < argc) { + args["hmm-file"] = argv[++i]; + } else if ((arg == "-n" || arg == "--sample-size") && i + 1 < argc) { + args["sample-size"] = argv[++i]; + } else if (arg == "--min-cnv" && i + 1 < argc) { + args["min-cnv"] = argv[++i]; + } else if (arg == "--min-reads" && i + 1 < argc) { + args["min-reads"] = argv[++i]; + } else if (arg == "--eps" && i + 1 < argc) { + args["epsilon"] = argv[++i]; + } else if (arg == "--min-pts-pct" && i + 1 < argc) { + args["min-pts-pct"] = argv[++i]; + } else if ((arg == "-e" || arg == "--eth") && i + 1 < argc) { + args["eth"] = argv[++i]; + } else if ((arg == "-p" || arg == "--pfb") && i + 1 < argc) { + args["pfb-file"] = argv[++i]; + } else if (arg == "--assembly-gaps" && i + 1 < argc) { + args["assembly-gaps"] = argv[++i]; + } else if (arg == "--save-cnv") { + args["save-cnv"] = "true"; + } else if (arg == "--debug") { + args["debug"] = "true"; + } else if ((arg == "-v" || arg == "--version")) { + std::cout << "ContextSV version " << VERSION << std::endl; + exit(0); + } else if (arg == "-h" || arg == "--help") { + printUsage(argv[0]); + exit(0); + } else { + std::cerr << "Unknown option: " << arg << std::endl; + } + } + + // Check for required arguments + bool hasLR = args.find("bam-file") != args.end(); + bool hasOutput = args.find("output-dir") != args.end(); + bool hasRef = args.find("ref-file") != args.end(); + bool hasSNPs = args.find("snps-file") != args.end(); + bool requiredArgs = hasLR && hasOutput && hasRef && hasSNPs; + if (!requiredArgs) { + std::cerr << "Missing required argument(s): -b/--bam, -r/--ref, -s/--snp, -o/--outdir" << std::endl; + exit(1); + } + + return args; +} + +int main(int argc, char* argv[]) { + auto args = parseArguments(argc, argv); + runContextSV(args); + + return 0; +} diff --git a/src/snp_info.cpp b/src/snp_info.cpp deleted file mode 100644 index 90045402..00000000 --- a/src/snp_info.cpp +++ /dev/null @@ -1,116 +0,0 @@ -#include "snp_info.h" - -/// @cond -#include -#include -#include -#include -#include -/// @endcond - -#define MIN_PFB 0.01 - - -// Function to remove the 'chr' prefix from chromosome names -std::string removeChrPrefix(std::string chr) -{ - if (chr.find("chr") != std::string::npos) { - return chr.substr(3); - } - return chr; -} - -void SNPInfo::insertSNPAlleleFrequency(std::string chr, int64_t pos, double baf) -{ - chr = removeChrPrefix(chr); - - // Add the chromosome to the SNP B-allele frequency map if it does not exist - if (this->snp_baf_map.find(chr) == this->snp_baf_map.end()) { - this->snp_baf_map[chr] = BST(); - } - - // Insert the SNP into the map with its position and B-allele frequency - // using a binary search tree to keep the SNP positions sorted - this->snp_baf_map[chr].insert({pos, baf}); -} - -void SNPInfo::insertSNPPopulationFrequency(std::string chr, int64_t pos, double pfb) -{ - chr = removeChrPrefix(chr); - - // Add the chromosome to the SNP population frequency map if it does not - // exist - if (this->snp_pfb_map.find(chr) == this->snp_pfb_map.end()) { - this->snp_pfb_map[chr] = std::unordered_map(); - } - - // Insert the SNP into the map with its position and population frequency of - // the B allele - this->snp_pfb_map[chr][pos] = pfb; -} - -std::tuple, std::vector, std::vector> SNPInfo::querySNPs(std::string chr, int64_t start, int64_t end) -{ - // Lock the mutex for reading SNP information - std::lock_guard lock(this->snp_info_mtx); - - chr = removeChrPrefix(chr); - - // Create an ordered map of SNP positions to BAF and PFB values - std::map> snp_map; - - // Query SNPs within a range (start, end) and return their BAF and PFB - // values as separate vectors - std::vector bafs; - std::vector pfbs; - std::vector pos; - - // Check if the chromosome exists in the B-allele frequency map - if (this->snp_baf_map.find(chr) == this->snp_baf_map.end()) { - return std::make_tuple(pos, bafs, pfbs); - } - - // Query the SNPs within the range and return their BAFs and corresponding - // positions - auto& baf_bst = this->snp_baf_map[chr]; - auto baf_start = baf_bst.lower_bound({start, 0.0}); - auto baf_end = baf_bst.upper_bound({end, 0.0}); - for (auto it = baf_start; it != baf_end; it++) { - bafs.push_back(std::get<1>(*it)); - pos.push_back(std::get<0>(*it)); - } - - // Define a default PFB value (0.5) for SNPs with no population frequency data - pfbs = std::vector(bafs.size(), 0.5); - - // Check if the chromosome exists in the population frequency map - if (this->snp_pfb_map.find(chr) == this->snp_pfb_map.end()) { - return std::make_tuple(pos, bafs, pfbs); - } - - // Query the PFBs for all SNP positions with PFB data - auto& pfb_map = this->snp_pfb_map[chr]; - for (size_t i = 0; i < pos.size(); i++) { - int64_t snp_pos = pos[i]; - if (pfb_map.find(snp_pos) != pfb_map.end()) { - pfbs[i] = pfb_map[snp_pos]; - } - } - - return std::make_tuple(pos, bafs, pfbs); -} - -std::pair SNPInfo::getSNPRange(std::string chr) -{ - chr = removeChrPrefix(chr); - - // Get the range of SNP positions for a given chromosome - int64_t start = 0; - int64_t end = 0; - if (this->snp_baf_map.find(chr) != this->snp_baf_map.end()) { - auto& baf_bst = this->snp_baf_map[chr]; - start = std::get<0>(*baf_bst.begin()); - end = std::get<0>(*baf_bst.rbegin()); - } - return std::make_pair(start, end); -} diff --git a/src/sv_caller.cpp b/src/sv_caller.cpp index 179a5360..29dab604 100644 --- a/src/sv_caller.cpp +++ b/src/sv_caller.cpp @@ -16,623 +16,1353 @@ #include #include #include - +#include +#include +#include +#include +#include +#include + +#include "ThreadPool.h" #include "utils.h" #include "sv_types.h" +#include "version.h" +#include "fasta_query.h" +#include "dbscan.h" +#include "dbscan1d.h" +#include "debug.h" /// @endcond # define DUP_SEQSIM_THRESHOLD 0.9 // Sequence similarity threshold for duplication detection int SVCaller::readNextAlignment(samFile *fp_in, hts_itr_t *itr, bam1_t *bam1) { - // Read the next alignment + std::shared_lock lock(this->shared_mutex); int ret = sam_itr_next(fp_in, itr, bam1); - - // Return the result of reading the next alignment return ret; } -RegionData SVCaller::detectSVsFromRegion(std::string region) +std::vector SVCaller::getChromosomes(const std::string &bam_filepath) { - SVData sv_calls; - std::string bam_filepath = this->input_data->getLongReadBam(); + // Open the BAM file + samFile *fp_in = sam_open(bam_filepath.c_str(), "r"); + if (!fp_in) { + printError("ERROR: failed to open BAM file " + bam_filepath); + return {}; + } + bam_hdr_t *bamHdr = sam_hdr_read(fp_in); + if (!bamHdr) { + sam_close(fp_in); + printError("ERROR: failed to read header from " + bam_filepath); + return {}; + } + std::vector chromosomes; + for (int i = 0; i < bamHdr->n_targets; i++) { + chromosomes.push_back(bamHdr->target_name[i]); + } + bam_hdr_destroy(bamHdr); + sam_close(fp_in); + return chromosomes; +} - // Open the BAM file in a thread-safe manner +void SVCaller::findSplitSVSignatures(std::unordered_map> &sv_calls, const InputData &input_data) +{ + // Open the BAM file + std::string bam_filepath = input_data.getLongReadBam(); samFile *fp_in = sam_open(bam_filepath.c_str(), "r"); - if (fp_in == NULL) { - std::cerr << "ERROR: failed to open " << bam_filepath << std::endl; - exit(1); + if (!fp_in) { + printError("ERROR: failed to open " + bam_filepath); + return; } - // Get the header in a thread-safe manner + // Set maximum thread count + int thread_count = input_data.getThreadCount(); + hts_set_threads(fp_in, thread_count); + printMessage("Using " + std::to_string(thread_count) + " threads for split read analysis"); + + // Load the header bam_hdr_t *bamHdr = sam_hdr_read(fp_in); - if (bamHdr == NULL) { - std::cerr << "ERROR: failed to read header for " << bam_filepath << std::endl; - exit(1); + if (!bamHdr) { + sam_close(fp_in); + printError("ERROR: failed to read header from " + bam_filepath); + return; } - // Get the index in a thread-safe manner + // Load the index hts_idx_t *idx = sam_index_load(fp_in, bam_filepath.c_str()); - if (idx == NULL) { - std::cerr << "ERROR: failed to load index for " << bam_filepath << std::endl; - exit(1); + if (!idx) { + bam_hdr_destroy(bamHdr); + sam_close(fp_in); + printError("ERROR: failed to load index for " + bam_filepath); + return; } - // Create a read and iterator for the region in a thread-safe manner + // Alignment data structures + std::unordered_map> primary_map; // TID-> qname -> primary alignment + std::unordered_map> supp_map; // qname -> supplementary alignment + bam1_t *bam1 = bam_init1(); - hts_itr_t *itr = sam_itr_querys(idx, bamHdr, region.c_str()); + if (!bam1) { + printError("ERROR: failed to initialize BAM record"); + return; + } + + // Set the region to the whole genome, or a user-specified chromosome + hts_itr_t *itr = nullptr; + if (input_data.isSingleChr()) { + std::string chr = input_data.getChromosome(); + itr = sam_itr_querys(idx, bamHdr, chr.c_str()); + if (!itr) { + bam_destroy1(bam1); + printError("ERROR: failed to create iterator for " + chr); + return; + } + } else { + itr = sam_itr_queryi(idx, HTS_IDX_START, 0, 0); + if (!itr) { + bam_destroy1(bam1); + printError("ERROR: failed to create iterator for the whole genome"); + return; + } + } - // Loop through the alignments - // Create a map of primary and supplementary alignments by QNAME (query template name) - int num_alignments = 0; - PrimaryMap primary_alignments; - SuppMap supplementary_alignments; + uint32_t primary_count = 0; + uint32_t supplementary_count = 0; + + // Main loop to process the alignments + printMessage("Processing alignments from " + bam_filepath); + uint32_t num_alignments = 0; + std::unordered_set alignment_tids; // All unique chromosome IDs + std::unordered_set supp_qnames; // All unique query names while (readNextAlignment(fp_in, itr, bam1) >= 0) { - // Skip secondary and unmapped alignments, duplicates, and QC failures - if (bam1->core.flag & BAM_FSECONDARY || bam1->core.flag & BAM_FUNMAP || bam1->core.flag & BAM_FDUP || bam1->core.flag & BAM_FQCFAIL) { - // Do nothing - - // Skip alignments with low mapping quality - } else if (bam1->core.qual < this->min_mapq) { - // Do nothing - - } else { - // Get the QNAME (query template name) for associating split reads - std::string qname = bam_get_qname(bam1); - - // Process primary alignments - if (!(bam1->core.flag & BAM_FSUPPLEMENTARY)) { - - // Get the primary alignment chromosome, start, end, and depth - std::string chr = bamHdr->target_name[bam1->core.tid]; - int64_t start = bam1->core.pos; - int64_t end = bam_endpos(bam1); // This is the first position after the alignment - - // Call SVs directly from the CIGAR string - std::tuple, int32_t, int32_t> query_info = this->detectSVsFromCIGAR(bamHdr, bam1, sv_calls, true); - std::unordered_map match_map = std::get<0>(query_info); - int32_t query_start = std::get<1>(query_info); - int32_t query_end = std::get<2>(query_info); - - // Add the primary alignment to the map - AlignmentData alignment(chr, start, end, ".", query_start, query_end, match_map); - primary_alignments[qname] = std::move(alignment); - - // Process supplementary alignments - } else if (bam1->core.flag & BAM_FSUPPLEMENTARY) { - - // Add the supplementary alignment to the map - std::string chr = bamHdr->target_name[bam1->core.tid]; - int32_t start = bam1->core.pos; - int32_t end = bam_endpos(bam1); - - // Get CIGAR string information, but don't call SVs - std::tuple, int32_t, int32_t> query_info = this->detectSVsFromCIGAR(bamHdr, bam1, sv_calls, false); - const std::unordered_map& match_map = std::get<0>(query_info); - int32_t query_start = std::get<1>(query_info); - int32_t query_end = std::get<2>(query_info); - - // Add the supplementary alignment to the map - AlignmentData alignment(chr, start, end, ".", query_start, query_end, std::move(match_map)); - supplementary_alignments[qname].emplace_back(alignment); - - // If Read ID == 8873acc1-eb84-415d-8557-a32a8f52ccee, print the - // alignment - // if (qname == "8873acc1-eb84-415d-8557-a32a8f52ccee") { - // std::cout << "Supplementary alignment: " << chr << ":" << start << "-" << end << std::endl; - // std::cout << "Query start: " << query_start << ", Query end: " << query_end << std::endl; - // std::cout << "Match map: "; - // for (const auto& entry : match_map) { - // std::cout << entry.first << ":" << entry.second << " "; - // } - // std::cout << std::endl; - // } - } + // Skip secondary and unmapped alignments, duplicates, QC failures, and low mapping quality + if (bam1->core.flag & BAM_FSECONDARY || bam1->core.flag & BAM_FUNMAP || bam1->core.flag & BAM_FDUP || bam1->core.flag & BAM_FQCFAIL || bam1->core.qual < this->min_mapq) { + continue; + } + const std::string qname = bam_get_qname(bam1); // Query template name + + // Process primary alignments + if (!(bam1->core.flag & BAM_FSUPPLEMENTARY)) { + // Store chromosome (TID), start, and end positions (1-based) of the + // primary alignment, and the strand (true for forward, false for + // reverse) + std::pair qpos = getAlignmentReadPositions(bam1); + + primary_map[bam1->core.tid][qname] = PrimaryAlignment{bam1->core.pos + 1, bam_endpos(bam1), qpos.first, qpos.second, !(bam1->core.flag & BAM_FREVERSE), 0}; + alignment_tids.insert(bam1->core.tid); + primary_count++; + + // Process supplementary alignments + } else if (bam1->core.flag & BAM_FSUPPLEMENTARY) { + // Store chromosome (TID), start, and end positions (1-based) of the + // supplementary alignment, and the strand (true for forward, false + // for reverse) + std::pair qpos = getAlignmentReadPositions(bam1); + supp_map[qname].push_back(SuppAlignment{bam1->core.tid, bam1->core.pos + 1, bam_endpos(bam1), qpos.first, qpos.second, !(bam1->core.flag & BAM_FREVERSE)}); + alignment_tids.insert(bam1->core.tid); + supp_qnames.insert(qname); + supplementary_count++; } - - // Increment the number of alignment records processed num_alignments++; + + if (num_alignments % 1000000 == 0) { + printMessage("Processed " + std::to_string(num_alignments) + " alignments"); + } } - // Destroy the iterator + // Clean up the iterator and alignment hts_itr_destroy(itr); - - // Destroy the read bam_destroy1(bam1); - // Close the BAM file + // Clean up the BAM file and index sam_close(fp_in); - - // Destroy the header - bam_hdr_destroy(bamHdr); - - // Destroy the index hts_idx_destroy(idx); - - // Return the SV calls and the primary and supplementary alignments - // return std::make_tuple(sv_calls, primary_alignments, - // supplementary_alignments); - return std::make_tuple(std::move(sv_calls), std::move(primary_alignments), std::move(supplementary_alignments)); -} - -double SVCaller::calculateMismatchRate(std::unordered_map &match_map, int32_t start, int32_t end) -{ - // Calculate the mismatch rate - int match_count = 0; - int mismatch_count = 0; - for (int i = start; i <= end; i++) { - if (match_map.find(i) != match_map.end()) { - if (match_map[i] == 1) { - match_count++; - } else { - mismatch_count++; + // bam_hdr_destroy(bamHdr); + + // Remove primary alignments without supplementary alignments + std::unordered_map> to_remove; + for (auto& chr_primary : primary_map) { + std::unordered_set qnames; + for (const auto& entry : chr_primary.second) { + if (supp_qnames.find(entry.first) == supp_qnames.end()) { + to_remove[chr_primary.first].insert(entry.first); } } } - double mismatch_rate = (double)mismatch_count / (double)(match_count + mismatch_count); - // Return the mismatch rate - return mismatch_rate; -} + int total_removed = 0; + for (auto& chr_primary : primary_map) { + // Remove the qnames from the primary map + total_removed += to_remove[chr_primary.first].size(); + for (const auto& qname : to_remove[chr_primary.first]) { + chr_primary.second.erase(qname); + } + } + printMessage("Removed " + std::to_string(total_removed) + " primary alignments without supplementary alignments"); + + // Process the primary alignments and find SVs + for (const auto& chr_primary : primary_map) { + int primary_tid = chr_primary.first; + std::string chr_name = bamHdr->target_name[primary_tid]; + printMessage("Processing chromosome " + chr_name + " with " + std::to_string(chr_primary.second.size()) + " primary alignments"); + + std::vector chr_sv_calls; + chr_sv_calls.reserve(1000); + const std::unordered_map& chr_primary_map = chr_primary.second; + + // Identify overlapping primary alignments and cluster endpoints + std::unique_ptr root = nullptr; + for (const auto& entry : chr_primary_map) { + const std::string& qname = entry.first; + const PrimaryAlignment& region = entry.second; + insert(root, region, qname); + } -SVCaller::SVCaller(InputData &input_data) -{ - this->input_data = &input_data; -} + std::vector> primary_clusters; + std::set processed; + for (const auto& entry : chr_primary_map) { + const std::string& qname = entry.first; + if (processed.find(qname) != processed.end()) { + continue; // Skip already processed primary alignments + } + const PrimaryAlignment& primary_aln = entry.second; + std::vector overlap_group; + findOverlaps(root, primary_aln, overlap_group); + for (const std::string& qname : overlap_group) { + processed.insert(qname); + } + if (overlap_group.size() > 1) { + primary_clusters.push_back(overlap_group); + } + } -std::tuple, int32_t, int32_t> SVCaller::detectSVsFromCIGAR(bam_hdr_t* header, bam1_t* alignment, SVData& sv_calls, bool is_primary) -{ - // Get the chromosome - std::string chr = header->target_name[alignment->core.tid]; + // For each primary alignment cluster the supplementary alignment start and + // end positions, keeping the median of the largest cluster + int current_group = 0; + int min_length = 2000; + int max_length = 1000000; + for (const auto& primary_cluster : primary_clusters) { + // Determine if the primary alignments are mostly on opposite strands to + // the corresponding supplementary alignments (potential inversions) + bool inversion = false; + int num_primary = (int) primary_cluster.size(); + int num_supp_opposite_strand = 0; + for (const std::string& qname : primary_cluster) { + const std::vector& supp_alns = supp_map[qname]; + bool primary_strand = chr_primary_map.at(qname).strand; + bool has_opposite_strand = false; + for (const SuppAlignment& supp_aln : supp_alns) { + // Analyze if on the same chromosome + if (supp_aln.tid == primary_tid && supp_aln.strand != primary_strand) { + has_opposite_strand = true; + } + } + if (has_opposite_strand) { + num_supp_opposite_strand++; + } + } + if (static_cast(num_supp_opposite_strand) / static_cast(num_primary) > 0.5) { + inversion = true; + } - // Get the position of the alignment in the reference genome - int32_t pos = alignment->core.pos; + // Use DBSCAN to cluster primary alignment start, end positions + DBSCAN1D dbscan(100, 5); + current_group++; + std::vector starts; + std::vector ends; + std::vector primary_strands; + for (const std::string& qname : primary_cluster) { + const PrimaryAlignment& primary_aln = chr_primary_map.at(qname); + starts.push_back(primary_aln.start); + ends.push_back(primary_aln.end); + primary_strands.push_back(primary_aln.strand); + } - // Get the CIGAR string - uint32_t* cigar = bam_get_cigar(alignment); + // Get the largest cluster of primary alignment start positions + dbscan.fit(starts); + std::vector primary_start_cluster = dbscan.getLargestCluster(starts); - // Get the CIGAR length - int cigar_len = alignment->core.n_cigar; + // Get the largest cluster of primary alignment end positions + dbscan.fit(ends); + std::vector primary_end_cluster = dbscan.getLargestCluster(ends); - // Track the query position - int query_pos = 0; + // Continue if no clusters were found + if (primary_start_cluster.empty() && primary_end_cluster.empty()) { + continue; + } - // Loop through the CIGAR string (0-based) and detect insertions and deletions in - // reference coordinates (1-based) - // POS is the leftmost position of where the alignment maps to the reference: - // https://genome.sph.umich.edu/wiki/SAM - // std::vector threads; - // std::vector sv_calls_vec; + // Get the supplementary alignment positions, and also the distances + // between the primary and supplementary alignments on the read + std::vector supp_starts; + std::vector supp_ends; + std::vector supp_strands; + std::vector read_distances; + std::vector ref_distances; + for (const std::string& qname : primary_cluster) { + const PrimaryAlignment& primary_aln = chr_primary_map.at(qname); + const std::vector& supp_alns = supp_map.at(qname); + for (const SuppAlignment& supp_aln : supp_alns) { + if (supp_aln.tid == primary_tid) { + // Same chromosome + int read_distance = 0; + int ref_distance = 0; + supp_starts.push_back(supp_aln.start); + supp_ends.push_back(supp_aln.end); + supp_strands.push_back(supp_aln.strand); + + // Calculate the distance between the primary and supplementary + // alignments on the read if on the same chromosome and same + // strand + if (supp_aln.strand == primary_aln.strand) { + // Same strand + + // Check if the primary alignment is 5'-most + bool primary_5p = false; + if (primary_aln.start < supp_aln.start) { + primary_5p = true; + } + + // Calculate distance between alignments on the read + read_distance = std::max(0, std::max(static_cast(supp_aln.query_start), static_cast(primary_aln.query_start)) - std::min(static_cast(supp_aln.query_end), static_cast(primary_aln.query_end))); + + // Calculate distance between alignments on the + // reference + ref_distance = std::max(0, std::max(static_cast(supp_aln.start), static_cast(primary_aln.start)) - std::min(static_cast(supp_aln.end), static_cast(primary_aln.end))); + + // Throw an error if the read distance is negative + if (read_distance < 0) { + printError("ERROR: negative read distance between primary and supplementary alignments for " + qname); + } + // Throw an error if the reference distance is + // negative + if (ref_distance < 0) { + printError("ERROR: negative reference distance between primary and supplementary alignments for " + qname); + } + + // Use a negative read distance to indicate that the + // primary alignment is not 5'-most + if (!primary_5p) { + read_distance = -read_distance; + } + read_distances.push_back(read_distance); + ref_distances.push_back(ref_distance); + } - // Create a map of query position to match/mismatch (1/0) for calculating - // the mismatch rate at alignment overlaps - std::unordered_map query_match_map; + } else { + // TODO: TRANSLOCATIONS + } + } + } - // Loop through the CIGAR string, process operations, detect SVs (primary - // only), update clipped base support, calculate sequence identity for - // potential duplications (primary only), and calculate - // the clipped base support and mismatch rate - int32_t ref_pos; - int32_t ref_end; - int32_t query_start = 0; // First alignment position in the query - int32_t query_end = 0; // Last alignment position in the query - bool first_op = false; // First alignment operation for the query - for (int i = 0; i < cigar_len; i++) { + // Get the largest cluster of supplementary alignment start positions + dbscan.fit(supp_starts); + std::vector supp_start_cluster = dbscan.getLargestCluster(supp_starts); - // Get the CIGAR operation - int op = bam_cigar_op(cigar[i]); + // Get the largest cluster of supplementary alignment end positions + dbscan.fit(supp_ends); + std::vector supp_end_cluster = dbscan.getLargestCluster(supp_ends); - // Get the CIGAR operation length - int op_len = bam_cigar_oplen(cigar[i]); - - // Check if the CIGAR operation is an insertion - if (op == BAM_CINS && is_primary) { + // Get the largest cluster of read distances + dbscan.fit(read_distances); + std::vector read_distance_cluster = dbscan.getLargestCluster(read_distances); - // Add the SV if greater than the minimum SV size - if (op_len >= this->min_sv_size) { + // Get the largest cluster of reference distances + dbscan.fit(ref_distances); + std::vector ref_distance_cluster = dbscan.getLargestCluster(ref_distances); - // Get the sequence of the insertion from the query - // std::string ins_seq_str = ""; - // uint8_t* seq_ptr = bam_get_seq(alignment); - // for (int j = 0; j < op_len; j++) { - // ins_seq_str += seq_nt16_str[bam_seqi(seq_ptr, query_pos + j)]; - // } - std::string ins_seq_str(op_len, ' '); - for (int j = 0; j < op_len; j++) { - ins_seq_str[j] = seq_nt16_str[bam_seqi(bam_get_seq(alignment), query_pos + j)]; - } - - // To determine whether the insertion is a duplication, check - // for sequence identity between the insertion and the - // reference genome (duplications are typically >= 90%) - - // Loop from the leftmost position of the insertion (pos-op_len) - // to the rightmost position of the insertion (pos+op_len-1) and - // calculate the sequence identity at each window of the - // insertion length to identify potential duplications. + // Continue if no clusters were found + if (supp_start_cluster.empty() && supp_end_cluster.empty() && read_distance_cluster.empty() && ref_distance_cluster.empty()) { + continue; + } - // Loop through the reference sequence and calculate the - // sequence identity +/- insertion length from the insertion - // position. - bool is_duplication = false; - int ins_ref_pos; - for (int j = pos - op_len; j <= pos; j++) { + // Use the median of the largest cluster of primary and supplementary + // alignment start, end positions as the final genome coordinates of the + // SV + std::vector primary_positions; + int primary_cluster_size = 0; + // bool primary_start = false; + bool primary_end = false; + if (!primary_start_cluster.empty()) { + std::sort(primary_start_cluster.begin(), primary_start_cluster.end()); + primary_positions.push_back(primary_start_cluster[primary_start_cluster.size() / 2]); + primary_cluster_size = primary_start_cluster.size(); + // primary_start = true; + } - // Get the string for the window (1-based coordinates) - ins_ref_pos = j + 1; - std::string window_str = this->input_data->queryRefGenome(chr, ins_ref_pos, ins_ref_pos + op_len - 1); + if (!primary_end_cluster.empty()) { + std::sort(primary_end_cluster.begin(), primary_end_cluster.end()); + primary_positions.push_back(primary_end_cluster[primary_end_cluster.size() / 2]); + primary_cluster_size = std::max(primary_cluster_size, (int) primary_end_cluster.size()); + primary_end = true; + } - // Continue if the window string is empty (out-of-range) - if (window_str == "") { - continue; - } + // Get the supplementary alignment positions + std::vector supp_positions; + // bool supp_start = false; + bool supp_end = false; + int supp_cluster_size = 0; + if (!supp_start_cluster.empty()) { + std::sort(supp_start_cluster.begin(), supp_start_cluster.end()); + supp_positions.push_back(supp_start_cluster[supp_start_cluster.size() / 2]); + supp_cluster_size = supp_start_cluster.size(); + // supp_start = true; + } + if (!supp_end_cluster.empty()) { + std::sort(supp_end_cluster.begin(), supp_end_cluster.end()); + supp_positions.push_back(supp_end_cluster[supp_end_cluster.size() / 2]); + supp_cluster_size = std::max(supp_cluster_size, (int) supp_end_cluster.size()); + supp_end = true; + } - // Calculate the sequence identity - int num_matches = 0; - for (int k = 0; k < op_len; k++) { - if (ins_seq_str[k] == window_str[k]) { - num_matches++; - } + // Store the inversion as the supplementary start and end positions + // if (inversion && supp_positions.size() > 1) { + // std::sort(supp_positions.begin(), supp_positions.end()); + // int supp_start = supp_positions.front(); + // int supp_end = supp_positions.back(); + // int sv_length = std::abs(supp_start - supp_end); + + // // Use 50bp as the minimum length for an inversion + // if (sv_length >= 50 && sv_length <= max_length) { + // SVEvidenceFlags aln_type; + // aln_type.set(static_cast(SVDataType::SUPPINV)); + // SVCall sv_candidate(supp_start, supp_end, SVType::INV, getSVTypeSymbol(SVType::INV), aln_type, Genotype::UNKNOWN, 0.0, 0, 0, supp_cluster_size); + // // SVCall sv_candidate(supp_start, supp_end, SVType::INV, getSVTypeSymbol(SVType::INV), SVDataType::SUPPINV, Genotype::UNKNOWN, 0.0, 0, 0, supp_cluster_size); + // addSVCall(chr_sv_calls, sv_candidate); + // } + // } + + // ------------------------------- + // SPLIT INSERTION CALLS + int read_distance = 0; + int ref_distance = 0; + if (!read_distance_cluster.empty() && !ref_distance_cluster.empty()) { + // Use the median of the largest cluster of split distances as the + // insertion size + std::sort(read_distance_cluster.begin(), read_distance_cluster.end()); + read_distance = read_distance_cluster[read_distance_cluster.size() / 2]; + bool primary_5p_most = read_distance > 0; + read_distance = std::abs(read_distance); + + std::sort(ref_distance_cluster.begin(), ref_distance_cluster.end()); + ref_distance = ref_distance_cluster[ref_distance_cluster.size() / 2]; + + // Use the 3'-most primary position as the start position + int sv_start; + bool split_candidate_sv = false; + if (primary_5p_most && primary_end) { + std::sort(primary_positions.begin(), primary_positions.end()); + sv_start = primary_positions.back(); + split_candidate_sv = true; + } else if (!primary_5p_most && supp_end) { + + // Supplementary alignment is upstream with the + // insertion sequence, starting at the 5'-most + // primary position + std::sort(supp_positions.begin(), supp_positions.end()); + sv_start = supp_positions.back(); + split_candidate_sv = true; + } + SVEvidenceFlags aln_type; + aln_type.set(static_cast(SVDataType::SPLITDIST1)); + if (split_candidate_sv) { + int aln_offset = static_cast(ref_distance - read_distance); + if (read_distance > ref_distance && read_distance >= min_length && read_distance <= max_length) { + // Add an insertion SV call at the 5'-most primary position + SVType sv_type = SVType::INS; + SVCall sv_candidate(sv_start, sv_start + (read_distance-1), sv_type, getSVTypeSymbol(sv_type), aln_type, Genotype::UNKNOWN, 0.0, 0, aln_offset, primary_cluster_size); + addSVCall(chr_sv_calls, sv_candidate); + // } + } else if (ref_distance > read_distance && ref_distance >= min_length && ref_distance <= max_length) { + + // Set it to unknown, SV type will be determined by the + // HMM prediction + SVType sv_type = SVType::UNKNOWN; + SVCall sv_candidate(sv_start, sv_start + (ref_distance-1), sv_type, getSVTypeSymbol(sv_type), aln_type, Genotype::UNKNOWN, 0.0, 0, aln_offset, primary_cluster_size); + addSVCall(chr_sv_calls, sv_candidate); } - float seq_identity = (float)num_matches / (float)op_len; + } + } - // Check if the target sequence identity is reached - if (seq_identity >= DUP_SEQSIM_THRESHOLD) { - is_duplication = true; - break; + // Add a dummy SV call for CNV detection + int cluster_size = std::max(primary_cluster_size, supp_cluster_size); + SVType sv_type = inversion ? SVType::INV : SVType::UNKNOWN; + std::string alt = (sv_type == SVType::INV) ? "" : "."; + for (int primary_pos : primary_positions) { + for (int supp_pos : supp_positions) { + int sv_start = std::min(primary_pos, supp_pos); + int sv_end = std::max(primary_pos, supp_pos) - 1; + int sv_length = sv_end - sv_start + 1; + if (sv_length >= min_length && sv_length <= max_length) { + SVEvidenceFlags aln_type; + aln_type.set(static_cast(SVDataType::SPLIT)); + SVCall sv_candidate(sv_start, sv_end, sv_type, alt, aln_type, Genotype::UNKNOWN, 0.0, 0, 0, cluster_size); + addSVCall(chr_sv_calls, sv_candidate); } } + } + } - // Add to SV calls (1-based) with the appropriate SV type - ref_pos = pos+1; - ref_end = ref_pos + op_len -1; + // Combine SVs with identical start and end positions, and sum the cluster + // sizes + std::sort(chr_sv_calls.begin(), chr_sv_calls.end(), [](const SVCall& a, const SVCall& b) { + return a.start < b.start || (a.start == b.start && a.end < b.end); + }); + + // Merge duplicate SV calls with identical start and end positions, and sum the + // cluster sizes + mergeDuplicateSVs(chr_sv_calls); + sv_calls[chr_name] = std::move(chr_sv_calls); + printMessage(chr_name + ": Found " + std::to_string(sv_calls[chr_name].size()) + " SV candidates"); + } - // Lock the SV calls object and add the insertion - std::lock_guard lock(this->sv_mtx); - if (is_duplication) { - sv_calls.add(chr, ref_pos, ref_end, DUP, ins_seq_str, "CIGARDUP", "./.", 0.0); - } else { - sv_calls.add(chr, ref_pos, ref_end, INS, ins_seq_str, "CIGARINS", "./.", 0.0); - } - } + // Clean up the BAM header + bam_hdr_destroy(bamHdr); +} + +void SVCaller::findCIGARSVs(samFile* fp_in, hts_idx_t* idx, bam_hdr_t* bamHdr, const std::string& region, std::vector& sv_calls, const std::vector& pos_depth_map) +{ + // Create a read and iterator for the region + bam1_t *bam1 = bam_init1(); + if (!bam1) { + printError("ERROR: failed to initialize BAM record"); + return; + } + hts_itr_t *itr = sam_itr_querys(idx, bamHdr, region.c_str()); + if (!itr) { + bam_destroy1(bam1); + printError("ERROR: failed to query region " + region); + return; + } - // Check if the CIGAR operation is a deletion - } else if (op == BAM_CDEL && is_primary) { + // Main loop to process the alignments + while (readNextAlignment(fp_in, itr, bam1) >= 0) { - // Add the SV if greater than the minimum SV size - if (op_len >= this->min_sv_size) { - - // Add the deletion to the SV calls (1-based) - ref_pos = pos+1; - ref_end = ref_pos + op_len -1; + // Skip secondary and unmapped alignments, duplicates, QC failures, and + // low mapping quality, and supplementary alignments + if (bam1->core.flag & BAM_FSECONDARY || bam1->core.flag & BAM_FUNMAP || bam1->core.flag & BAM_FDUP || bam1->core.flag & BAM_FQCFAIL || bam1->core.qual < this->min_mapq || bam1->core.flag & BAM_FSUPPLEMENTARY) { + continue; + } - // Lock the SV calls object and add the deletion - // std::lock_guard lock(this->sv_mtx); - sv_calls.add(chr, ref_pos, ref_end, DEL, ".", "CIGARDEL", "./.", 0.0); - } + // Process the alignment + this->processCIGARRecord(bamHdr, bam1, sv_calls, pos_depth_map); + } - // Check if the CIGAR operation is a clipped base - } else if (op == BAM_CSOFT_CLIP || op == BAM_CHARD_CLIP) { + // Clean up the iterator and alignment + hts_itr_destroy(itr); + bam_destroy1(bam1); +} - // Update the clipped base support - // std::lock_guard lock(this->sv_mtx); - sv_calls.updateClippedBaseSupport(chr, pos); +void SVCaller::processCIGARRecord(bam_hdr_t *header, bam1_t *alignment, std::vector &sv_calls, const std::vector &pos_depth_map) +{ + std::string chr = header->target_name[alignment->core.tid]; // Chromosome name + uint32_t aln_start = (uint32_t)alignment->core.pos; // Leftmost position of the alignment in the reference genome (0-based) + uint32_t pos = aln_start; - // Update the query alignment start position - if (!first_op) { - query_start = query_pos + op_len; - first_op = true; - } - } + uint32_t* cigar = bam_get_cigar(alignment); // CIGAR array + int cigar_len = alignment->core.n_cigar; + uint32_t query_pos = 0; - // Update match/mismatch query map - if (op == BAM_CEQUAL) { - // match_count += op_len; - for (int j = 0; j < op_len; j++) { - query_match_map[query_pos + j] = 1; - } - } else if (op == BAM_CDIFF) { - // mismatch_count += op_len; - for (int j = 0; j < op_len; j++) { - query_match_map[query_pos + j] = 0; - } - } else if (op == BAM_CMATCH) { - // Compare read and reference sequences - // Get the sequence from the query - uint8_t* seq_ptr = bam_get_seq(alignment); - std::string cmatch_seq_str = ""; - for (int j = 0; j < op_len; j++) { - cmatch_seq_str += seq_nt16_str[bam_seqi(seq_ptr, query_pos + j)]; - } + // Loop through the CIGAR string, process operations, detect SVs (primary + // only), and calculate sequence identity for potential duplications (primary only) + uint32_t ref_pos; + uint32_t ref_end; + double default_lh = 0.0; + const std::string amb_bases = "RYKMSWBDHV"; // Ambiguous bases + std::bitset<256> amb_bases_bitset; + for (char base : amb_bases) { + amb_bases_bitset.set(base); + amb_bases_bitset.set(std::tolower(base)); + } - // Get the corresponding reference sequence - int cmatch_pos = pos + 1; // Querying the reference genome is 1-based - std::string cmatch_ref_str = this->input_data->queryRefGenome(chr, cmatch_pos, cmatch_pos + op_len - 1); + std::vector cigar_sv_calls; + cigar_sv_calls.reserve(1000); + for (int i = 0; i < cigar_len; i++) { + int op_len = bam_cigar_oplen(cigar[i]); // CIGAR operation length + int op = bam_cigar_op(cigar[i]); // CIGAR operation + if (op_len >= 50) { + + // Process the CIGAR operation + if (op == BAM_CINS) { - // Check that the two sequence lengths are equal - if (cmatch_seq_str.length() != cmatch_ref_str.length()) { - std::cerr << "ERROR: Sequence lengths do not match" << std::endl; - exit(1); - } + // Get the sequence of the insertion from the query + std::string ins_seq_str(op_len, ' '); + for (int j = 0; j < op_len; j++) { + // Replace ambiguous bases with N + char base = seq_nt16_str[bam_seqi(bam_get_seq(alignment), query_pos + j)]; + if (amb_bases_bitset.test(base)) { + ins_seq_str[j] = 'N'; + } else { + ins_seq_str[j] = base; + } + } - // Compare the two sequences and update the mismatch map - for (int j = 0; j < op_len; j++) { - if (cmatch_seq_str[j] != cmatch_ref_str[j]) { - query_match_map[query_pos + j] = 0; - } else { - query_match_map[query_pos + j] = 1; + // Add as an insertion + uint32_t ins_pos = pos + 1; + uint32_t ins_end = ins_pos + op_len - 1; + + // Determine the ALT allele format based on small vs. large insertion + std::string alt_allele = ""; + if (op_len <= 50) { + alt_allele = ins_seq_str; } + SVEvidenceFlags aln_type; + aln_type.set(static_cast(SVDataType::CIGARINS)); + SVCall sv_call(ins_pos, ins_end, SVType::INS, alt_allele, aln_type, Genotype::UNKNOWN, default_lh, 0, 0, 0); + // SVCall sv_call(ins_pos, ins_end, SVType::INS, alt_allele, SVDataType::CIGARINS, Genotype::UNKNOWN, default_lh, 0, 0, 0); + cigar_sv_calls.emplace_back(sv_call); + + // Process clipped bases as potential insertions + } else if (op == BAM_CSOFT_CLIP) { + // Soft-clipped bases are considered as potential insertions + // Skip if the position exceeds the reference genome length + if (pos + 1 >= pos_depth_map.size()) { + continue; + } + + // Get the sequence of the insertion from the query + std::string ins_seq_str(op_len, ' '); + for (int j = 0; j < op_len; j++) { + // Replace ambiguous bases with N + char base = seq_nt16_str[bam_seqi(bam_get_seq(alignment), query_pos + j)]; + if (amb_bases_bitset.test(base)) { + ins_seq_str[j] = 'N'; + } else { + ins_seq_str[j] = base; + } + } + + // Add as an insertion + uint32_t ins_pos = pos + 1; + uint32_t ins_end = ins_pos + op_len - 1; + + // Determine the ALT allele format based on small vs. large insertion + std::string alt_allele = ""; + if (op_len <= 50) { + alt_allele = ins_seq_str; + } + SVEvidenceFlags aln_type; + aln_type.set(static_cast(SVDataType::CIGARCLIP)); + SVCall sv_call(ins_pos, ins_end, SVType::INS, alt_allele, aln_type, Genotype::UNKNOWN, default_lh, 0, 0, 0); + // SVCall sv_call(ins_pos, ins_end, SVType::INS, alt_allele, SVDataType::CIGARCLIP, Genotype::UNKNOWN, default_lh, 0, 0, 0); + cigar_sv_calls.emplace_back(sv_call); + + // Check if the CIGAR operation is a deletion + } else if (op == BAM_CDEL) { + + ref_pos = pos+1; + ref_end = ref_pos + op_len -1; + SVEvidenceFlags aln_type; + aln_type.set(static_cast(SVDataType::CIGARDEL)); + SVCall sv_call(ref_pos, ref_end, SVType::DEL, getSVTypeSymbol(SVType::DEL), aln_type, Genotype::UNKNOWN, default_lh, 0, 0, 0); + // SVCall sv_call(ref_pos, ref_end, SVType::DEL, getSVTypeSymbol(SVType::DEL), SVDataType::CIGARDEL, Genotype::UNKNOWN, default_lh, 0, 0, 0); + cigar_sv_calls.emplace_back(sv_call); } } - // Update the reference coordinate based on the CIGAR operation + // Update the reference position // https://samtools.github.io/hts-specs/SAMv1.pdf if (op == BAM_CMATCH || op == BAM_CDEL || op == BAM_CREF_SKIP || op == BAM_CEQUAL || op == BAM_CDIFF) { pos += op_len; - } else if (op == BAM_CINS || op == BAM_CSOFT_CLIP || op == BAM_CHARD_CLIP || op == BAM_CPAD) { - // Do nothing - } else { - std::cerr << "ERROR: Unknown CIGAR operation " << op << std::endl; - exit(1); } - - // Update the query position based on the CIGAR operation (M, I, S, H) + + // Update the query position if (op == BAM_CMATCH || op == BAM_CINS || op == BAM_CSOFT_CLIP || op == BAM_CEQUAL || op == BAM_CDIFF) { query_pos += op_len; - } else if (op == BAM_CDEL || op == BAM_CREF_SKIP || op == BAM_CHARD_CLIP || op == BAM_CPAD) { - // Do nothing - } else { - std::cerr << "ERROR: Unknown CIGAR operation " << op << std::endl; - exit(1); } } - // Update the query end position - query_end = query_pos; + for (SVCall& sv_call : cigar_sv_calls) { + addSVCall(sv_calls, sv_call); + } +} - // Return the mismatch map and the query start and end positions - return std::tuple, int32_t, int32_t>(query_match_map, query_start, query_end); +std::pair SVCaller::getAlignmentReadPositions(bam1_t *alignment) +{ + int query_start = -1; + int query_end = 0; + uint32_t* cigar = bam_get_cigar(alignment); + int cigar_len = alignment->core.n_cigar; + for (int i = 0; i < cigar_len; i++) { + int op_len = bam_cigar_oplen(cigar[i]); + int op = bam_cigar_op(cigar[i]); + + // Set the query start position to the first non-soft clip operation + if (query_start == -1 && (op == BAM_CMATCH || op == BAM_CINS || op == BAM_CEQUAL || op == BAM_CDIFF)) { + query_start = query_end; // First valid query position + } + + // https://github.com/samtools/htslib/blob/develop/htslib/sam.h: + // bam_cigar_type(o) (BAM_CIGAR_TYPE>>((o)<<1)&3) // bit 1: consume query; bit 2: consume reference + if (op == BAM_CMATCH || op == BAM_CINS || op == BAM_CSOFT_CLIP || op == BAM_CEQUAL || op == BAM_CDIFF) { + query_end += op_len; + } + } + + if (query_start == -1) { + query_start = 0; + } + + return std::make_pair(query_start, query_end); } -// Detect SVs from split read alignments (primary and supplementary) and -// directly from the CIGAR string -SVData SVCaller::run() +void SVCaller::processChromosome(const std::string& chr, std::vector& chr_sv_calls, const InputData& input_data, const std::vector& chr_pos_depth_map, double mean_chr_cov) { // Open the BAM file - std::string bam_filepath = this->input_data->getLongReadBam(); + std::string bam_filepath = input_data.getLongReadBam(); + samFile *fp_in = sam_open(bam_filepath.c_str(), "r"); + if (!fp_in) { + printError("ERROR: failed to open " + bam_filepath); + return; + } + hts_set_threads(fp_in, 1); - // Get the region data - std::vector chromosomes; - if (this->input_data->getChromosome() != "") { - chromosomes.push_back(this->input_data->getChromosome()); - } else { - chromosomes = this->input_data->getRefGenomeChromosomes(); + // Load the header + bam_hdr_t *bamHdr = sam_hdr_read(fp_in); + if (!bamHdr) { + sam_close(fp_in); + printError("ERROR: failed to read header from " + bam_filepath); + return; } - int chr_count = chromosomes.size(); - // Loop through each region and detect SVs - std::cout << "Detecting SVs from " << chr_count << " chromosome(s)..." << std::endl; - int region_count = 0; - auto start1 = std::chrono::high_resolution_clock::now(); - SVData sv_calls; - int chunk_count = 10000; // Number of chunks to split the chromosome into - int min_cnv_length = this->input_data->getMinCNVLength(); - for (const auto& chr : chromosomes) { - std::cout << "Running SV detection for chromosome " << chr << "..." << std::endl; + // Load the index + hts_idx_t *idx = sam_index_load(fp_in, bam_filepath.c_str()); + if (!idx) { + bam_hdr_destroy(bamHdr); + sam_close(fp_in); + printError("ERROR: failed to load index for " + bam_filepath); + return; + } + // BamFileGuard bam_guard(fp_in, idx, bamHdr); // Guard to close the BAM file + + // Get DBSCAN parameters + double dbscan_epsilon = input_data.getDBSCAN_Epsilon(); + int dbscan_min_pts = 5; + double dbscan_min_pts_pct = input_data.getDBSCAN_MinPtsPct(); + if (dbscan_min_pts_pct > 0.0) { + dbscan_min_pts = (int)std::ceil(mean_chr_cov * dbscan_min_pts_pct); + printMessage(chr + ": Mean chr. cov.: " + std::to_string(mean_chr_cov) + " (DBSCAN min. pts.= " + std::to_string(dbscan_min_pts) + ", min. pts. pct.= " + std::to_string(dbscan_min_pts_pct) + ")"); + } + + // ----------------------------------------------------------------------- + // Detect SVs from the CIGAR strings + printMessage(chr + ": CIGAR SVs..."); + this->findCIGARSVs(fp_in, idx, bamHdr, chr, chr_sv_calls, chr_pos_depth_map); + + // Clean up the BAM file and index + sam_close(fp_in); + hts_idx_destroy(idx); + bam_hdr_destroy(bamHdr); - // Split the chromosome into chunks - std::vector region_chunks; + printMessage(chr + ": Merging CIGAR..."); + // Save JSON if chr21 + // if (chr == "chr21") { + // std::string json_fp = input_data.getOutputDir() + "/" + chr + ".json"; + // mergeSVs(chr_sv_calls, dbscan_epsilon, dbscan_min_pts, true, json_fp); + // } else { + // mergeSVs(chr_sv_calls, dbscan_epsilon, dbscan_min_pts, false); + // } + mergeSVs(chr_sv_calls, dbscan_epsilon, dbscan_min_pts, false); + + int region_sv_count = getSVCount(chr_sv_calls); + printMessage(chr + ": Found " + std::to_string(region_sv_count) + " SV candidates in the CIGAR string"); +} - // Get the region start and end positions - if (this->input_data->isRegionSet()) { - std::pair region = this->input_data->getRegion(); - int region_start = region.first; - int region_end = region.second; +void SVCaller::run(const InputData& input_data) +{ + bool cigar_svs = true; + bool cigar_cn = true; + bool split_svs = true; + bool merge_split_svs = true; + bool merge_final_svs = true; + + // Print the input data + input_data.printParameters(); + + // Set up the reference genome + printMessage("Loading the reference genome..."); + const std::string ref_filepath = input_data.getRefGenome(); + std::shared_mutex ref_mutex; // Dummy mutex (remove later) + ReferenceGenome ref_genome(ref_mutex); + ref_genome.setFilepath(ref_filepath); + + // Get the chromosomes + std::vector chromosomes; + if (input_data.isSingleChr()) { + // Get the chromosome from the user input argument + chromosomes.push_back(input_data.getChromosome()); + } else { + // Get the chromosomes from the input BAM file + chromosomes = this->getChromosomes(input_data.getLongReadBam()); + } + + // Read the HMM from the file + std::string hmm_filepath = input_data.getHMMFilepath(); + std::cout << "Reading HMM from file: " << hmm_filepath << std::endl; + const CHMM& hmm = ReadCHMM(hmm_filepath.c_str()); + + // Set up the JSON output file for CNV data + const std::string& json_fp = input_data.getCNVOutputFile(); + + // Calculate the mean chromosome coverage and generate the position depth + // maps for each chromosome (I/O is multi-threaded, which is more efficient + // than per-chromosome multi-threading in this case) + std::shared_mutex shared_mutex; + CNVCaller cnv_caller(shared_mutex); + std::unordered_map> chr_pos_depth_map; + std::unordered_map chr_mean_cov_map; + const std::string bam_filepath = input_data.getLongReadBam(); + int chr_thread_count = input_data.getThreadCount(); + + // Initialize the chromosome position depth map and mean coverage map + for (const auto& chr : chromosomes) { + uint32_t chr_len = ref_genome.getChromosomeLength(chr); + if (chr_len == 0) { + printError("Chromosome " + chr + " not found in reference genome"); + return; + // continue; + } + chr_pos_depth_map[chr] = std::vector(chr_len+1, 0); // 1-based index + chr_mean_cov_map[chr] = 0.0; + } + cnv_caller.calculateMeanChromosomeCoverage(chromosomes, chr_pos_depth_map, chr_mean_cov_map, bam_filepath, chr_thread_count); - // Use one chunk for the region - std::string chunk = chr + ":" + std::to_string(region_start) + "-" + std::to_string(region_end); - region_chunks.push_back(chunk); - - } else { - int chr_len = this->input_data->getRefGenomeChromosomeLength(chr); - int chunk_size = chr_len / chunk_count; - for (int i = 0; i < chunk_count; i++) { - int start = i * chunk_size + 1; // 1-based - int end = start + chunk_size; - if (i == chunk_count - 1) { - end = chr_len; + // Remove chromosomes with no reads (mean coverage is zero) + printMessage("Removing chromosomes with no reads..."); + std::vector valid_chr; + for (const auto& chr : chromosomes) { + if (chr_mean_cov_map.find(chr) != chr_mean_cov_map.end()) { + valid_chr.push_back(chr); + } + chromosomes = valid_chr; + } + std::unordered_map> whole_genome_sv_calls; + int current_chr = 0; + int total_chr_count = chromosomes.size(); + + if (cigar_svs) { + // Use multi-threading across chromosomes. If a single chromosome is + // specified, use a single main thread (multi-threading is used for file I/O) + int thread_count = 1; + if (!input_data.isSingleChr()) { + thread_count = input_data.getThreadCount(); + std::cout << "Using " << thread_count << " threads for chr processing..." << std::endl; + } + ThreadPool pool(thread_count); + auto process_chr = [&](const std::string& chr) { + try { + std::vector sv_calls; + sv_calls.reserve(1000); + InputData chr_input_data = input_data; // Use a thread-local copy + this->processChromosome(chr, sv_calls, chr_input_data, chr_pos_depth_map[chr], chr_mean_cov_map[chr]); + { + std::shared_lock lock(this->shared_mutex); + whole_genome_sv_calls[chr] = std::move(sv_calls); } - std::string chunk = chr + ":" + std::to_string(start) + "-" + std::to_string(end); - region_chunks.push_back(chunk); + } catch (const std::exception& e) { + printError("Error processing chromosome " + chr + ": " + e.what()); + } catch (...) { + printError("Unknown error processing chromosome " + chr); } + }; + + // Submit tasks to the thread pool and track futures + std::vector> futures; + for (const auto& chr : chromosomes) { + futures.emplace_back(pool.enqueue([&, chr] { + process_chr(chr); + })); } - // Load chromosome data for copy number predictions - std::cout << "Loading chromosome data for copy number predictions..." << std::endl; - CNVCaller cnv_caller(*this->input_data); - cnv_caller.loadChromosomeData(chr); - // std::cout << "Loaded chromosome data for copy number predictions." << std::endl; - - // Process each chunk one at a time - std::cout << "Processing " << region_chunks.size() << " region(s) for chromosome " << chr << "..." << std::endl; - for (const auto& sub_region : region_chunks) { - // Detect SVs from the sub-region - // std::cout << "Detecting CIGAR string SVs from " << sub_region << "..." << std::endl; - RegionData region_data = this->detectSVsFromRegion(sub_region); - SVData& sv_calls_region = std::get<0>(region_data); - PrimaryMap& primary_map = std::get<1>(region_data); - SuppMap& supp_map = std::get<2>(region_data); - int region_sv_count = sv_calls_region.totalCalls(); - if (region_sv_count > 0) { - std::cout << "Detected " << region_sv_count << " SVs from " << sub_region << "..." << std::endl; + // Wait for all tasks to complete + for (auto& future : futures) { + try { + current_chr++; + future.get(); + } catch (const std::exception& e) { + printError("Error processing chromosome task: " + std::string(e.what())); + } catch (...) { + printError("Unknown error processing chromosome task."); } + } + printMessage("All tasks have finished."); + if (cigar_cn) { + // ------------------------------------------------------- // Run copy number variant predictions on the SVs detected from the // CIGAR string, using a minimum CNV length threshold - // std::cout << "Detecting copy number variants from CIGAR string SVs..." << std::endl; - std::map& cigar_svs = sv_calls_region.getChromosomeSVs(chr); - if (cigar_svs.size() > 0) { - std::cout << "Running copy number variant detection from CIGAR string SVs..." << std::endl; - cnv_caller.runCIGARCopyNumberPrediction(chr, cigar_svs, min_cnv_length); + current_chr = 0; + printMessage("Running copy number predictions on CIGAR SVs..."); + for (auto& entry : whole_genome_sv_calls) { + current_chr++; + const std::string& chr = entry.first; + std::vector& sv_calls = entry.second; + if (sv_calls.size() > 0) { + printMessage("(" + std::to_string(current_chr) + "/" + std::to_string(total_chr_count) + ") Running copy number predictions on " + chr + "..."); + cnv_caller.runCIGARCopyNumberPrediction(chr, sv_calls, hmm, chr_mean_cov_map[chr], chr_pos_depth_map[chr], input_data); + } + } + } + // ------------------------------------------------------- + } + + if (split_svs) { + DEBUG_PRINT("Identifying split-SV signatures..."); + std::unordered_map> whole_genome_split_sv_calls; + this->findSplitSVSignatures(whole_genome_split_sv_calls, input_data); + + DEBUG_PRINT("Running copy number predictions on split-read SVs..."); + current_chr = 0; + for (auto& entry : whole_genome_split_sv_calls) { + const std::string& chr = entry.first; + std::vector& sv_calls = entry.second; + + if (sv_calls.size() > 0) { + current_chr++; + DEBUG_PRINT("(" + std::to_string(current_chr) + "/" + std::to_string(total_chr_count) + ") Running copy number predictions on " + chr + " with " + std::to_string(sv_calls.size()) + " SV candidates..."); + this->runSplitReadCopyNumberPredictions(chr, sv_calls, cnv_caller, hmm, chr_mean_cov_map[chr], chr_pos_depth_map[chr], input_data); + } + } + + if (merge_split_svs) { + DEBUG_PRINT("Merging split-read SVs..."); + for (auto& entry : whole_genome_split_sv_calls) { + std::vector& sv_calls = entry.second; + mergeSVs(sv_calls, 0.1, 2, true); } - // std::cout << "Running copy number variant detection from CIGAR string SVs..." << std::endl; - // cnv_caller.runCIGARCopyNumberPrediction(chr, cigar_svs, min_cnv_length); + } - // Run split-read SV detection in a single thread, combined with - // copy number variant predictions - std::cout << "Detecting copy number variants from split reads..." << std::endl; - this->detectSVsFromSplitReads(sv_calls_region, primary_map, supp_map, cnv_caller); + DEBUG_PRINT("Unifying SVs..."); + for (auto& entry : whole_genome_split_sv_calls) { + const std::string& chr = entry.first; + std::vector& sv_calls = entry.second; + whole_genome_sv_calls[chr].insert(whole_genome_sv_calls[chr].end(), sv_calls.begin(), sv_calls.end()); + } + } - // Add the SV calls to the main SV calls object - sv_calls.concatenate(sv_calls_region); + if (merge_final_svs) { + // Merge any duplicate SV calls from the CIGAR and split-read + // detections (same start positions) + DEBUG_PRINT("Merging CIGAR and split read SV calls..."); + for (auto& entry : whole_genome_sv_calls) { + std::vector& sv_calls = entry.second; + mergeSVs(sv_calls, 0.1, 2, true); } + } - // Increment the region count - region_count++; - std::cout << "Completed " << region_count << " of " << chr_count << " chromosome(s)..." << std::endl; - // std::cout << "Extracted aligments for " << region_count << " of " << chr_count << " chromosome(s)..." << std::endl; + if (input_data.getSaveCNVData()) { + closeJSON(json_fp); } - auto end1 = std::chrono::high_resolution_clock::now(); - std::cout << "Finished detecting " << sv_calls.totalCalls() << " SVs from " << chr_count << " chromosome(s). Elapsed time: " << getElapsedTime(start1, end1) << std::endl; + // Print the total number of SVs detected for each chromosome + uint32_t total_sv_count = 0; + for (const auto& entry : whole_genome_sv_calls) { + std::string chr = entry.first; + int sv_count = getSVCount(entry.second); + total_sv_count += sv_count; + printMessage("Total SVs detected for " + chr + ": " + std::to_string(sv_count)); + } + printMessage("Total SVs detected: " + std::to_string(total_sv_count)); - return sv_calls; + // Save to VCF + std::cout << "Saving SVs to VCF..." << std::endl; + this->saveToVCF(whole_genome_sv_calls, input_data, ref_genome, chr_pos_depth_map); } +void SVCaller::findOverlaps(const std::unique_ptr &root, const PrimaryAlignment &query, std::vector &result) +{ + if (!root) return; + + // If overlapping, add to result + if (query.start <= root->region.end && query.end >= root->region.start) + result.push_back(root->qname); + + // If left subtree may have overlaps, search left + if (root->left && root->left->max_end >= query.start) + findOverlaps(root->left, query, result); -// Detect SVs from split read alignments -void SVCaller::detectSVsFromSplitReads(SVData& sv_calls, PrimaryMap& primary_map, SuppMap& supp_map, CNVCaller& cnv_caller) + // Always check the right subtree + findOverlaps(root->right, query, result); +} + +void SVCaller::insert(std::unique_ptr &root, const PrimaryAlignment ®ion, std::string qname) { - // Find split-read SV evidence - int sv_count = 0; - int min_cnv_length = this->input_data->getMinCNVLength(); - for (const auto& entry : primary_map) { - std::string qname = entry.first; - AlignmentData primary_alignment = entry.second; - std::string primary_chr = std::get<0>(primary_alignment); - int32_t primary_start = std::get<1>(primary_alignment); - int32_t primary_end = std::get<2>(primary_alignment); - int32_t primary_query_start = std::get<4>(primary_alignment); - int32_t primary_query_end = std::get<5>(primary_alignment); - std::unordered_map primary_match_map = std::get<6>(primary_alignment); - - // Loop through the supplementary alignments and find gaps and overlaps - AlignmentVector supp_alignments = supp_map[qname]; - for (const auto& supp_alignment : supp_alignments) { - - // Get the supplementary alignment chromosome - std::string supp_chr = std::get<0>(supp_alignment); - - // Skip supplementary alignments that are on a different chromosome - // for now (TODO: Use for identifying trans-chromosomal SVs such as - // translocations) - if (primary_chr != supp_chr) { - continue; - } - int32_t supp_start = std::get<1>(supp_alignment); - int32_t supp_end = std::get<2>(supp_alignment); - int32_t supp_query_start = std::get<4>(supp_alignment); - int32_t supp_query_end = std::get<5>(supp_alignment); - std::unordered_map supp_match_map = std::get<6>(supp_alignment); - - // Determine if there is overlap between the primary and - // supplementary query sequences - int32_t overlap_start = std::max(primary_query_start, supp_query_start); - int32_t overlap_end = std::min(primary_query_end, supp_query_end); - int32_t overlap_length = overlap_end - overlap_start; - if (overlap_length > 0) { - // std::cout << "Overlap detected for read " << qname << std::endl; - // std::cout << "Primary read position: " << primary_query_start << "-" << primary_query_end << std::endl; - // std::cout << "Supplementary read position: " << supp_query_start << "-" << supp_query_end << std::endl; - // std::cout << "Overlap range: " << overlap_start << "-" << overlap_end << std::endl; - // std::cout << "Overlap length: " << overlap_length << std::endl; - // std::cout << "Primary reference position: " << primary_start << "-" << primary_end << std::endl; - // std::cout << "Supplementary reference position: " << supp_start << "-" << supp_end << std::endl; - - // Calculate the mismatch rate for each alignment at the overlap - double primary_mismatch_rate = this->calculateMismatchRate(primary_match_map, overlap_start, overlap_end-1); - double supp_mismatch_rate = this->calculateMismatchRate(supp_match_map, overlap_start, overlap_end-1); - // std::cout << "Primary mismatch rate: " << primary_mismatch_rate << std::endl; - // std::cout << "Supplementary mismatch rate: " << supp_mismatch_rate << std::endl; - - // Trim the overlap from the alignment with the higher mismatch - // rate - if (primary_mismatch_rate > supp_mismatch_rate) { - if (overlap_start == primary_query_start) { - primary_start += overlap_length; - } else if (overlap_end == primary_query_end) { - primary_end -= overlap_length; - } + if (!root) { + root = std::make_unique(region, qname); + return; + } + + if (region.start < root->region.start) + { + insert(root->left, region, qname); + } else { + insert(root->right, region, qname); + } + + // Update max_end + root->max_end = std::max(root->max_end, region.end); +} +// Run copy number predictions on the SVs detected from the split reads +void SVCaller::runSplitReadCopyNumberPredictions(const std::string& chr, std::vector& split_sv_calls, const CNVCaller& cnv_caller, const CHMM& hmm, double mean_chr_cov, const std::vector& pos_depth_map, const InputData& input_data) +{ + std::vector additional_calls; + for (auto& sv_candidate : split_sv_calls) { + + std::tuple result = cnv_caller.runCopyNumberPrediction(chr, hmm, sv_candidate.start, sv_candidate.end, mean_chr_cov, pos_depth_map, input_data); + double supp_lh = std::get<0>(result); + SVType supp_type = std::get<1>(result); + Genotype genotype = std::get<2>(result); + int cn_state = std::get<3>(result); + + // Update the SV type if the predicted type is not unknown + if (supp_type != SVType::UNKNOWN) { + // Update all information if the current SV call is not known and + // there is a predicted CNV type + if (sv_candidate.sv_type == SVType::UNKNOWN && (supp_type == SVType::DEL || supp_type == SVType::DUP)) { + sv_candidate.sv_type = supp_type; + sv_candidate.alt_allele = getSVTypeSymbol(supp_type); // Update the ALT allele format + sv_candidate.aln_type.set(static_cast(SVDataType::HMM)); + sv_candidate.hmm_likelihood = supp_lh; + sv_candidate.genotype = genotype; + sv_candidate.cn_state = cn_state; + + // For predictions with the same type, or LOH, neutral predictions, update the + // prediction information + } else if (sv_candidate.sv_type != SVType::UNKNOWN && (supp_type == sv_candidate.sv_type || supp_type == SVType::LOH || supp_type == SVType::NEUTRAL)) { + sv_candidate.aln_type.set(static_cast(SVDataType::HMM)); + sv_candidate.hmm_likelihood = supp_lh; + sv_candidate.genotype = genotype; + sv_candidate.cn_state = cn_state; + + // Add an additional SV call if the type is different + } else if (sv_candidate.sv_type != SVType::UNKNOWN && (supp_type != sv_candidate.sv_type && (supp_type == SVType::DEL || supp_type == SVType::DUP))) { + // For inversions, just update the alignment type, copy number + // state, and HMM likelihood. Coverage changes for these may be + // predicted as CNVs + if (sv_candidate.sv_type == SVType::INV) { + sv_candidate.aln_type.set(static_cast(SVDataType::HMM)); + sv_candidate.hmm_likelihood = supp_lh; + sv_candidate.genotype = genotype; + sv_candidate.cn_state = cn_state; + // For insertions predicted as duplications, update all information + } else if (sv_candidate.sv_type == SVType::INS && supp_type == SVType::DUP) { + sv_candidate.sv_type = supp_type; + sv_candidate.alt_allele = getSVTypeSymbol(supp_type); // Update the ALT allele format + sv_candidate.aln_type.set(static_cast(SVDataType::HMM)); + sv_candidate.hmm_likelihood = supp_lh; + sv_candidate.genotype = genotype; + sv_candidate.cn_state = cn_state; } else { - if (overlap_start == supp_query_start) { - supp_start += overlap_length; - } else if (overlap_end == supp_query_end) { - supp_end -= overlap_length; - } + // Add a new SV call with the conflicting type + SVCall new_sv_call = sv_candidate; // Copy the original SV call + new_sv_call.sv_type = supp_type; + new_sv_call.alt_allele = getSVTypeSymbol(supp_type); // Update the ALT allele format + new_sv_call.aln_type.set(static_cast(SVDataType::HMM)); + new_sv_call.hmm_likelihood = supp_lh; + new_sv_call.genotype = genotype; + new_sv_call.cn_state = cn_state; + additional_calls.push_back(new_sv_call); } } + } + } - // Gap analysis (deletion or duplication) - if (supp_start < primary_start && supp_end < primary_start) { + // Add the additional SV calls to the original list, replacing any existing + // ones + for (auto& new_sv_call : additional_calls) { + bool found = false; + for (auto& existing_sv_call : split_sv_calls) { + if (existing_sv_call.start == new_sv_call.start && existing_sv_call.end == new_sv_call.end && + existing_sv_call.sv_type == new_sv_call.sv_type) { + // Update the existing SV call with the new one + existing_sv_call = new_sv_call; + found = true; + break; + } + } + if (!found) { + addSVCall(split_sv_calls, new_sv_call); // Add as a new SV call + } + } +} - // Gap with supplementary before primary: - // [supp_start] [supp_end] -- [primary_start] [primary_end] - std::vector> sv_list; // SV candidate and alignment type - // Use the gap ends as the SV endpoints - if (primary_start - supp_end >= min_cnv_length) { - SVCandidate sv_candidate(supp_end+1, primary_start+1, "."); - std::pair sv_pair(sv_candidate, "GAPINNER_A"); - sv_list.push_back(sv_pair); - sv_count++; - } +void SVCaller::saveToVCF(const std::unordered_map>& sv_calls, const InputData &input_data, const ReferenceGenome& ref_genome, const std::unordered_map>& chr_pos_depth_map) const +{ + // Check if an assembly gap file was provided + std::string assembly_gap_file = input_data.getAssemblyGaps(); + std::unordered_map>> assembly_gaps; + if (!assembly_gap_file.empty()) { + std::cout << "Loading assembly gap file: " << assembly_gap_file << std::endl; + // Load the assembly gap file and process it + std::ifstream gap_stream(assembly_gap_file); + if (!gap_stream.is_open()) { + printError("Failed to open assembly gap file: " + assembly_gap_file); + return; + } + std::string line; + while (std::getline(gap_stream, line)) { + // Skip empty lines and comments + if (line.empty() || line[0] == '#') { + continue; + } - // Also use the alignment ends as the SV endpoints - if (primary_end - supp_start >= min_cnv_length) { - SVCandidate sv_candidate(supp_start+1, primary_end+1, "."); - std::pair sv_pair(sv_candidate, "GAPOUTER_A"); - sv_list.push_back(sv_pair); - sv_count++; - } + // Parse the line (assuming tab-separated values) + std::istringstream iss(line); + std::string chr; + uint32_t start, end; + if (!(iss >> chr >> start >> end)) { + printError("Failed to parse assembly gap file line: " + line); + continue; + } + // Add the assembly gap to the map + assembly_gaps[chr].emplace_back(start, end); + } + gap_stream.close(); + std::cout << "Loaded " << assembly_gaps.size() << " assembly gaps." << std::endl; + } + + std::cout << "Creating VCF writer..." << std::endl; + std::string output_dir = input_data.getOutputDir(); + std::string output_vcf = output_dir + "/output.vcf"; + std::cout << "Writing VCF file to " << output_vcf << std::endl; + std::ofstream vcf_stream(output_vcf); + if (!vcf_stream.is_open()) { + printError("Failed to open VCF file for writing."); + return; + } + + std::string sample_name = "SAMPLE"; + + std::cout << "Getting reference genome filepath..." << std::endl; + try { + std::string ref_fp = ref_genome.getFilepath(); + std::cout << "Reference genome filepath: " << ref_fp << std::endl; + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return; + } - // Determine which SV to keep based on HMM prediction likelihood - if (sv_list.size() > 0) { - cnv_caller.updateSVsFromCopyNumberPrediction(sv_calls, sv_list, supp_chr); + // Set the header lines + std::cout << "Getting reference genome header..." << std::endl; + const std::string contig_header = ref_genome.getContigHeader(); + std::cout << "Formatting VCF header..." << std::endl; + std::vector header_lines = { + std::string("##reference=") + ref_genome.getFilepath(), + contig_header, + "##INFO=", + "##INFO=", + "##INFO=", + "##INFO=", + "##INFO=", + "##INFO=", + "##INFO=", + "##INFO=", + "##INFO=", + "##INFO=", + "##INFO=", + "##FILTER=", + "##FILTER=", + "##FILTER=", + "##FORMAT=", + "##FORMAT=", + }; + + std::cout << "Writing VCF header..." << std::endl; + + // Add the file format + std::string file_format = "##fileformat=VCFv4.2"; + vcf_stream << file_format << std::endl; + + // Add date and time + time_t rawtime; + struct tm * timeinfo; + char buffer[80]; + time (&rawtime); + timeinfo = localtime(&rawtime); + strftime(buffer, sizeof(buffer), "%Y%m%d", timeinfo); + vcf_stream << "##fileDate=" << buffer << std::endl; + + // Add source + std::string sv_method = "ContextSV" + std::string(VERSION); + std::string source = "##source=" + sv_method; + vcf_stream << source << std::endl; + + // Loop over the header metadata lines + for (const auto &line : header_lines) { + vcf_stream << line << std::endl; + } + + // Add the header line + std::string header_line = "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE"; + vcf_stream << header_line << std::endl; + std::cout << "Saving SV calls to " << output_vcf << std::endl; + int total_count = 0; + int unclassified_svs = 0; + int filtered_svs = 0; + int assembly_gap_filtered_svs = 0; + for (const auto& pair : sv_calls) { + std::string chr = pair.first; + const std::vector& sv_calls = pair.second; + std::cout << "Saving SV calls for " << chr << "..." << std::endl; + for (const auto& sv_call : sv_calls) { + uint32_t start = sv_call.start; + uint32_t end = sv_call.end; + int sv_length = end - start + 1; + std::string ref_allele = "."; + std::string alt_allele = sv_call.alt_allele; + SVType sv_type = sv_call.sv_type; + std::string genotype = getGenotypeString(sv_call.genotype); + std::string data_type_str = getSVAlignmentTypeString(sv_call.aln_type); + double hmm_likelihood = sv_call.hmm_likelihood; + int cluster_size = sv_call.cluster_size; + std::string filter = "PASS"; + int aln_offset = sv_call.aln_offset; + int cn_state = sv_call.cn_state; + + SVType cn_type = getSVTypeFromCNState(cn_state); + std::string loh = (cn_type == SVType::LOH) ? ";LOH" : ""; + + // If the SV type is unknown, print a warning and skip + if (sv_type == SVType::UNKNOWN || sv_type == SVType::NEUTRAL) { + unclassified_svs += 1; + continue; + } else { + total_count += 1; + } + + // Check if the SV is in an assembly gap (0-based) + if (assembly_gap_file != "") { + bool in_assembly_gap = false; + auto it = assembly_gaps.find(chr); + if (it != assembly_gaps.end()) { + // Check if the deletion overlaps with any assembly gaps + for (const auto& gap : assembly_gaps[chr]) { + // Determine if the deletion overlaps with the + // assembly gap by greater than 50% + uint32_t overlap_start = std::max(start, gap.first + 1); // Convert to 1-based + uint32_t overlap_end = std::min(end, gap.second + 1); // Convert to 1-based + if (overlap_start <= overlap_end) { + // Calculate the overlap length + uint32_t overlap_length = overlap_end - overlap_start + 1; + + // Calculate the percentage of overlap + double overlap_pct = static_cast(overlap_length) / static_cast(sv_length); + if (overlap_pct > 0.2) { + in_assembly_gap = true; + break; + } + } + } + if (in_assembly_gap) { + filter = "AssemblyGap"; + assembly_gap_filtered_svs += 1; + } } - - } else if (supp_start > primary_end && supp_end > primary_end) { - // Gap with supplementary after primary: - // [primary_start] [primary_end] -- [supp_start] [supp_end] - std::vector> sv_list; // SV candidate and alignment type - - // Use the gap ends as the SV endpoints - if (supp_start - primary_end >= min_cnv_length) { - SVCandidate sv_candidate(primary_end+1, supp_start+1, "."); - std::pair sv_pair(sv_candidate, "GAPINNER_B"); - sv_list.push_back(sv_pair); - sv_count++; + } + + // Deletion + if (sv_type == SVType::DEL) { + // Get the deleted sequence from the reference genome, also including the preceding base + uint32_t preceding_pos = (uint32_t) std::max(1, static_cast(start)-1); // Make sure the position is not negative + ref_allele = ref_genome.query(chr, preceding_pos, end); + + // Use the preceding base as the alternate allele + if (ref_allele != "") { + // The alt allele is the preceding base, and the reference + // allele is the deleted sequence including the preceding base + alt_allele = ref_allele.at(0); + } else { + // If the reference allele is empty, use a symbolic allele + ref_allele = "N"; // Convention for DEL + alt_allele = ""; // Symbolic allele + std::cerr << "Warning: Reference allele is empty for deletion at " << chr << ":" << start << "-" << end << std::endl; } - // Also use the alignment ends as the SV endpoints - if (supp_end - primary_start >= min_cnv_length) { - SVCandidate sv_candidate(primary_start+1, supp_end+1, "."); - std::pair sv_pair(sv_candidate, "GAPOUTER_B"); - sv_list.push_back(sv_pair); - sv_count++; + sv_length = -1 * sv_length; // Negative length for deletions + start = preceding_pos; // Update the position to the preceding base + + // Other types (duplications, insertions, inversions) + } else { + + if (sv_type == SVType::INS) { + // Update the position to the preceding base + if (static_cast(start) > 1) { + uint32_t preceding_pos = start - 1; + ref_allele = ref_genome.query(chr, preceding_pos, preceding_pos); + start = preceding_pos; + if (ref_allele != "") { + if (alt_allele != "") { + // Insert the reference allele before the insertion + alt_allele.insert(0, ref_allele); + } + } else { + // If the reference allele is empty, use a symbolic allele + ref_allele = "N"; // Convention for INS + alt_allele = ""; // Symbolic allele + std::cerr << "Warning: Reference allele is empty for insertion at " << chr << ":" << start << "-" << end << std::endl; + } + } else { + // Throw an error if the insertion is at the first position + std::cerr << "Error: Insertion at the first position " << chr << ":" << start << "-" << end << std::endl; + continue; + } + end = start; // Update the end position to the same base + + } else { + ref_allele = "N"; // Convention for INV and DUP } + } - // Determine which SV to keep based on HMM prediction likelihood - if (sv_list.size() > 0) { - cnv_caller.updateSVsFromCopyNumberPrediction(sv_calls, sv_list, supp_chr); + // Fix ambiguous bases in the reference allele + const std::string amb_bases = "RYKMSWBDHV"; // Ambiguous bases + std::bitset<256> amb_bases_bitset; + for (char base : amb_bases) { + amb_bases_bitset.set(base); + amb_bases_bitset.set(std::tolower(base)); + } + for (char& base : ref_allele) { + if (amb_bases_bitset.test(base)) { + base = 'N'; } } + int read_depth = this->getReadDepth(chr_pos_depth_map.at(chr), start); + + // Create the VCF parameter strings + std::string sv_type_str = getSVTypeString(sv_type); + std::string info_str = "END=" + std::to_string(end) + ";SVTYPE=" + sv_type_str + ";SVLEN=" + std::to_string(sv_length) + ";SVMETHOD=" + sv_method + ";ALN=" + data_type_str + ";HMM=" + std::to_string(hmm_likelihood) + ";SUPPORT=" + std::to_string(read_depth) + ";CLUSTER=" + std::to_string(cluster_size) + ";ALNOFFSET=" + std::to_string(aln_offset) + ";CN=" + std::to_string(cn_state) + loh; + std::string format_str = "GT:DP"; + std::string sample_str = genotype + ":" + std::to_string(read_depth); + std::vector samples = {sample_str}; + + // Write the SV call to the file (CHROM, POS, ID, REF, ALT, QUAL, + // FILTER, INFO, FORMAT, SAMPLES) + vcf_stream << chr << "\t" << start << "\t" << "." << "\t" << ref_allele << "\t" << alt_allele << "\t" << "." << "\t" << filter << "\t" << info_str << "\t" << format_str << "\t" << samples[0] << std::endl; } } + vcf_stream.close(); + std::cout << "Saved SV calls to " << output_vcf << std::endl; + + // Print the number of SV calls skipped + std::cout << "Finished writing VCF file. Total records: " << total_count << std::endl; + if (unclassified_svs > 0) { + std::cout << "Total unclassified SVs: " << unclassified_svs << std::endl; + } + printMessage("Total PASS filtered SVs: " + std::to_string(filtered_svs)); + printMessage("Total filtered assembly gaps: " + std::to_string(assembly_gap_filtered_svs)); +} - // Print the number of SVs detected from split-read alignments - if (sv_count > 0) { - std::cout << "Found " << sv_count << " SVs from split-read alignments" << std::endl; +int SVCaller::getReadDepth(const std::vector& pos_depth_map, uint32_t start) const +{ + int read_depth = 0; + try { + read_depth += pos_depth_map.at(start); + } catch (const std::out_of_range& e) { + // Occurs with clipped reads (insertion evidence) that are outside the + // range of the depth map + printError("Warning: Read depth for position " + std::to_string(start) + " is out of range of size " + std::to_string(pos_depth_map.size())); } + + return read_depth; } diff --git a/src/sv_data.cpp b/src/sv_data.cpp deleted file mode 100644 index 96e5a2fd..00000000 --- a/src/sv_data.cpp +++ /dev/null @@ -1,310 +0,0 @@ -#include "sv_data.h" -#include "vcf_writer.h" - -/// @cond -#include -#include -#include -/// @endcond - - -int SVData::add(std::string chr, int64_t start, int64_t end, int sv_type, std::string alt_allele, std::string data_type, std::string genotype, double hmm_likelihood) -{ - // Check if the alternate allele contains ambiguous bases - const std::unordered_set ambiguous_bases = {'R', 'Y', 'W', 'S', 'K', 'M', 'B', 'D', 'H', 'V'}; - for (char c : alt_allele) { - if (ambiguous_bases.count(c) > 0) { - c = 'N'; - } - } - - // Check if the SV candidate already exists in the map - SVCandidate candidate(start, end, alt_allele); - if (this->sv_calls[chr].find(candidate) != this->sv_calls[chr].end()) { - // Update the alignment-based support count (+1) - SVInfo& sv_info = this->sv_calls[chr][candidate]; - sv_info.read_support += 1; - - // Update the SV type if it is unknown - if (sv_info.sv_type == UNKNOWN) { - sv_info.sv_type = sv_type; - } - - // Update the genotype if it is unknown - if (sv_info.genotype == "./.") { - sv_info.genotype = genotype; - } - - // Update the HMM likelihood - if ((sv_info.hmm_likelihood == 0.0) || (hmm_likelihood > sv_info.hmm_likelihood)) { - sv_info.hmm_likelihood = hmm_likelihood; - } - - // Add the alignment type used to call the SV - sv_info.data_type.insert(data_type); - - return 0; // SV call already exists - - // Otherwise, add the SV candidate to the map - } else { - // For insertions and duplications, the SV length is the length of the - // inserted sequence, not including the insertion position - int sv_length = 0; - if (sv_type == INS || sv_type == DUP || sv_type == TANDUP) { - sv_length = end - start; - } else { - // For deletions, the SV length is the length of the deletion - sv_length = end - start + 1; - } - - // Create a new SVInfo object (SV type, alignment support, read depth, data type, SV length, genotype) - SVInfo sv_info(sv_type, 1, 0, data_type, sv_length, genotype, hmm_likelihood); - - // Add the SV candidate to the map - this->sv_calls[chr][candidate] = sv_info; - - return 1; // SV call added - } -} - -void SVData::concatenate(const SVData &sv_data) -{ - // Iterate over the chromosomes in the other SVData object - for (auto const& chr_sv_calls : sv_data.sv_calls) { - std::string chr = chr_sv_calls.first; - - // Iterate over the SV calls in the other SVData object - for (auto const& sv_call : chr_sv_calls.second) { - - // Add the SV call to the map of candidate locations. Since the region - // is unique (per chromosome), there is no need to check if the SV - // candidate already exists in the map. - SVCandidate candidate = sv_call.first; // (start, end, alt_allele) - SVInfo info = sv_call.second; // (sv_type, read_support, data_type, sv_length) - this->sv_calls[chr][candidate] = info; - } - } -} - -void SVData::updateClippedBaseSupport(std::string chr, int64_t pos) -{ - // Update clipped base support - std::pair key(chr, pos); - if (this->clipped_base_support.find(key) != this->clipped_base_support.end()) { - // Update the depth - this->clipped_base_support[key] += 1; - } else { - // Add the depth - this->clipped_base_support[key] = 1; - } -} - -int SVData::getClippedBaseSupport(std::string chr, int64_t pos, int64_t end) -{ - // Clipped base support is the maximum clipped base support at the start - // and end positions - int clipped_base_support = 0; - std::pair pos_key(chr, pos); - - if (pos == end) { - // If the start and end positions are the same, then the clipped base - // support is the same at both positions - clipped_base_support = this->clipped_base_support[pos_key]; - - } else{ - - // Otherwise, get the clipped base support at the start and end - // positions - int pos_support = 0; - int end_support = 0; - std::pair end_key(chr, end); - if (this->clipped_base_support.find(pos_key) != this->clipped_base_support.end()) { - pos_support = this->clipped_base_support[pos_key]; - } - if (this->clipped_base_support.find(end_key) != this->clipped_base_support.end()) { - end_support = this->clipped_base_support[end_key]; - } - clipped_base_support = std::max(pos_support, end_support); - } - - return clipped_base_support; -} - -void SVData::saveToVCF(FASTAQuery& ref_genome, std::string output_dir) -{ - // Create a VCF writer - std::string output_vcf = output_dir + "/output.vcf"; - VcfWriter vcf_writer(output_vcf); - - // Set the sample name - std::string sample_name = "SAMPLE"; - - // Set the header lines - std::vector header_lines = { - std::string("##reference=") + ref_genome.getFilepath(), - ref_genome.getContigHeader(), - "##INFO=", - "##INFO=", - "##INFO=", - "##INFO=", - "##INFO=", - "##INFO=", - "##INFO=", - "##INFO=", - "##INFO=", - "##FILTER=", - "##FILTER=", - "##FORMAT=", - "##FORMAT=" - }; - - // Write the header lines - vcf_writer.writeHeader(header_lines); - - // Save the SV calls - std::cout << "Saving SV calls to " << output_vcf << std::endl; - std::string sv_method = "CONTEXTSVv0.1"; - int num_sv_calls = this->totalCalls(); - int skip_count = 0; - std::set chrs = this->getChromosomes(); - for (auto const& chr : chrs) { - if (this->sv_calls.find(chr) == this->sv_calls.end()) { - continue; - } - std::cout << "Saving SV calls for " << chr << " (" << this->sv_calls[chr].size() << " SV calls)..." << std::endl; - for (auto const& sv_call : this->sv_calls[chr]) { - - // Get the SV candidate and SV info - SVCandidate candidate = sv_call.first; - SVInfo info = sv_call.second; - int sv_type = info.sv_type; - int read_support = info.read_support; - int read_depth = info.read_depth; - int sv_length = info.sv_length; - std::set data_type = info.data_type; - std::string genotype = info.genotype; - double hmm_likelihood = info.hmm_likelihood; - - // Convert the data type set to a string - std::string data_type_str = ""; - for (auto const& type : data_type) { - data_type_str += type + ","; - } - - // Get the CHROM, POS, END, and ALT - int64_t pos = std::get<0>(candidate); - int64_t end = std::get<1>(candidate); - - // If the SV type is unknown, skip it - if (sv_type == UNKNOWN) { - skip_count += 1; - continue; - } - - // Process by SV type - std::string ref_allele = "."; - std::string alt_allele = "."; - std::string repeat_type = "NA"; - - // Deletion - if (sv_type == DEL) { - // Get the deleted sequence from the reference genome, also including the preceding base - int64_t preceding_pos = (int64_t) std::max(1, (int) pos-1); // Make sure the position is not negative - ref_allele = ref_genome.query(chr, preceding_pos, end); - - // Use the preceding base as the alternate allele - if (ref_allele != "") { - alt_allele = ref_allele.at(0); - } else { - alt_allele = ""; // Use symbolic allele for imprecise deletions - std::cerr << "Warning: Reference allele is empty for deletion at " << chr << ":" << pos << "-" << end << std::endl; - } - - // Make the SV length negative - sv_length = -1 * sv_length; - - // Update the position - pos = preceding_pos; - - // Duplications and insertions - } else if (sv_type == INS || sv_type == DUP || sv_type == TANDUP) { - // Use the preceding base as the reference allele - int64_t preceding_pos = (int64_t) std::max(1, (int) pos-1); // Make sure the position is not negative - ref_allele = ref_genome.query(chr, preceding_pos, preceding_pos); - - // Format novel insertions - if (sv_type == INS) { - // Use the insertion sequence as the alternate allele - alt_allele = std::get<2>(candidate); - - // Insert the reference base into the alternate allele - alt_allele.insert(0, ref_allele); - - // Update the position - pos = preceding_pos; - - // Update the end position to the start position to change from - // query to reference coordinates for insertions - end = pos; - } else if (sv_type == DUP) { - // Use a symbolic allele for duplications - alt_allele = ""; - - // Set the repeat type as an interspersed duplication - repeat_type = "INTERSPERSED"; - } else if (sv_type == TANDUP) { - // Use a symbolic allele for tandem duplications - alt_allele = ""; - - // Set the repeat type - repeat_type = "TANDEM"; - } - } - - // Create the VCF parameter strings - int clipped_base_support = this->getClippedBaseSupport(chr, pos, end); - std::string sv_type_str = this->sv_type_map[sv_type]; - std::string info_str = "END=" + std::to_string(end) + ";SVTYPE=" + sv_type_str + \ - ";SVLEN=" + std::to_string(sv_length) + ";SUPPORT=" + std::to_string(read_support) + \ - ";SVMETHOD=" + sv_method + ";ALN=" + data_type_str + ";CLIPSUP=" + std::to_string(clipped_base_support) + \ - ";REPTYPE=" + repeat_type + ";HMM=" + std::to_string(hmm_likelihood); - - std::string format_str = "GT:DP"; - std::string sample_str = genotype + ":" + std::to_string(read_depth); - std::vector samples = {sample_str}; - - // Write the SV call to the file (CHROM, POS, ID, REF, ALT, QUAL, FILTER, INFO, FORMAT, SAMPLES) - vcf_writer.writeRecord(chr, pos, ".", ref_allele, alt_allele, ".", "PASS", info_str, format_str, samples); - } - } - - // Print the number of SV calls skipped - std::cout << "Skipped " << skip_count << " of " << num_sv_calls << " SV calls because the SV type is unknown" << std::endl; - - // Close the output stream - vcf_writer.close(); -} - -std::map& SVData::getChromosomeSVs(std::string chr) -{ - return this->sv_calls[chr]; -} - -std::set SVData::getChromosomes() -{ - std::set chromosomes; - for (auto const& sv_call : this->sv_calls) { - chromosomes.insert(sv_call.first); - } - return chromosomes; -} - -int SVData::totalCalls() -{ - int sv_calls = 0; - for (auto const& sv_call : this->sv_calls) { - sv_calls += sv_call.second.size(); - } - - return sv_calls; -} diff --git a/src/sv_object.cpp b/src/sv_object.cpp new file mode 100644 index 00000000..d6f46b82 --- /dev/null +++ b/src/sv_object.cpp @@ -0,0 +1,350 @@ +#include "sv_object.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dbscan.h" +#include "utils.h" +#include "debug.h" + +bool SVCall::operator<(const SVCall & other) const +{ + return start < other.start || (start == other.start && end < other.end); +} + +void addSVCall(std::vector& sv_calls, SVCall& sv_call) +{ + // Check if the SV call is valid + if (sv_call.start > sv_call.end) { + printError("ERROR: Invalid SV call at position " + std::to_string(sv_call.start) + "-" + std::to_string(sv_call.end) + " from data type " + getSVAlignmentTypeString(sv_call.aln_type)); + return; + } + + // Insert the SV call in sorted order + auto it = std::lower_bound(sv_calls.begin(), sv_calls.end(), sv_call); + sv_calls.insert(it, sv_call); +} + +uint32_t getSVCount(const std::vector& sv_calls) +{ + return (uint32_t) sv_calls.size(); +} + +void concatenateSVCalls(std::vector &target, const std::vector& source) +{ + target.insert(target.end(), source.begin(), source.end()); +} + +void mergeSVs(std::vector& sv_calls, double epsilon, int min_pts, bool keep_noise, const std::string& json_filepath) +{ + printMessage("Merging SVs with DBSCAN, eps=" + std::to_string(epsilon) + ", min_pts=" + std::to_string(min_pts)); + + if (sv_calls.size() < 2) { + return; + } + + // Set this to print cluster information for a specific SV call for debugging + // This is useful for debugging purposes to see how the SVs are merged + bool debug_mode = false; + SVType debug_sv_type = SVType::INV; + + // Cluster SVs using DBSCAN for each SV type + int initial_size = sv_calls.size(); + std::vector merged_sv_calls; + DBSCAN dbscan(epsilon, min_pts); + for ( const auto& sv_type : { + SVType::DEL, + SVType::DUP, + SVType::INV, + SVType::INS, + SVType::BND, + }) + { + // Skip if not the debug SV type + if (debug_mode && (sv_type != debug_sv_type)) { + DEBUG_PRINT("DEBUG: Skipping SV type " + getSVTypeString(sv_type) + " for debug mode"); + continue; + } + + DEBUG_PRINT("Merging SV type: " + getSVTypeString(sv_type) + " (epsilon=" + std::to_string(epsilon) + ", min_pts=" + std::to_string(min_pts) + ", num SVs=" + std::to_string(sv_calls.size()) + ")"); + std::vector merged_sv_type_calls; + + // Create a vector of SV calls for the current SV type and size interval + std::vector sv_type_calls; + std::copy_if(sv_calls.begin(), sv_calls.end(), std::back_inserter(sv_type_calls), [sv_type](const SVCall& sv_call) { + return sv_call.sv_type == sv_type; + }); + + if (sv_type_calls.size() < 2) { + // Add all unclustered points to the merged list + for (const auto& sv_call : sv_type_calls) { + SVCall noise_sv_call = sv_call; + merged_sv_type_calls.push_back(noise_sv_call); + } + continue; + } + + dbscan.fit(sv_type_calls); + + // Create a map of cluster IDs to SV calls + const std::vector& clusters = dbscan.getClusters(); + std::map> cluster_map; // Cluster ID to SV calls + for (size_t i = 0; i < clusters.size(); ++i) { + cluster_map[clusters[i]].push_back(sv_type_calls[i]); + } + + // Save clusters to JSON if requested + if (!json_filepath.empty()) { + // Create the directory if it doesn't exist + std::string dir = json_filepath.substr(0, json_filepath.find_last_of('/')); + if (!fileExists(dir)) { + std::string command = "mkdir -p " + dir; + system(command.c_str()); + } + // Save the clusters to a JSON file + // Prepend the SV type before the extension + // Remove the file extension from the JSON filename + std::string json_filename_no_ext = json_filepath.substr(0, json_filepath.find_last_of('.')); + std::string json_filename = json_filename_no_ext + "_" + getSVTypeString(sv_type) + ".json"; + saveClustersToJSON(json_filename, cluster_map); + } + + // Merge SVs in each cluster + int cluster_count = 0; + for (auto& cluster : cluster_map) { + int cluster_id = cluster.first; + std::vector& cluster_sv_calls = cluster.second; + + // Continue if fewer than 2 SV calls in the cluster (due to CIGARCLIP filter) + if (cluster_sv_calls.size() < 2) { + continue; + } + + // Add unmerged SV calls + if (cluster_id < 0 && keep_noise) { + + // Add all unclustered points to the merged list + for (const auto& sv_call : cluster_sv_calls) { + SVCall noise_sv_call = sv_call; + merged_sv_type_calls.push_back(noise_sv_call); + + // Print the added SV calls if >10 kb and the debug SV type + if (debug_mode && noise_sv_call.sv_type == debug_sv_type && (noise_sv_call.end - noise_sv_call.start) > 10000) { + DEBUG_PRINT("DEBUG: Adding noise SV call at " + std::to_string(noise_sv_call.start) + "-" + std::to_string(noise_sv_call.end) + + ", type: " + getSVTypeString(noise_sv_call.sv_type) + + ", length: " + std::to_string(noise_sv_call.end - noise_sv_call.start) + + ", cluster size: " + std::to_string(noise_sv_call.cluster_size) + + ", likelihood: " + std::to_string(noise_sv_call.hmm_likelihood)); + } + } + + // Merge clustered SV calls + } else { + + // ---------------------------- + // HMM-BASED MERGING + // ---------------------------- + + // Check if any SV has a non-zero likelihood + bool has_nonzero_likelihood = false; + if (cluster_sv_calls.size() > 0) { + for (const auto& sv_call : cluster_sv_calls) { + + // Check if any SV has a non-zero likelihood + if (sv_call.hmm_likelihood != 0.0) { + has_nonzero_likelihood = true; + break; + } + } + } + + SVCall merged_sv_call = cluster_sv_calls[0]; + if (has_nonzero_likelihood) { + // Choose the SV with the highest cluster size of all SVs + // with non-zero likelihood (if equal, choose the larger SV) + std::sort(cluster_sv_calls.begin(), cluster_sv_calls.end(), [](const SVCall& a, const SVCall& b) { + return a.cluster_size > b.cluster_size || (a.cluster_size == b.cluster_size && a.end - a.start > b.end - b.start); + }); + auto it = std::find_if(cluster_sv_calls.begin(), cluster_sv_calls.end(), [](const SVCall& sv_call) { + return sv_call.hmm_likelihood != 0.0; + }); + + // Add SV call + merged_sv_call = *it; + merged_sv_type_calls.push_back(merged_sv_call); + + // ---------------------------- + // CIGAR-BASED MERGING + // ---------------------------- + + } else { + // Use the median length SV of the top 10% of the cluster + // (shorter reads are often noise) + std::sort(cluster_sv_calls.begin(), cluster_sv_calls.end(), [](const SVCall& a, const SVCall& b) { + return (a.end - a.start) > (b.end - b.start); + }); + + // Print the added SV calls if >10 kb and the debug SV type + if (debug_mode && sv_type == debug_sv_type) { + DEBUG_PRINT("DEBUG: Cluster " + std::to_string(cluster_id) + " with " + std::to_string(cluster_sv_calls.size()) + " SV calls (length sorted):"); + for (const auto& sv_call : cluster_sv_calls) { + if ((sv_call.end - sv_call.start) > 10000) { + DEBUG_PRINT("DEBUG: SV call at " + std::to_string(sv_call.start) + "-" + std::to_string(sv_call.end) + + ", type: " + getSVTypeString(sv_call.sv_type) + + ", length: " + std::to_string(sv_call.end - sv_call.start) + + ", cluster size: " + std::to_string(sv_call.cluster_size) + + ", likelihood: " + std::to_string(sv_call.hmm_likelihood)); + } + } + } + + // Get the top % of the cluster + double top_pct = 0.2; + size_t top_pct_size = std::max(1, (int) (cluster_sv_calls.size() * top_pct)); + std::vector top_pct_calls(cluster_sv_calls.begin(), cluster_sv_calls.begin() + top_pct_size); + + // Print the added SV calls if >10 kb and the debug SV type + if (debug_mode && sv_type == debug_sv_type) { + DEBUG_PRINT("DEBUG: Top " + std::to_string((int)(top_pct * 100)) + "% of cluster " + std::to_string(cluster_id) + " with " + + std::to_string(top_pct_calls.size()) + " SV calls (length sorted):"); + for (const auto& sv_call : top_pct_calls) { + if ((sv_call.end - sv_call.start) > 10000) { + DEBUG_PRINT("DEBUG: SV call at " + std::to_string(sv_call.start) + "-" + std::to_string(sv_call.end) + + ", type: " + getSVTypeString(sv_call.sv_type) + + ", length: " + std::to_string(sv_call.end - sv_call.start) + + ", cluster size: " + std::to_string(sv_call.cluster_size) + + ", likelihood: " + std::to_string(sv_call.hmm_likelihood)); + } + } + } + + // Get the median SV for the top % of the cluster + size_t median_index = top_pct_calls.size() / 2; + merged_sv_call = top_pct_calls[median_index]; + + // Print the merged SV call + if (debug_mode && sv_type == debug_sv_type) { + DEBUG_PRINT("DEBUG: Merged SV call at " + std::to_string(merged_sv_call.start) + "-" + std::to_string(merged_sv_call.end) + + ", type: " + getSVTypeString(merged_sv_call.sv_type) + + ", length: " + std::to_string(merged_sv_call.end - merged_sv_call.start) + + ", cluster size: " + std::to_string(merged_sv_call.cluster_size) + + ", likelihood: " + std::to_string(merged_sv_call.hmm_likelihood)); + } + + // Add SV call + merged_sv_call.cluster_size = (int) cluster_sv_calls.size(); + merged_sv_type_calls.push_back(merged_sv_call); + } + cluster_count++; + } + } + DEBUG_PRINT("Merged " + std::to_string(cluster_count) + " clusters of " + getSVTypeString(sv_type) + ", found " + std::to_string(merged_sv_type_calls.size()) + " merged SV calls"); + + // Print SV call start, end, type, and length for debugging if > 10 kb + if (debug_mode && sv_type == debug_sv_type) { + DEBUG_PRINT("DEBUG: Merged SV calls for " + getSVTypeString(sv_type) + ":"); + for (const auto& sv_call : merged_sv_type_calls) { + if ((sv_call.end - sv_call.start) > 10000) { + DEBUG_PRINT("DEBUG: SV call at " + std::to_string(sv_call.start) + "-" + std::to_string(sv_call.end) + + ", type: " + getSVTypeString(sv_call.sv_type) + + ", length: " + std::to_string(sv_call.end - sv_call.start) + + ", cluster size: " + std::to_string(sv_call.cluster_size) + + ", likelihood: " + std::to_string(sv_call.hmm_likelihood)); + } + } + } + merged_sv_calls.insert(merged_sv_calls.end(), + merged_sv_type_calls.begin(), merged_sv_type_calls.end()); + } + sv_calls = std::move(merged_sv_calls); // Replace with filtered list + int updated_size = sv_calls.size(); + printMessage("Merged " + std::to_string(initial_size) + " SV calls into " + std::to_string(updated_size) + " SV calls"); +} + +void saveClustersToJSON(const std::string &filename, const std::map> &clusters) +{ + // Check if the filename is empty + if (filename.empty()) { + printError("ERROR: Filename is empty"); + return; + } + + // Remove the file if it already exists + if (fileExists(filename)) { + std::remove(filename.c_str()); + } + + // Open the JSON file for writing + std::ofstream json_file(filename); + if (!json_file.is_open()) { + printError("ERROR: Unable to open JSON file for writing: " + filename); + return; + } + json_file << "{\n"; + json_file << " \"clusters\": [\n"; + size_t count = 0; + for (const auto& [cluster_id, cluster] : clusters) { + if (cluster_id < 0) { + continue; // Skip noise points + } + + json_file << " {\n"; + json_file << " \"cluster_id\": " << cluster_id << ",\n"; + json_file << " \"cluster_size\": " << cluster.size() << ",\n"; + json_file << " \"sv_calls\": [\n"; + for (size_t j = 0; j < cluster.size(); ++j) { + const auto& sv_call = cluster[j]; + json_file << " {\n"; + json_file << " \"start\": " << sv_call.start << ",\n"; + json_file << " \"end\": " << sv_call.end << "\n"; + json_file << " }" << (j < cluster.size() - 1 ? "," : "") << "\n"; + } + json_file << " ]\n"; + count++; + if (count < clusters.size() - 1) { + json_file << " }," << "\n"; + } else { + json_file << " }\n"; + printMessage("JSON found last cluster: " + std::to_string(cluster_id)); + } + } + json_file << " ]\n"; + json_file << "}\n"; + json_file.close(); + printMessage("Saved clusters to JSON file: " + filename); +} + +void mergeDuplicateSVs(std::vector &sv_calls) +{ + int initial_size = sv_calls.size(); + std::vector combined_sv_calls; + + // Sort first by start position, then by SV type + std::sort(sv_calls.begin(), sv_calls.end(), [](const SVCall& a, const SVCall& b) { + return std::tie(a.start, a.sv_type) < std::tie(b.start, b.sv_type); + }); + for (size_t i = 0; i < sv_calls.size(); i++) { + SVCall& sv_call = sv_calls[i]; + + // Merge cluster sizes if start and end positions are the same + if (i > 0 && sv_call.start == sv_calls[i - 1].start && sv_call.end == sv_calls[i - 1].end) { + // Combine the cluster sizes + sv_call.cluster_size += sv_calls[i - 1].cluster_size; + combined_sv_calls.back() = sv_call; + } else { + combined_sv_calls.push_back(sv_call); + } + } + int merge_count = initial_size - combined_sv_calls.size(); + sv_calls = std::move(combined_sv_calls); // Replace with filtered list + if (merge_count > 0) { + printMessage("Merged " + std::to_string(merge_count) + " SV candidates with identical start and end positions"); + } +} diff --git a/src/swig_interface.cpp b/src/swig_interface.cpp index 87334fec..76eb2151 100644 --- a/src/swig_interface.cpp +++ b/src/swig_interface.cpp @@ -1,7 +1,3 @@ -// -// Created by jperdomo on 1/8/2023. -// - #include "swig_interface.h" #include "contextsv.h" @@ -11,14 +7,13 @@ // Run the CLI with the given parameters -int run(InputData input_data) +int run(const InputData& input_data) { - // Run ContextSV - ContextSV contextsv(input_data); + ContextSV contextsv; try { - contextsv.run(); + contextsv.run(input_data); } catch (std::exception& e) diff --git a/src/utils.cpp b/src/utils.cpp index 62088fe2..a27263b7 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,9 +1,13 @@ #include "utils.h" /// @cond +#include // getrusage +#include #include #include #include +#include +#include /// @endcond @@ -13,10 +17,7 @@ std::mutex print_mtx; // Print a progress bar void printProgress(int progress, int total) { - // Get the percentage float percent = (float)progress / (float)total * 100.0; - - // Get the number of hashes int num_hashes = (int)(percent / 2.0); // Print the progress bar @@ -99,4 +100,44 @@ std::string getElapsedTime(std::chrono::high_resolution_clock::time_point start, int seconds = elapsed.count() - (hours * 3600) - (minutes * 60); std::string elapsed_time = std::to_string(hours) + ":" + std::to_string(minutes) + ":" + std::to_string(seconds); return elapsed_time; -} \ No newline at end of file +} + +// Function to remove the 'chr' prefix from chromosome names +std::string removeChrPrefix(std::string chr) +{ + if (chr.find("chr") != std::string::npos) { + return chr.substr(3); + } + return chr; +} + +void printMemoryUsage(const std::string& functionName) { + struct rusage usage; + getrusage(RUSAGE_SELF, &usage); + + // Convert from KB to GB + double mem_usage_gb = (double)usage.ru_maxrss / 1024.0 / 1024.0; + std::cout << functionName << " memory usage: " + << std::fixed << std::setprecision(2) << mem_usage_gb << " GB" << std::endl; +} + +bool fileExists(const std::string &filepath) +{ + std::ifstream file(filepath); + return file.is_open(); +} + +bool isFileEmpty(const std::string &filepath) +{ + return std::filesystem::file_size(filepath) == 0; +} + +void closeJSON(const std::string &filepath) +{ + std::ofstream + json_file(filepath, std::ios::app); + + json_file << "}\n"; // Close the last JSON object + json_file << "]"; // Close the JSON array + json_file.close(); +} diff --git a/src/vcf_writer.cpp b/src/vcf_writer.cpp deleted file mode 100644 index eaf41ea5..00000000 --- a/src/vcf_writer.cpp +++ /dev/null @@ -1,63 +0,0 @@ -#include "vcf_writer.h" - -/// @cond -#include -#include -/// @endcond - -VcfWriter::VcfWriter(const std::string &filename) -{ - // Remove the file if it already exists - std::remove(filename.c_str()); - - // Open the VCF file - this->file_stream.open(filename); - if (!this->file_stream.is_open()) { - std::cerr << "Error: Unable to open " << filename << std::endl; - exit(1); - } -} - -void VcfWriter::writeHeader(const std::vector &headerLines) -{ - // Add the file format - std::string file_format = "##fileformat=VCFv4.2"; - this->file_stream << file_format << std::endl; - - // Add date and time - time_t rawtime; - struct tm * timeinfo; - char buffer[80]; - time (&rawtime); - timeinfo = localtime(&rawtime); - strftime(buffer, sizeof(buffer), "%Y%m%d", timeinfo); - file_stream << "##fileDate=" << buffer << std::endl; - - // Add source - std::string source = "##source=ContexSV"; - this->file_stream << source << std::endl; - - // Loop over the header metadata lines - for (auto &line : headerLines) { - this->file_stream << line << std::endl; - } - - // Add the header line - std::string header_line = "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE"; - this->file_stream << header_line << std::endl; - - // Flush the stream to ensure that the header is written - this->file_stream.flush(); -} - -void VcfWriter::writeRecord(const std::string &chrom, int pos, const std::string &id, const std::string &ref, const std::string &alt, const std::string &qual, const std::string &filter, const std::string &info, const std::string &format, const std::vector &samples) -{ - // Write a record to the VCF file - this->file_stream << chrom << "\t" << pos << "\t" << id << "\t" << ref << "\t" << alt << "\t" << qual << "\t" << filter << "\t" << info << "\t" << format << "\t" << samples[0] << std::endl; -} - -void VcfWriter::close() -{ - // Close the VCF file - this->file_stream.close(); -} diff --git a/tests/test_general.py b/tests/test_general.py index 499805c8..ff65faba 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -40,7 +40,6 @@ def test_run(): # Set input parameters. input_data = contextsv.InputData() - input_data.setShortReadBam(TEST_BAM_FILE) input_data.setLongReadBam(TEST_BAM_FILE) input_data.setRefGenome(TEST_REF_FILE) input_data.setSNPFilepath(TEST_SNPS_FILE) @@ -64,11 +63,11 @@ def test_run(): # Check that the VCF file has the correct number of lines. with open(output_file, 'r', encoding='utf-8') as f: - assert len(f.readlines()) == 21 + assert len(f.readlines()) == 22 # Check that the VCF file has the correct header, and the correct # VCF CHROM, POS, and INFO fields in the next 2 lines. - header_line = 18 + header_line = 17 with open(output_file, 'r', encoding='utf-8') as f: for i, line in enumerate(f): if i == header_line: @@ -78,11 +77,11 @@ def test_run(): fields = line.strip().split('\t') assert fields[0] == "21" assert fields[1] == "14458394" - assert fields[7] == "END=14458394;SVTYPE=INS;SVLEN=1341;SUPPORT=1;SVMETHOD=CONTEXTSVv0.1;ALN=CIGARINS,;CLIPSUP=0;REPTYPE=NA;HMM=0.000000" + assert fields[7] == "END=14458394;SVTYPE=INS;SVLEN=1344;SUPPORT=1;SVMETHOD=CONTEXTSVv0.1;ALN=CIGARINS;CLIPSUP=0;REPTYPE=NA;HMM=0.000000" elif i == header_line + 2: fields = line.strip().split('\t') assert fields[0] == "21" - assert fields[1] == "14458394" - assert fields[7] == "END=14458394;SVTYPE=INS;SVLEN=1344;SUPPORT=1;SVMETHOD=CONTEXTSVv0.1;ALN=CIGARINS,;CLIPSUP=0;REPTYPE=NA;HMM=0.000000" + assert fields[1] == "14502888" + assert fields[7] == "END=14502953;SVTYPE=BOUNDARY;SVLEN=65;SUPPORT=1;SVMETHOD=CONTEXTSVv0.1;ALN=BOUNDARY;CLIPSUP=0;REPTYPE=NA;HMM=-4.606171" break \ No newline at end of file