Skip to content

Commit

Permalink
Simplify pa-bin binary
Browse files Browse the repository at this point in the history
  • Loading branch information
RagnarGrootKoerkamp committed Mar 27, 2024
1 parent 03bed57 commit 7ce4fdb
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 211 deletions.
33 changes: 8 additions & 25 deletions astarpa-next/src/bin/path_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
//~ This file is mostly identical to `pa-bin/src/main.rs`, but wraps the given
// heuristic in the `PathHeuristic`. To achieve this, some more functions are inlined here.

use astarpa::HeuristicParams;
use astarpa_next::path_pruning::PathHeuristic;
use clap::Parser;
use pa_affine_types::{AffineAligner, AffineCost};
use pa_base_algos::{
nw::{AffineFront, NW},
Domain,
};
use pa_bin::cli::Cli;
use pa_bin::Cli;
use pa_heuristic::{Heuristic, HeuristicMapper};
use pa_types::*;
use pa_vis_types::{NoVis, VisualizerT};
Expand All @@ -18,31 +19,13 @@ use std::{
ops::ControlFlow,
};

pub fn astar_aligner(args: &Cli) -> Box<dyn AffineAligner> {
#[cfg(not(feature = "vis"))]
{
make_path_heuristic_aligner(args, NoVis)
}

#[cfg(feature = "vis")]
{
use pa_vis::cli::VisualizerType;
match args.vis.make_visualizer() {
VisualizerType::NoVisualizer => make_path_heuristic_aligner(args, NoVis),
VisualizerType::Visualizer(vis) => {
eprintln!("vis!");
make_path_heuristic_aligner(args, vis)
}
}
}
pub fn astar_aligner() -> Box<dyn AffineAligner> {
make_path_heuristic_aligner(NoVis)
}

fn make_path_heuristic_aligner(
args: &Cli,
vis: impl VisualizerT + 'static,
) -> Box<dyn AffineAligner> {
let dt = args.diagonal_transition;
let h = &args.heuristic;
fn make_path_heuristic_aligner(vis: impl VisualizerT + 'static) -> Box<dyn AffineAligner> {
let dt = true;
let h = &HeuristicParams::default();
struct Mapper<V: VisualizerT> {
#[allow(unused)]
dt: bool,
Expand Down Expand Up @@ -71,7 +54,7 @@ fn make_path_heuristic_aligner(
fn main() {
let args = Cli::parse();

let mut aligner = astar_aligner(&args);
let mut aligner = astar_aligner();

let mut out_file = args
.output
Expand Down
2 changes: 1 addition & 1 deletion astarpa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl<V: VisualizerT, H: Heuristic> AstarPa<V, H> {
}

/// Helper trait to erase the type of the heuristic that additionally returns alignment statistics.
pub trait AstarStatsAligner {
pub trait AstarStatsAligner: Aligner {
fn align(&self, a: Seq, b: Seq) -> ((Cost, Cigar), AstarStats);
}

Expand Down
6 changes: 2 additions & 4 deletions pa-bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ astarpa.workspace = true
astarpa2.workspace = true
itertools.workspace = true
clap.workspace = true
pa-vis = {workspace = true, optional=true}
serde.workspace = true
bio.workspace = true
rand_chacha = "0.3"
Expand All @@ -27,13 +26,12 @@ pa-base-algos.workspace = true
pa-affine-types.workspace = true
pa-vis-types.workspace = true
pa-bitpacking.workspace = true
pa-vis.workspace = true

[features]
# Needed to correctly show pruned matches in visualizations.
example = ["pa-heuristic/example", "astarpa2/example"]
# Visualizer features can be disabled.
vis = ["dep:pa-vis"]
default = ["vis"]
default = []

# A*PA figures
[[example]]
Expand Down
130 changes: 0 additions & 130 deletions pa-bin/src/cli.rs

This file was deleted.

135 changes: 134 additions & 1 deletion pa-bin/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1 +1,134 @@
pub mod cli;
#![feature(trait_upcasting)]

use astarpa::{make_aligner, HeuristicParams};
use astarpa2::AstarPa2Params;
use bio::io::fasta;
use clap::{value_parser, Parser};
use itertools::Itertools;
use pa_types::{Aligner, Seq};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
use std::{
fs::File,
io::{BufRead, BufReader},
ops::ControlFlow,
path::PathBuf,
};

#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum AlignerType {
Astarpa,
Astarpa2Simple,
#[default]
Astarpa2Full,
}

impl AlignerType {
pub fn build(&self) -> Box<dyn Aligner> {
match self {
AlignerType::Astarpa => make_aligner(true, &HeuristicParams::default()),
AlignerType::Astarpa2Simple => AstarPa2Params::simple().make_aligner(true),
AlignerType::Astarpa2Full => AstarPa2Params::full().make_aligner(true),
}
}
}

/// Globally align pairs of sequences using A*PA.
#[derive(Parser, Serialize, Deserialize)]
#[clap(author, about, disable_version_flag(true))]
// Override some generator flags
#[clap(mut_arg("seed", |a| a.hide_short_help(true)))]
#[clap(mut_arg("cnt", |a| a.hide_short_help(true)))]
#[clap(mut_arg("size", |a| a.hide_short_help(true)))]
#[clap(mut_arg("error_model", |a| a.hide_short_help(true)))]
#[clap(mut_arg("error_model", |a| a.hide_short_help(true)))]
#[clap(group(
clap::ArgGroup::new("input_type")
.required(true)
.args(&["input", "length"]),
))]
pub struct Cli {
/// A .seq, .txt, or Fasta file with sequence pairs to align.
#[clap(short, long, value_parser = value_parser!(PathBuf), display_order = 1)]
pub input: Option<PathBuf>,

/// Write a .csv of `{cost},{cigar}` lines
#[clap(short, long, value_parser = value_parser!(PathBuf), display_order = 1)]
pub output: Option<PathBuf>,

/// The aligner to use.
#[clap(long, default_value = "astarpa2-full")]
pub aligner: AlignerType,

/// Options to generate an input pair.
#[clap(flatten, next_help_heading = "Generated input")]
pub generate: pa_generate::DatasetGenerator,
}

impl Cli {
/// Call the given function for each pair in the input.
pub fn process_input_pairs(&self, mut run_pair: impl FnMut(Seq, Seq) -> ControlFlow<()>) {
if let Some(input) = &self.input {
// Parse file
let files = if input.is_file() {
vec![input.clone()]
} else {
input
.read_dir()
.expect(&format!("{} is not a file or directory", input.display()))
.map(|x| x.unwrap().path())
.collect_vec()
};

'outer: for f in files {
match f.extension().expect("Unknown file extension") {
ext if ext == "seq" || ext == "txt" => {
let f = std::fs::File::open(&f).unwrap();
let f = BufReader::new(f);
for (mut a, mut b) in f.lines().map(|l| l.unwrap().into_bytes()).tuples() {
if ext == "seq" {
assert_eq!(a.remove(0), '>' as u8);
assert_eq!(b.remove(0), '<' as u8);
}
if let ControlFlow::Break(()) = run_pair(&a, &b) {
break 'outer;
}
}
}
ext if ext == "fna" || ext == "fa" || ext == "fasta" => {
for (a, b) in fasta::Reader::new(BufReader::new(File::open(&f).unwrap()))
.records()
.tuples()
{
if let ControlFlow::Break(()) =
run_pair(a.unwrap().seq(), b.unwrap().seq())
{
break 'outer;
}
}
}
ext => {
unreachable!(
"Unknown file extension {ext:?}. Must be in {{seq,txt,fna,fa,fasta}}."
)
}
};
}
} else {
// Generate random input.
let seed = self.generate.seed.unwrap_or_else(|| {
let seed = ChaCha8Rng::from_entropy().gen_range(0..1_000);
eprintln!("Seed: {seed}");
seed
});
let ref mut rng = ChaCha8Rng::seed_from_u64(seed);
for _ in 0..self.generate.cnt.unwrap() {
let (a, b) = self.generate.settings.generate(rng);
if let ControlFlow::Break(()) = run_pair(&a, &b) {
break;
}
}
}
}
}
Loading

0 comments on commit 7ce4fdb

Please sign in to comment.