From 0ecd99694a7468d137c6470cc6b5809c1e712eed Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 24 May 2023 12:15:27 +0200 Subject: [PATCH 1/7] add specialization for hcat and vcat --- src/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 0942fa5f7..db1f53d21 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -94,6 +94,8 @@ Base.setindex!(D::ColVecs, v::AbstractVector, i) = setindex!(D.X, v, :, i) Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X)) Base.zero(x::ColVecs) = ColVecs(zero(x.X)) +Base.reduce(::typeof(hcat), a::ColVecs) = a.X +Base.reduce(::typeof(vcat), a::ColVecs) = reshape(a.X, :) dim(x::ColVecs) = size(x.X, 1) From e8e55624e097636ac8f62d49ce2707a0278bcc04 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 24 May 2023 12:18:27 +0200 Subject: [PATCH 2/7] test coverage --- test/utils.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/utils.jl b/test/utils.jl index 42784548a..553c1908c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -40,6 +40,9 @@ Y = randn(rng, D, N + 1) DY = ColVecs(Y) + + @test reduce(vcat, DY) == vcat(DY...) + @test reduce(hcat, DY) == hcat(DY...) @test KernelFunctions.pairwise(SqEuclidean(), DX) ≈ pairwise(SqEuclidean(), X; dims=2) @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈ From 24508ef7196caaa4b77e4d85cb48f6ea1c0232ba Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 24 May 2023 12:35:37 +0200 Subject: [PATCH 3/7] bump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c763a8824..1b597b7c9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.55" +version = "0.10.56" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 15f30c9338710eb1a16d944e8e01467cee04c4ac Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 24 May 2023 14:35:31 +0200 Subject: [PATCH 4/7] simplify reshape to vec Co-authored-by: David Widmann --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index db1f53d21..45e861db7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -95,7 +95,7 @@ Base.setindex!(D::ColVecs, v::AbstractVector, i) = setindex!(D.X, v, :, i) Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X)) Base.zero(x::ColVecs) = ColVecs(zero(x.X)) Base.reduce(::typeof(hcat), a::ColVecs) = a.X -Base.reduce(::typeof(vcat), a::ColVecs) = reshape(a.X, :) +Base.reduce(::typeof(vcat), a::ColVecs) = vec(a.X) dim(x::ColVecs) = size(x.X, 1) From 630258c14362356d50cc3c59790dc790809cbaea Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 24 May 2023 16:24:01 +0200 Subject: [PATCH 5/7] add implementation for RowVecs --- src/utils.jl | 2 ++ test/utils.jl | 3 +++ 2 files changed, 5 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 45e861db7..4339b1bdb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -169,6 +169,8 @@ Base.setindex!(D::RowVecs, v::AbstractVector, i) = setindex!(D.X, v, i, :) Base.vcat(a::RowVecs, b::RowVecs) = RowVecs(vcat(a.X, b.X)) Base.zero(x::RowVecs) = RowVecs(zero(x.X)) +Base.reduce(::typeof(hcat), a::RowVecs) = permutedims(a.X) +Base.reduce(::typeof(vcat), a::RowVecs) = vec(permutedims(a.X)) dim(x::RowVecs) = size(x.X, 2) diff --git a/test/utils.jl b/test/utils.jl index 553c1908c..6698c306e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -101,6 +101,9 @@ Y = randn(rng, D + 1, N) DY = RowVecs(Y) + + @test reduce(vcat, DY) == vcat(DY...) + @test reduce(hcat, DY) == hcat(DY...) @test KernelFunctions.pairwise(SqEuclidean(), DX) ≈ pairwise(SqEuclidean(), X; dims=1) @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈ From 94734f0d60a2d5e4eb6dd5d353d03284e9d465da Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 24 May 2023 16:43:34 +0200 Subject: [PATCH 6/7] formatter --- test/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index 6698c306e..6b7869b26 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -101,7 +101,7 @@ Y = randn(rng, D + 1, N) DY = RowVecs(Y) - + @test reduce(vcat, DY) == vcat(DY...) @test reduce(hcat, DY) == hcat(DY...) @test KernelFunctions.pairwise(SqEuclidean(), DX) ≈ From 387837f4867710eb047bd6cb5e2a77f7abb1608e Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Thu, 25 May 2023 09:43:10 +0200 Subject: [PATCH 7/7] copy to be safe --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 4339b1bdb..11372cfe6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -94,8 +94,8 @@ Base.setindex!(D::ColVecs, v::AbstractVector, i) = setindex!(D.X, v, :, i) Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X)) Base.zero(x::ColVecs) = ColVecs(zero(x.X)) -Base.reduce(::typeof(hcat), a::ColVecs) = a.X -Base.reduce(::typeof(vcat), a::ColVecs) = vec(a.X) +Base.reduce(::typeof(hcat), a::ColVecs) = copy(a.X) +Base.reduce(::typeof(vcat), a::ColVecs) = copy(vec(a.X)) dim(x::ColVecs) = size(x.X, 1)