Skip to content

Commit

Permalink
Merge pull request #44 from abstractqqq/restructure
Browse files Browse the repository at this point in the history
restructured rust code org
  • Loading branch information
abstractqqq authored Dec 30, 2023
2 parents b83471f + c6c8ac2 commit d13cd33
Show file tree
Hide file tree
Showing 49 changed files with 245 additions and 219 deletions.
233 changes: 127 additions & 106 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pyo3 = {version = "*", features = ["extension-module"]}
pyo3-polars = {version = "*", features = ["derive"]}
polars = {version = "0.35.4", features = ["performant", "lazy", "dtype-array", "ndarray", "log", "nightly"]}
num = "0.4.1"
faer = {version = "0.15", features = ["ndarray", "nightly"]}
faer = {version = "0.16", features = ["ndarray", "nightly"]}
serde = {version = "*", features=["derive"]}
ndarray = {version="0.15.6", features=["rayon"]} # see if we can get rid of this
hashbrown = {version = "0.14.2", features=["nightly"]}
Expand Down
17 changes: 17 additions & 0 deletions python/polars_ds/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ class MetricExt:
def __init__(self, expr: pl.Expr):
self._expr: pl.Expr = expr

def max_error(self, pred: pl.Expr) -> pl.Expr:
"""
Computes the max absolute error between actual and pred.
"""
x = self._expr - pred
return pl.max_horizontal(x.max(), -x.min())

def mean_gamma_deviance(self, pred: pl.Expr) -> pl.Expr:
"""
Computes the mean gamma deviance between actual and pred.
Note that this will return NaNs when any value is < 0. This only makes sense when y_true
and y_pred as strictly positive.
"""
x = self._expr / pred
return 2.0 * (x.log() + x - 1).mean()

def hubor_loss(self, pred: pl.Expr, delta: float) -> pl.Expr:
"""
Computes huber loss between this and the other expression. This assumes
Expand Down
20 changes: 4 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
mod num_ext;
mod num;
mod stats_utils;
mod stats;
mod stats_ext;
mod str_ext;
mod str2;
mod utils;
use polars::{
error::{PolarsError, PolarsResult},
series::Series,
};
use pyo3::{pymodule, types::PyModule, PyResult, Python};

#[cfg(target_os = "linux")]
Expand All @@ -16,18 +12,10 @@ use jemallocator::Jemalloc;
#[cfg(target_os = "linux")]
static ALLOC: Jemalloc = Jemalloc;

// #[inline]
// pub fn no_null_in_inputs(inputs: &[Series], err_msg: String) -> PolarsResult<()> {
// for s in inputs {
// if s.null_count() > 0 {
// return Err(PolarsError::ComputeError(err_msg.into()));
// }
// }
// Ok(())
// }

#[pymodule]
#[pyo3(name = "_polars_ds")]
fn _polars_ds(_py: Python<'_>, _m: &PyModule) -> PyResult<()> {

Ok(())
}
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/num_ext/entrophies.rs → src/num/entrophies.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::num_ext::knn::{build_standard_kdtree, query_nb_cnt, KdtreeKwargs};
use crate::num::knn::{build_standard_kdtree, query_nb_cnt, KdtreeKwargs};
use ndarray::s;
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/stats_ext/chi2.rs → src/stats/chi2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::simple_stats_output;
use crate::stats::gamma;
use crate::stats_utils::gamma;
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

Expand Down
2 changes: 1 addition & 1 deletion src/stats_ext/fstats.rs → src/stats/fstats.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/// Multiple F-statistics at once and F test
use super::{list_float_output, simple_stats_output, StatsResult};
use crate::stats::beta::fisher_snedecor_sf;
use crate::stats_utils::beta::fisher_snedecor_sf;
use itertools::Itertools;
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;
Expand Down
2 changes: 1 addition & 1 deletion src/stats_ext/ks.rs → src/stats/ks.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/// KS statistics.
use crate::stats_ext::StatsResult;
use crate::stats::StatsResult;
use crate::utils::binary_search_right;
use itertools::Itertools;
use polars::prelude::*;
Expand Down
80 changes: 63 additions & 17 deletions src/stats/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,64 @@
/// This submodule is mostly taken from the project statrs. See credit section in README.md
/// The reason I do not want to add it as a dependency is that it has a nalgebra dependency for
/// multi-variate distributions, which is something that I think will not be needed in this
/// package. Another reason is that if I want to do linear algebra, I would use Faer since Faer
/// performs better and nalgebra is too much of a dependency for this package right now.
pub mod beta;
pub mod gamma;
pub mod normal;

pub const PREC_ACC: f64 = 0.0000000000000011102230246251565;
pub const LN_PI: f64 = 1.1447298858494001741434273513530587116472948129153;
//pub const LN_SQRT_2PI: f64 = 0.91893853320467274178032973640561763986139747363778;
pub const LN_2_SQRT_E_OVER_PI: f64 = 0.6207822376352452223455184457816472122518527279025978;

#[inline]
pub fn is_zero(x: f64) -> bool {
x.abs() < PREC_ACC
mod chi2;
mod fstats;
mod ks;
mod normal_test;
mod sample;
mod t_test;

use polars::prelude::*;

pub fn list_float_output(_: &[Field]) -> PolarsResult<Field> {
Ok(Field::new(
"list_float",
DataType::List(Box::new(DataType::Float64)),
))
}

pub fn simple_stats_output(_: &[Field]) -> PolarsResult<Field> {
let s = Field::new("statistic", DataType::Float64);
let p = Field::new("pvalue", DataType::Float64);
let v: Vec<Field> = vec![s, p];
Ok(Field::new("", DataType::Struct(v)))
}

struct StatsResult {
pub statistic: f64,
pub p: Option<f64>,
}

impl StatsResult {
pub fn new(s: f64, p: f64) -> StatsResult {
StatsResult {
statistic: s,
p: Some(p),
}
}

pub fn from_stats(s: f64) -> StatsResult {
StatsResult {
statistic: s,
p: None,
}
}

pub fn unwrap_p_or(&self, default: f64) -> f64 {
self.p.unwrap_or(default)
}
}

pub enum Alternative {
TwoSided,
Less,
Greater,
}

impl From<&str> for Alternative {
fn from(s: &str) -> Alternative {
match s.to_lowercase().as_str() {
"two-sided" | "two" => Alternative::TwoSided,
"less" => Alternative::Less,
"greater" => Alternative::Greater,
_ => Alternative::TwoSided,
}
}
}
2 changes: 1 addition & 1 deletion src/stats_ext/normal_test.rs → src/stats/normal_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
///
/// I chose this over the Shapiro Francia test because the distribution is unknown and would require Monte Carlo
use super::{simple_stats_output, StatsResult};
use crate::stats::{gamma, is_zero};
use crate::stats_utils::{gamma, is_zero};
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions src/stats_ext/t_test.rs → src/stats/t_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/// Student's t test and Welch's t test.
use super::{simple_stats_output, Alternative, StatsResult};
use crate::stats::{beta, is_zero};
use crate::{stats_utils::{beta, is_zero}, stats};
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

Expand Down Expand Up @@ -110,7 +110,7 @@ fn pl_ttest_2samp(inputs: &[Series]) -> PolarsResult<Series> {

let alt = inputs[5].utf8()?;
let alt = alt.get(0).unwrap();
let alt = super::Alternative::from(alt);
let alt = stats::Alternative::from(alt);

let valid = mean1.is_finite() && mean2.is_finite() && var1.is_finite() && var2.is_finite();
if !valid {
Expand Down Expand Up @@ -147,7 +147,7 @@ fn pl_welch_t(inputs: &[Series]) -> PolarsResult<Series> {

let alt = inputs[6].utf8()?;
let alt = alt.get(0).unwrap();
let alt = super::Alternative::from(alt);
let alt = stats::Alternative::from(alt);

// No need to check for validity because input is sanitized.

Expand Down Expand Up @@ -175,7 +175,7 @@ fn pl_ttest_1samp(inputs: &[Series]) -> PolarsResult<Series> {

let alt = inputs[4].utf8()?;
let alt = alt.get(0).unwrap();
let alt = super::Alternative::from(alt);
let alt = stats::Alternative::from(alt);

// No need to check for validity because input is sanitized.

Expand Down
64 changes: 0 additions & 64 deletions src/stats_ext/mod.rs

This file was deleted.

File renamed without changes.
File renamed without changes.
18 changes: 18 additions & 0 deletions src/stats_utils/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/// This submodule is mostly taken from the project statrs. See credit section in README.md
/// The reason I do not want to add it as a dependency is that it has a nalgebra dependency for
/// multi-variate distributions, which is something that I think will not be needed in this
/// package. Another reason is that if I want to do linear algebra, I would use Faer since Faer
/// performs better and nalgebra is too much of a dependency for this package right now.
pub mod beta;
pub mod gamma;
pub mod normal;

pub const PREC_ACC: f64 = 0.0000000000000011102230246251565;
pub const LN_PI: f64 = 1.1447298858494001741434273513530587116472948129153;
//pub const LN_SQRT_2PI: f64 = 0.91893853320467274178032973640561763986139747363778;
pub const LN_2_SQRT_E_OVER_PI: f64 = 0.6207822376352452223455184457816472122518527279025978;

#[inline]
pub fn is_zero(x: f64) -> bool {
x.abs() < PREC_ACC
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#![allow(unused_mut)]
#![allow(unused_parens)]
#![allow(unused_variables)]
use crate::str_ext::snowball::Among;
use crate::str_ext::snowball::SnowballEnv;
use crate::str2::snowball::Among;
use crate::str2::snowball::SnowballEnv;

static A_0: &'static [Among<Context>; 3] = &[
Among("arsen", -1, -1, None),
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::str_ext::snowball::SnowballEnv;
use crate::str2::snowball::SnowballEnv;

pub struct Among<T: 'static>(
pub &'static str,
Expand Down
4 changes: 2 additions & 2 deletions src/str_ext/snowball/mod.rs → src/str2/snowball/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ mod among;
mod snowball_env;

// TODO: why do we need this `crate::`?
pub use crate::str_ext::snowball::among::Among;
pub use crate::str_ext::snowball::snowball_env::SnowballEnv;
pub use crate::str2::snowball::among::Among;
pub use crate::str2::snowball::snowball_env::SnowballEnv;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::str_ext::snowball::Among;
use crate::str2::snowball::Among;
use std::borrow::Cow;

#[derive(Debug, Clone)]
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit d13cd33

Please sign in to comment.