Skip to content

Commit

Permalink
[WIP] pushing to share. Partial strandedness implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
a-frantz committed Feb 8, 2024
1 parent d5b5b5c commit fe6eabd
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 25 deletions.
6 changes: 6 additions & 0 deletions src/derive/command/strandedness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ pub fn derive(args: DeriveStrandednessArgs) -> anyhow::Result<()> {
bail!("No gene records matched criteria. Check your GFF file and `--gene-feature-name` and `--all-genes` options.");
}
if exon_records.is_empty() {
// TODO move this below?
bail!("No exon records matched criteria. Check your GFF file and `--exon-feature-name` option.");
}

Expand All @@ -148,6 +149,11 @@ pub fn derive(args: DeriveStrandednessArgs) -> anyhow::Result<()> {
let stop: usize = record.end().into();
let strand = record.strand();

if strand != gff::record::Strand::Forward && strand != gff::record::Strand::Reverse {
exon_metrics.bad_strand += 1;
continue;
}

exon_intervals.entry(seq_name).or_default().push(Interval {
start,
stop,
Expand Down
257 changes: 232 additions & 25 deletions src/derive/strandedness/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@ use noodles::bam;
use noodles::core::{Position, Region};
use noodles::gff;
use noodles::sam;
use noodles::sam::record::data::field::Tag;
use rand::Rng;
use rust_lapper::{Interval, Lapper};
use rust_lapper::Lapper;
use serde::Serialize;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;

use crate::utils::read_groups::{validate_read_group_info, OVERALL, UNKNOWN_READ_GROUP};

const STRANDED_THRESHOLD: f64 = 0.80;
const UNSTRANDED_THRESHOLD: f64 = 0.40;

/// General gene metrics that are tallied as a part of the
/// strandedness subcommand.
#[derive(Clone, Default, Serialize)]
#[derive(Clone, Default, Serialize, Debug)]
pub struct GeneRecordMetrics {
/// The total number of genes found in the GFF.
pub total: usize,
Expand All @@ -21,9 +29,12 @@ pub struct GeneRecordMetrics {
/// If --all-genes is set this will not be tallied.
pub protein_coding: usize,

/// The number of genes tested.
pub tested: usize,

/// The number of genes which were discarded due to having
/// exons on both strands.
pub exons_on_both_strands: usize,
/// an unknown/invalid strand OR with exons on both strands.
pub bad_strands: usize,

/// The number of genes which were discarded due to not having
/// enough reads.
Expand All @@ -32,15 +43,18 @@ pub struct GeneRecordMetrics {

/// General exon metrics that are tallied as a part of the
/// strandedness subcommand.
#[derive(Clone, Default, Serialize)]
#[derive(Clone, Default, Serialize, Debug)]
pub struct ExonRecordMetrics {
/// The total number of exons found in the GFF.
pub total: usize,

/// The number of exons discarded due to having an unknown/invalid strand.
pub bad_strand: usize,
}

/// General read record metrics that are tallied as a part of the
/// strandedness subcommand.
#[derive(Clone, Default, Serialize)]
#[derive(Clone, Default, Serialize, Debug)]
pub struct ReadRecordMetrics {
/// The number of records that have been filtered because of their flags.
/// (i.e. they were qc_fail, duplicates, secondary, or supplementary)
Expand All @@ -64,12 +78,6 @@ pub struct ReadRecordMetrics {
/// Struct for tracking count results.
#[derive(Clone, Default)]
struct Counts {
/// The number of reads determined to be Paired-End.
paired_end_reads: usize,

/// The number of reads determined to be Single-End.
single_end_reads: usize,

/// The number of reads that are evidence of Forward Strandedness.
forward: usize,

Expand Down Expand Up @@ -157,6 +165,18 @@ pub struct DerivedStrandednessResult {
/// One for each read group in the BAM,
/// and potentially one for any reads with an unknown read group.
pub read_groups: Vec<ReadGroupDerivedStrandednessResult>,

/// General read record metrics that are tallied as a part of the
/// strandedness subcommand.
pub read_metrics: ReadRecordMetrics,

/// General gene metrics that are tallied as a part of the
/// strandedness subcommand.
pub gene_metrics: GeneRecordMetrics,

/// General exon metrics that are tallied as a part of the
/// strandedness subcommand.
pub exon_metrics: ExonRecordMetrics,
}

impl DerivedStrandednessResult {
Expand All @@ -167,6 +187,9 @@ impl DerivedStrandednessResult {
forward: usize,
reverse: usize,
read_groups: Vec<ReadGroupDerivedStrandednessResult>,
read_metrics: ReadRecordMetrics,
gene_metrics: GeneRecordMetrics,
exon_metrics: ExonRecordMetrics,
) -> Self {
DerivedStrandednessResult {
succeeded,
Expand All @@ -177,6 +200,59 @@ impl DerivedStrandednessResult {
forward_pct: (forward as f64 / (forward + reverse) as f64) * 100.0,
reverse_pct: (reverse as f64 / (forward + reverse) as f64) * 100.0,
read_groups,
read_metrics,
gene_metrics,
exon_metrics,
}
}
}

#[derive(Clone, Copy, Debug)]
enum Strand {
Forward,
Reverse,
}

impl From<sam::record::Flags> for Strand {
fn from(flags: sam::record::Flags) -> Self {
if flags.is_reverse_complemented() {
Self::Reverse
} else {
Self::Forward
}
}
}

impl TryFrom<gff::record::Strand> for Strand {
type Error = ();

fn try_from(strand: gff::record::Strand) -> Result<Self, Self::Error> {
match strand {
gff::record::Strand::Forward => Ok(Self::Forward),
gff::record::Strand::Reverse => Ok(Self::Reverse),
_ => Err(()),
}
}
}

#[derive(Clone, Copy, Debug)]
enum SegmentOrder {
First,
Last,
}

impl TryFrom<sam::record::Flags> for SegmentOrder {
type Error = ();

fn try_from(flags: sam::record::Flags) -> Result<Self, Self::Error> {
if !flags.is_segmented() {
Err(())
} else if flags.is_first_segment() && !flags.is_last_segment() {
Ok(SegmentOrder::First)
} else if flags.is_last_segment() && !flags.is_first_segment() {
Ok(SegmentOrder::Last)
} else {
Err(())
}
}
}
Expand Down Expand Up @@ -216,6 +292,9 @@ fn disqualify_gene(
exons: &HashMap<&str, Lapper<usize, gff::record::Strand>>,
) -> bool {
let gene_strand = gene.strand();
if gene_strand != gff::record::Strand::Forward && gene_strand != gff::record::Strand::Reverse {
return true;
}
let mut all_on_same_strand = true;
let mut at_least_one_exon = false;

Expand Down Expand Up @@ -291,12 +370,101 @@ fn query_filtered_reads(
return filtered_reads;
}

// fn classify_read(
// read: &sam::alignment::Record,
// gene_strand: &gff::record::Strand,
// ) -> {
// // TODO
// }
fn classify_read(
read: &sam::alignment::Record,
gene_strand: &gff::record::Strand,
all_counts: &mut HashMap<&str, Counts>,
read_metrics: &mut ReadRecordMetrics,
) {
let gene_strand = Strand::try_from(gene_strand).unwrap();

let read_group = match read.data().get(Tag::ReadGroup) {
Some(rg) => rg.as_str().unwrap_or_else(|| {
tracing::warn!("Could not parse a RG tag from a read in the file.");
UNKNOWN_READ_GROUP.as_str()
}),
None => UNKNOWN_READ_GROUP.as_str(),
};

let overall_counts = all_counts
.entry(OVERALL.as_str())
.or_insert(Counts::default());
let rg_counts = all_counts.entry(read_group).or_insert(Counts::default());

let read_strand = Strand::from(read.flags());
if read.flags().is_segmented() {
read_metrics.paired_end_reads += 1;

let order = SegmentOrder::try_from(read.flags()).unwrap();

match (order, read_strand, gene_strand) {
(SegmentOrder::First, Strand::Forward, Strand::Forward)
| (SegmentOrder::First, Strand::Reverse, Strand::Reverse)
| (SegmentOrder::Last, Strand::Forward, Strand::Reverse)
| (SegmentOrder::Last, Strand::Reverse, Strand::Forward) => {
rg_counts.forward += 1;
overall_counts.forward += 1;
}
(SegmentOrder::First, Strand::Forward, Strand::Reverse)
| (SegmentOrder::First, Strand::Reverse, Strand::Forward)
| (SegmentOrder::Last, Strand::Forward, Strand::Forward)
| (SegmentOrder::Last, Strand::Reverse, Strand::Reverse) => {
rg_counts.reverse += 1;
overall_counts.reverse += 1;
}
}
} else {
read_metrics.single_end_reads += 1;

match (read_strand, gene_strand) {
(Strand::Forward, Strand::Forward) | (Strand::Reverse, Strand::Reverse) => {
rg_counts.forward += 1;
overall_counts.forward += 1;
}
(Strand::Forward, Strand::Reverse) | (Strand::Reverse, Strand::Forward) => {
rg_counts.reverse += 1;
overall_counts.reverse += 1;
}
}
}
}

/// Method to predict the strandedness of a read group.
fn predict_strandedness(rg_name: &str, counts: &Counts) -> ReadGroupDerivedStrandednessResult {
if counts.forward == 0 && counts.reverse == 0 {
return ReadGroupDerivedStrandednessResult {
read_group: rg_name.to_string(),
succeeded: false,
strandedness: "Inconclusive".to_string(),
total: 0,
forward: 0,
reverse: 0,
forward_pct: 0.0,
reverse_pct: 0.0,
};
}
let mut result = ReadGroupDerivedStrandednessResult::new(
rg_name.to_string(),
false,
"Inconclusive".to_string(),
counts.forward,
counts.reverse,
);

if result.forward_pct > STRANDED_THRESHOLD {
result.succeeded = true;
result.strandedness = "Forward".to_string();
} else if result.reverse_pct > STRANDED_THRESHOLD {
result.succeeded = true;
result.strandedness = "Reverse".to_string();
} else if result.forward_pct > UNSTRANDED_THRESHOLD && result.reverse_pct > UNSTRANDED_THRESHOLD
{
result.succeeded = true;
result.strandedness = "Unstranded".to_string();
}

return result;
}

/// Main method to evaluate the observed strand state and
/// return a result for the derived strandedness. This may fail, and the
Expand All @@ -310,10 +478,14 @@ pub fn predict(
filters: &StrandednessFilters,
gene_metrics: &mut GeneRecordMetrics,
exon_metrics: &mut ExonRecordMetrics,
read_metrics: &mut ReadRecordMetrics,
) -> Result<DerivedStrandednessResult, anyhow::Error> {
let rng = rand::thread_rng();
let mut num_tested_genes: usize = 0;
let mut read_metrics = ReadRecordMetrics::default();
let mut rng = rand::thread_rng();
let mut num_tested_genes: usize = 0; // Local to this attempt
let mut all_counts: HashMap<&str, Counts> = HashMap::new();

all_counts.insert(UNKNOWN_READ_GROUP.as_str(), Counts::default());
all_counts.insert(OVERALL.as_str(), Counts::default());

for _ in 0..max_iterations_per_try {
if num_tested_genes > num_genes {
Expand All @@ -323,15 +495,15 @@ pub fn predict(
let cur_gene = gene_records.swap_remove(rng.gen_range(0..gene_records.len()));

if disqualify_gene(&cur_gene, exons) {
gene_metrics.exons_on_both_strands += 1;
gene_metrics.bad_strands += 1;
continue;
}

let mut enough_reads = false;
for read in query_filtered_reads(parsed_bam, &cur_gene, filters, &mut read_metrics) {
for read in query_filtered_reads(parsed_bam, &cur_gene, filters, read_metrics) {
enough_reads = true;

// TODO classify_read(&read, &cur_gene.strand());
classify_read(&read, &cur_gene.strand(), &mut all_counts, read_metrics);
}
if enough_reads {
num_tested_genes += 1;
Expand All @@ -340,5 +512,40 @@ pub fn predict(
}
}

anyhow::Ok(result)
gene_metrics.tested += num_tested_genes; // Add to any other attempts

// Overly complicated but IDK how to simplify this
let found_rgs = all_counts
.keys()
.cloned()
.map(|rg| rg.to_string())
.collect::<Vec<_>>();
let found_rgs_arc = found_rgs
.iter()
.map(|rg| Arc::new(rg.clone()))
.collect::<HashSet<_>>();

let rgs_in_header_not_found = validate_read_group_info(&found_rgs_arc, &parsed_bam.header);
for rg in rgs_in_header_not_found {
all_counts.insert(rg.as_str(), Counts::default());
}

let mut final_result = DerivedStrandednessResult::new(
true,
"Inconclusive".to_string(),
0,
0,
Vec::new(),
read_metrics.clone(),
gene_metrics.clone(),
exon_metrics.clone(),
);

for (rg, counts) in all_counts {
if rg == UNKNOWN_READ_GROUP.as_str() && counts.forward == 0 && counts.reverse == 0 {
continue;
}
}

anyhow::Ok(final_result)
}

0 comments on commit fe6eabd

Please sign in to comment.