Skip to content

Commit

Permalink
Merge pull request #3 from abstractqqq/better_lstsq
Browse files Browse the repository at this point in the history
Better lstsq
  • Loading branch information
abstractqqq authored Nov 3, 2023
2 parents 063f713 + 6ac4bb0 commit 1ea4ce1
Show file tree
Hide file tree
Showing 10 changed files with 883 additions and 152 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ polars-lazy = "0.34"
num = "0.4.1"
faer = {version = "0.14.1", features = ["ndarray"]}
ndarray = "0.15.6"
hashbrown = "0.14.2"

[target.'cfg(target_os = "linux")'.dependencies]
jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] }
23 changes: 23 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
SHELL=/bin/bash

venv: ## Set up virtual environment
python3 -m venv .venv
.venv/bin/pip install -r requirements.txt

install: venv
unset CONDA_PREFIX && \
source .venv/bin/activate && maturin develop -m Cargo.toml

dev-release: venv
unset CONDA_PREFIX && \
source .venv/bin/activate && maturin develop --release -m Cargo.toml
pip install .

pre-commit: venv
cargo fmt --all --manifest-path Cargo.toml && cargo clippy --all-features --manifest-path Cargo.toml

# run: install
# source .venv/bin/activate && python run.py

# run-release: install-release
# source venv/bin/activate && python run.py
37 changes: 26 additions & 11 deletions python/polars_ds/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def hubor_loss(self, other: pl.Expr, delta: float) -> pl.Expr:

def l1_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:
"""
Computes L1 loss (normalized L1 distance) between this and the other expression
Computes L1 loss (normalized L1 distance) between this and the other expression. This
is the norm without 1/p power.
Parameters
----------
Expand All @@ -149,7 +150,9 @@ def l1_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:

def l2_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:
"""
Computes L2 loss (normalized L2 distance) between this and the other expression
Computes L2 loss (normalized L2 distance) between this and the other expression. This
is the norm without 1/p power.
Parameters
----------
Expand All @@ -166,7 +169,9 @@ def l2_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:

def lp_loss(self, other: pl.Expr, p: float, normalize: bool = True) -> pl.Expr:
"""
Computes LP loss (normalized LP distance) between this and the other expression
Computes LP loss (normalized LP distance) between this and the other expression. This
is the norm without 1/p power.
for p finite.
Parameters
Expand All @@ -179,9 +184,9 @@ def lp_loss(self, other: pl.Expr, p: float, normalize: bool = True) -> pl.Expr:
if p <= 0:
raise ValueError(f"Input `p` must be > 0, not {p}")

temp = (self._expr - other).abs().pow(p)
temp = (self._expr - other).abs().pow(p).sum()
if normalize:
return temp / self._expr.count()
return (temp / self._expr.count())
return temp

def chebyshev_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:
Expand Down Expand Up @@ -243,20 +248,30 @@ def smape(self, other: pl.Expr) -> pl.Expr:
denominator = 1.0 / (self._expr.abs() + other.abs())
return (1.0 / self._expr.count()) * numerator.dot(denominator)

def lstsq(self, *other: pl.Expr) -> pl.Expr:
def lstsq(self, other: list[pl.Expr], add_bias:bool=False) -> pl.Expr:
"""
Computes least squares solution to a linear matrix equation.
Computes least squares solution to a linear matrix equation. If columns are
not linearly independent, some numerical issue or error may occur. Unrealistic
coefficient values is an indication of `silent` numerical problem during the
computation.
If add_bias is true, it will be the last coefficient in the output
and output will have length |other| + 1
Parameters
----------
other
Either an int or a Polars expression
List of Polars expressions. They should have the same size.
add_bias
Whether to add a bias term
"""
return self._expr._register_plugin(

return self._expr.register_plugin(
lib=lib,
symbol="lstsq",
args=list(other),
symbol="pl_lstsq",
args=[pl.lit(add_bias, dtype=pl.Boolean)] + other,
is_elementwise=False,
returns_scalar=True
)


Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
maturin
polars
numpy
pytest
Loading

0 comments on commit 1ea4ce1

Please sign in to comment.