Skip to content

Commit

Permalink
Matrix: add 'solve' method (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym authored Dec 24, 2023
1 parent becea99 commit 8b30c13
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 13.12.0
- `Matrix`:
- add `solve` method for solving a system of linear equations

## 13.11.31
- Lib migration to Dart 3.0.0 (only non-breaking changes)

Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
- [Matrix Cholesky inversion](#matrix-cholesky-inversion)
- [Lower triangular matrix inversion](#lower-triangular-matrix-inversion)
- [Upper triangular matrix inversion](#upper-triangular-matrix-inversion)
- [Solving a system of linear equations](#solving-a-system-of-linear-equations)
- [Obtaining Matrix eigenvectors and eigenvalues, Power Iteration method](#obtaining-matrix-eigenvectors-and-eigenvalues-power-iteration-method)
- [Matrix row-wise reduce](#matrix-row-wise-reduce)
- [Matrix column-wise reduce](#matrix-column-wise-reduce)
Expand Down Expand Up @@ -977,6 +978,37 @@ print(matrix1 - matrix2);
// [0, 0, 1],
```


#### Solving a system of linear equations

A matrix notation for [a system of linear equations](https://en.wikipedia.org/wiki/System_of_linear_equations):

```
AX=B
```

To solve the system and find `X`, one may use the `solve` method:

````Dart
import 'package:ml_linalg/linalg.dart';
void main() {
final A = Matrix.fromList([
[1, 1, 1],
[0, 2, 5],
[2, 5, -1],
], dtype: dtype);
final B = Matrix.fromList([
[6],
[-4],
[27],
], dtype: dtype);
final result = A.solve(B);
print(result); // the output is close to [[5], [3], [-2]]
}
````

#### Obtaining Matrix eigenvectors and eigenvalues, Power Iteration method

The method returns a collection of pairs of an eigenvector and its corresponding eigenvalue.
Expand Down
33 changes: 26 additions & 7 deletions lib/matrix.dart
Original file line number Diff line number Diff line change
Expand Up @@ -738,10 +738,10 @@ abstract class Matrix implements Iterable<Iterable<double>> {
/// ```
Vector toVector();

/// Returns maximal value of the matrix
/// Returns a max value of the matrix
double max();

/// Return minimal value of the matrix
/// Return a min value of the matrix
double min();

/// Returns a norm of a matrix
Expand Down Expand Up @@ -774,14 +774,16 @@ abstract class Matrix implements Iterable<Iterable<double>> {
/// Creates a new [Matrix] composed of Euler's numbers raised to powers which
/// are the elements of this [Matrix]
Matrix exp(
{@Deprecated('The flag is useless, it\'ll be removed in the next major update')
bool skipCaching = false});
{@Deprecated(
'The flag is useless, it\'ll be removed in the next major update')
bool skipCaching = false});

/// Creates a new [Matrix] composed of natural logarithms of the source
/// matrix elements
Matrix log(
{@Deprecated('The flag is useless, it\'ll be removed in the next major update')
bool skipCaching = false});
{@Deprecated(
'The flag is useless, it\'ll be removed in the next major update')
bool skipCaching = false});

/// Performs Hadamard product - element-wise matrices multiplication
Matrix multiply(Matrix other);
Expand All @@ -803,9 +805,26 @@ abstract class Matrix implements Iterable<Iterable<double>> {
{EigenMethod method, Vector? initial, int iterationCount, int? seed});

/// Finds the inverse of the original matrix. Product of the inverse and the original matrix results in singular matrix
/// Default value id [Inverse.LU]
/// Default value is [Inverse.LU]
Matrix inverse([Inverse inverseType]);

/// Returns a solution for [a system of linear equations](https://en.wikipedia.org/wiki/System_of_linear_equations):
///
/// ```
/// A*X = B
/// ```
///
/// where `A` is this Matrix, [B] is a column matrix of [this matrix row count]x1 dimension
///
/// To solve the system, one should do the following
///
/// ```
/// X = inverse(A)*B
/// ```
///
/// To find the inverse of this matrix, one should specify the [Inverse] type through passing the [inverse] argument, default value is [Inverse.LU]
Matrix solve(Matrix B, [Inverse inverse]);

/// Returns a serializable map
Map<String, dynamic> toJson();
}
5 changes: 5 additions & 0 deletions lib/src/matrix/float32_matrix.dart
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,11 @@ class Float32Matrix
}
}

@override
Matrix solve(Matrix B, [Inverse inverse = Inverse.LU]) {
return this.inverse(inverse) * B;
}

@override
Iterable<Matrix> decompose(
[Decomposition decompositionType = Decomposition.LU]) {
Expand Down
5 changes: 5 additions & 0 deletions lib/src/matrix/float64_matrix.g.dart
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,11 @@ class Float64Matrix
}
}

@override
Matrix solve(Matrix B, [Inverse inverse = Inverse.LU]) {
return this.inverse(inverse) * B;
}

@override
Iterable<Matrix> decompose(
[Decomposition decompositionType = Decomposition.LU]) {
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: ml_linalg
description: SIMD-based linear algebra and statistics, efficient manipulation with numeric data
version: 13.11.31
version: 13.12.0
homepage: https://github.com/gyrdym/ml_linalg

environment:
Expand Down
7 changes: 7 additions & 0 deletions test/matrix/methods/solve/solve_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import 'package:ml_linalg/dtype.dart';
import 'solve_test_group_factory.dart';

void main() {
matrixSolveTestGroupFactory(DType.float32);
matrixSolveTestGroupFactory(DType.float64);
}
50 changes: 50 additions & 0 deletions test/matrix/methods/solve/solve_test_group_factory.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:test/test.dart';

import '../../../dtype_to_title.dart';
import '../../../helpers.dart';

void matrixSolveTestGroupFactory(DType dtype) =>
group(dtypeToMatrixTestTitle[dtype], () {
group('solve', () {
test('should solve system of linear equations, 2x2 matrix', () {
final A = Matrix.fromList([
[3, 1],
[4, 3],
], dtype: dtype);
final B = Matrix.fromList([
[3],
[17],
], dtype: dtype);
final actual = A.solve(B);
final expected = [
[-1.6],
[7.8],
];

expect(actual, iterable2dAlmostEqualTo(expected, 1e-6));
});

test('should solve system of linear equations, 3x3 matrix', () {
final A = Matrix.fromList([
[1, 1, 1],
[0, 2, 5],
[2, 5, -1],
], dtype: dtype);
final B = Matrix.fromList([
[6],
[-4],
[27],
], dtype: dtype);
final actual = A.solve(B);
final expected = [
[5],
[3],
[-2]
];

expect(actual, iterable2dAlmostEqualTo(expected, 1e-6));
});
});
});

0 comments on commit 8b30c13

Please sign in to comment.