From f28ce2b871b29e28a228afc6b64de23405e857b7 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Sun, 25 Feb 2024 23:57:37 +0800 Subject: [PATCH] update --- src/PermMatrix.jl | 14 ++++-- src/broadcast.jl | 4 +- src/iterate.jl | 4 +- src/kronecker.jl | 6 +++ src/linalg.jl | 18 ++++--- test/PermMatrixCSC.jl | 107 ++++++++++++++++++++++++++++++++++++++++++ test/broadcast.jl | 16 ++++++- test/iterate.jl | 1 + test/kronecker.jl | 11 +++-- test/runtests.jl | 1 + 10 files changed, 163 insertions(+), 19 deletions(-) create mode 100644 test/PermMatrixCSC.jl diff --git a/src/PermMatrix.jl b/src/PermMatrix.jl index d6d0d70..ab3bc73 100644 --- a/src/PermMatrix.jl +++ b/src/PermMatrix.jl @@ -150,10 +150,16 @@ pmcscrand(n::Int) = pmcscrand(Float64, n) Base.show(io::IO, ::MIME"text/plain", M::AbstractPermMatrix) = show(io, M) function Base.show(io::IO, M::AbstractPermMatrix) - println(io, "PermMatrix") - for ((i, j), p) in IterNz(M) - print(io, "($i, $j) = $p") - i < length(M.perm) && println(io) + n = size(M, 1) + println(io, typeof(M)) + nmax = 20 + for (k, (i, j, p)) in enumerate(IterNz(M)) + if k <= nmax || k > n-nmax + print(io, "($i, $j) = $p") + k < n && println(io) + elseif k == nmax+1 + println(io, "...") + end end end diff --git a/src/broadcast.jl b/src/broadcast.jl index ba476e8..e8804a5 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -44,7 +44,7 @@ Broadcast.broadcasted( # specialize perm matrix function _broadcast_perm_prod(A::AbstractPermMatrix, B::AbstractMatrix) dest = similar(A, Base.promote_op(*, eltype(A), eltype(B))) - @inbounds for ((i, j), a) in IterNz(A) + @inbounds for (i, j, a) in IterNz(A) dest[i, j] = a * B[i, j] end return dest @@ -71,7 +71,7 @@ Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::IMatrix, B::Abstr Diagonal(B) function _broadcast_diag_perm_prod(A::Diagonal, B::AbstractPermMatrix) - Diagonal(A.diag .* getindex.(Ref(B), 1:size(A, 1))) + Diagonal(A.diag .* getindex.(Ref(B), 1:size(A, 1), 1:size(A, 2))) end Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::Diagonal) = diff --git a/src/iterate.jl b/src/iterate.jl index bb4cc84..d2d61ca 100644 --- a/src/iterate.jl +++ b/src/iterate.jl @@ -47,12 +47,12 @@ end # PermMatrixCSC function Base.iterate(it::IterNz{<:PermMatrixCSC}) 0 == length(it) && return nothing - return ((@inbounds it.A.perm[1], 1), (@inbounds it.A.vals[1])), 1 + return ((@inbounds it.A.perm[1]), 1, (@inbounds it.A.vals[1])), 1 end function Base.iterate(it::IterNz{<:PermMatrixCSC}, state) state == length(it) && return nothing state += 1 - return ((@inbounds it.A.perm[state], state), (@inbounds it.A.vals[state])), state + return ((@inbounds it.A.perm[state]), state, (@inbounds it.A.vals[state])), state end # AbstractMatrix diff --git a/src/kronecker.jl b/src/kronecker.jl index 7d7e909..eb2b608 100644 --- a/src/kronecker.jl +++ b/src/kronecker.jl @@ -31,6 +31,12 @@ LinearAlgebra.kron(A::IMatrix{Ta}, B::IMatrix{Tb}) where {Ta<:Number,Tb<:Number} LinearAlgebra.kron(A::IMatrix{<:Number}, B::Diagonal{<:Number}) = A.n == 1 ? B : Diagonal(orepeat(B.diag, A.n)) LinearAlgebra.kron(B::Diagonal{<:Number}, A::IMatrix) = A.n == 1 ? B : Diagonal(irepeat(B.diag, A.n)) +####### diagonal kron ######## +LinearAlgebra.kron(A::StridedMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrixCSC(B)) +LinearAlgebra.kron(A::Diagonal{<:Number}, B::StridedMatrix{<:Number}) = kron(PermMatrixCSC(A), B) +LinearAlgebra.kron(A::Diagonal{<:Number}, B::SparseMatrixCSC{<:Number}) = kron(PermMatrixCSC(A), B) +LinearAlgebra.kron(A::SparseMatrixCSC{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrixCSC(B)) + function LinearAlgebra.kron(A::AbstractMatrix{Tv}, B::IMatrix) where {Tv<:Number} B.n == 1 && return A mA, nA = size(A) diff --git a/src/linalg.jl b/src/linalg.jl index edc3a53..cf696cb 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -64,7 +64,7 @@ function LinearAlgebra.mul!(Y::AbstractVector, A::AbstractPermMatrix, X::Abstrac length(X) == size(A, 2) || throw(DimensionMismatch("input X length does not match permutation matrix A")) length(Y) == size(A, 2) || throw(DimensionMismatch("output Y length does not match permutation matrix A")) - @inbounds for ((i, j), p) in IterNz(A) + @inbounds for (i, j, p) in IterNz(A) Y[i] = p * X[j] * alpha + beta * Y[i] end return Y @@ -85,22 +85,28 @@ function Base.:*(A::PermMatrixCSC{Ta}, D::Diagonal{Td}) where {Td,Ta} end # to self -function Base.:*(A::AbstractPermMatrix, B::AbstractPermMatrix) +function Base.:*(A::PermMatrix, B::PermMatrix) @assert basetype(A) == basetype(B) size(A, 1) == size(B, 1) || throw(DimensionMismatch()) - PermMatrix(B.perm[A.perm], A.vals .* view(B.vals, A.perm)) + basetype(A)(B.perm[A.perm], A.vals .* view(B.vals, A.perm)) +end + +function Base.:*(A::PermMatrixCSC, B::PermMatrixCSC) + @assert basetype(A) == basetype(B) + size(A, 1) == size(B, 1) || throw(DimensionMismatch()) + basetype(A)(A.perm[B.perm], B.vals .* view(A.vals, B.perm)) end # to matrix -function LinearAlgebra.:mul!(C::AbstractMatrix, A::AbstractPermMatrix, X::AbstractMatrix, alpha::Number, beta::Number) +function LinearAlgebra.mul!(C::AbstractMatrix, A::AbstractPermMatrix, X::AbstractMatrix, alpha::Number, beta::Number) size(X, 1) == size(A, 2) || throw(DimensionMismatch()) AR = PermMatrix(A) - C .= C .* beta .+ AR.vals .* view(X,AR.perm,:) .* alpha + C .= C .* beta .+ AR.vals .* view(X, AR.perm, :) .* alpha end function LinearAlgebra.mul!(C::AbstractMatrix, X::AbstractMatrix, A::AbstractPermMatrix, alpha::Number, beta::Number) size(X, 2) == size(A, 1) || throw(DimensionMismatch()) AC = PermMatrixCSC(A) - C .= C .* beta .+ reshape(AC.vals, 1, :) .* view(X, :, perm) .* alpha + C .= C .* beta .+ reshape(AC.vals, 1, :) .* view(X, :, AC.perm) .* alpha end # NOTE: this is just a temperory fix for v0.7. We should overload mul! in diff --git a/test/PermMatrixCSC.jl b/test/PermMatrixCSC.jl new file mode 100644 index 0000000..24bc2a0 --- /dev/null +++ b/test/PermMatrixCSC.jl @@ -0,0 +1,107 @@ +using Test, Random +import LuxurySparse: PermMatrixCSC, pmcscrand +import LuxurySparse +using SparseArrays: sprand, SparseMatrixCSC +using LinearAlgebra + +Random.seed!(2) +p1 = PermMatrixCSC([1, 4, 2, 3], [0.1, 0.2, 0.4im, 0.5]) +p2 = PermMatrixCSC([2, 1, 4, 3], [0.1, 0.2, 0.4, 0.5]) +#p3 = PermMatrix([4,1,2,3],[0.5, 0.4im, 0.3, 0.2]) +p3 = pmcscrand(4) +sp = sprand(4, 4, 0.3) +v = [0.5, 0.3im, 0.2, 1.0] + +@testset "basic" begin + @test p1 == copy(p1) + @test eltype(p1) == ComplexF64 + @test eltype(p2) == Float64 + @test eltype(p3) == Float64 + @test size(p1) == (4, 4) + @test size(p3) == (4, 4) + @test size(p1, 1) == size(p1, 2) == 4 + @test Matrix(p1) ≈ transpose([0.1 0 0 0; 0 0 0 0.2; 0 0.4im 0 0; 0 0 0.5 0]) + p0 = similar(p1) + @test p0.perm == p1.perm + @test p0.perm !== p1.perm + @test p0.vals !== p1.vals + @test p1[2, 2] === 0.0im + @test p1[1, 1] === 0.1 + 0.0im + copyto!(p0, p1) + @test p0 == p1 +end + +@testset "linalg" begin + @test inv(p1) ≈ inv(Matrix(p1)) + @test transpose(p1) ≈ transpose(Matrix(p1)) + @test inv(p1) * p1 ≈ Matrix(I, 4, 4) + @test p1 * transpose(p1) ≈ diagm(0 => p1.vals[invperm(p1.perm)] .^ 2) + #@test p1*adjoint(p1) == diagm(0=>abs.(p1.vals).^2) + #@test all(isapprox.(adjoint(p3), transpose(conj(Matrix(p3))))) + @test p1 * p1' == diagm(0 => abs.(p1.vals[invperm(p1.perm)]) .^ 2) + @test all(isapprox.(p3', transpose(conj(Matrix(p3))))) +end + +@testset "mul" begin + @test p3 * p2 ≈ SparseMatrixCSC(p3) * p2 ≈ Matrix(p3) * p2 + + # Multiply vector + @test p3 * v == Matrix(p3) * v + @test v' * p3 == v' * Matrix(p3) + @test vec(collect(1:4)' * p3) ≈ p3.perm .* p3.vals + + # Diagonal matrices + Dv = Diagonal(v) + @test p3 * Dv == Matrix(p3) * Dv + @test Dv * p3 == Dv * Matrix(p3) +end + +@testset "elementary" begin + @test all(isapprox.(conj(p1), conj(Matrix(p1)))) + @test all(isapprox.(real(p1), real(Matrix(p1)))) + @test all(isapprox.(imag(p1), imag(Matrix(p1)))) +end + +@testset "basicmath" begin + @test p1 * 2 == Matrix(p1) * 2 + @test p1 / 2 == Matrix(p1) / 2 +end + +@testset "memorysafe" begin + @test p1 == PermMatrixCSC([1, 4, 2, 3], [0.1, 0.2, 0.4im, 0.5]) + @test p2 == PermMatrixCSC([2, 1, 4, 3], [0.1, 0.2, 0.4, 0.5]) + @test v == [0.5, 0.3im, 0.2, 1.0] +end + +@testset "sparse" begin + Random.seed!(2) + pm = pmrand(10) + out = zeros(10, 10) + @test LuxurySparse.nnz(pm) == 10 + @test LuxurySparse.findnz(pm)[3] == pm.vals +end + +@testset "identity sparse" begin + p1 = Diagonal(randn(10)) + @test LuxurySparse.nnz(p1) == 10 + @test LuxurySparse.findnz(p1)[3] == p1.diag +end + +@testset "setindex" begin + pm = PermMatrix([3, 2, 4, 1], [0.0, 0.0, 0.0, 0.0]) + pm[3, 4] = 1.0 + @test_throws BoundsError pm[3, 1] = 1.0 + @test pm[3, 4] == 1.0 +end + +@testset "broadcast" begin + pm = PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3]) + res = pm .* 3im + @test res == PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3] .* 3im) && res isa PermMatrix +end + +@testset "fix dense-perm multiplication" begin + A = randn(ComplexF64, 4, 4) + pm = PermMatrix([3, 2, 4, 1], [0.2im, 0.6im, 0.1, 0.3]) + @test A * pm ≈ A * Matrix(pm) +end \ No newline at end of file diff --git a/test/broadcast.jl b/test/broadcast.jl index 7aa3425..4879b70 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -5,7 +5,7 @@ using SparseArrays @testset "broadcast *" begin - @testset "Diagonal .* $(nameof(typeof(M)))" for M in Any[pmrand(3)] + @testset "Diagonal .* $(nameof(typeof(M)))" for M in [[pmrand(3)]..., pmcscrand(3)] M1 = Diagonal(rand(3)) out = M1 .* M @test typeof(out) <: Diagonal @@ -29,11 +29,21 @@ using SparseArrays out = M .* M1 @test typeof(out) <: PermMatrix @test out ≈ M .* Matrix(M1) + + M1 = pmcscrand(3) + out = M1 .* M + @test typeof(out) <: PermMatrixCSC + @test out ≈ Matrix(M1) .* M + + out = M .* M1 + !(M isa PermMatrix) && @test typeof(out) <: PermMatrixCSC + @test out ≈ M .* Matrix(M1) end @testset "IMatrix .* $(nameof(typeof(M)))" for M in Any[ rand(3, 3), pmrand(3), + pmcscrand(3), sprand(3, 3, 0.5), ] eye = IMatrix(3) @@ -77,6 +87,10 @@ end M1 = pmrand(3) @test M1 .- M ≈ Matrix(M1) .- M @test M .- M1 ≈ M .- Matrix(M1) + + M1 = pmcscrand(3) + @test M1 .- M ≈ Matrix(M1) .- M + @test M .- M1 ≈ M .- Matrix(M1) end @testset "IMatrix .* $(nameof(typeof(M)))" for M in Any[ diff --git a/test/iterate.jl b/test/iterate.jl index c3c9d39..81b7ca0 100644 --- a/test/iterate.jl +++ b/test/iterate.jl @@ -3,6 +3,7 @@ using Test, LuxurySparse, SparseArrays, LinearAlgebra @testset "iterate" begin for M in Any[ pmrand(10), + pmcscrand(10), Diagonal(randn(10)), IMatrix(10), randn(10, 10), diff --git a/test/kronecker.jl b/test/kronecker.jl index a2a8792..9df3c20 100644 --- a/test/kronecker.jl +++ b/test/kronecker.jl @@ -1,5 +1,5 @@ using Test, Random, SparseArrays, LinearAlgebra -import LuxurySparse: IMatrix, PermMatrix +import LuxurySparse: IMatrix, PermMatrix, PermMatrixCSC, basetype, AbstractPermMatrix @testset "kron" begin Random.seed!(2) @@ -8,12 +8,15 @@ import LuxurySparse: IMatrix, PermMatrix sp = sprand(ComplexF64, 4, 4, 0.5) ds = rand(ComplexF64, 4, 4) pm = PermMatrix([2, 3, 4, 1], randn(4)) - pm = PermMatrix([2, 3, 4, 1], randn(4)) + pmc = PermMatrixCSC([2, 3, 4, 1], randn(4)) v = [0.5, 0.3im, 0.2, 1.0] dv = Diagonal(v) - for source in Any[p1, sp, ds, dv, pm], - target in Any[p1, sp, ds, dv, pm] + for source in Any[p1, sp, ds, dv, pm, pmc], + target in Any[p1, sp, ds, dv, pm, pmc] + if source isa AbstractPermMatrix && target isa AbstractPermMatrix && basetype(source) != basetype(target) + continue + end lres = kron(source, target) rres = kron(target, source) flres = kron(Matrix(source), Matrix(target)) diff --git a/test/runtests.jl b/test/runtests.jl index 8ee7851..17fca67 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ end @testset "PermMatrix" begin include("PermMatrix.jl") + include("PermMatrixCSC.jl") end @testset "SparseMatrixCOO" begin