Skip to content

Commit

Permalink
switch to GroupPerm for efficient sortperm (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
piever authored Mar 21, 2019
1 parent bb17d3d commit 5962572
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 86 deletions.
4 changes: 4 additions & 0 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ function __init__()
Requires.@require Tables="bd369af6-aec1-5ad0-b16a-f7cc5008161c" include("tables.jl")
Requires.@require WeakRefStrings="ea10d353-3f73-51f8-a26c-33c1cb351aa5" begin
fastpermute!(v::WeakRefStrings.StringArray, p::AbstractVector) = permute!(v, p)
@inline function roweq(a::WeakRefStrings.StringArray{String}, i, j)
weaksa = convert(WeakRefStrings.StringArray{WeakRefStrings.WeakRefString{UInt8}}, a)
@inbounds isequal(weaksa[i], weaksa[j])
end
end
end

Expand Down
106 changes: 34 additions & 72 deletions src/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,99 +4,61 @@ fastpermute!(v::AbstractArray, p::AbstractVector) = copyto!(v, v[p])
fastpermute!(v::StructArray, p::AbstractVector) = permute!(v, p)
fastpermute!(v::PooledArray, p::AbstractVector) = permute!(v, p)

optimize_isequal(v::AbstractArray) = v
optimize_isequal(v::PooledArray) = v.refs
optimize_isequal(v::StructArray{<:Union{Tuple, NamedTuple}}) = StructArray(map(optimize_isequal, fieldarrays(v)))

recover_original(v::AbstractArray, el) = el
recover_original(v::PooledArray, el) = v.pool[el]
recover_original(v::StructArray{T}, el) where {T<:Union{Tuple, NamedTuple}} = T(map(recover_original, fieldarrays(v), el))

pool(v::AbstractArray, condition = !isbitstypeeltype) = condition(v) ? convert(PooledArray, v) : v
pool(v::StructArray, condition = !isbitstypeeltype) = replace_storage(t -> pool(t, condition), v)

function Base.permute!(c::StructArray, p::AbstractVector)
foreachfield(v -> fastpermute!(v, p), c)
return c
end

struct TiedIndices{T<:AbstractVector, V<:AbstractVector{<:Integer}, U<:AbstractUnitRange}
vec::T
perm::V
pool(v::AbstractArray, condition = !isbitstypeeltype) = condition(v) ? convert(PooledArray, v) : v
pool(v::StructArray, condition = !isbitstypeeltype) = replace_storage(t -> pool(t, condition), v)

struct GroupPerm{V<:AbstractVector, P<:AbstractVector{<:Integer}, U<:AbstractUnitRange}
vec::V
perm::P
within::U
end

TiedIndices(vec::AbstractVector, perm=sortperm(vec)) =
TiedIndices(vec, perm, axes(vec, 1))

Base.IteratorSize(::Type{<:TiedIndices}) = Base.SizeUnknown()
GroupPerm(vec, perm=sortperm(vec)) = GroupPerm(vec, perm, axes(vec, 1))

Base.eltype(::Type{<:TiedIndices{T}}) where {T} =
Pair{eltype(T), UnitRange{Int}}
Base.sortperm(g::GroupPerm) = g.perm

Base.sortperm(t::TiedIndices) = t.perm

function Base.iterate(n::TiedIndices, i = first(n.within))
vec, perm = n.vec, n.perm
l = last(n.within)
function Base.iterate(g::GroupPerm, i = first(g.within))
vec, perm = g.vec, g.perm
l = last(g.within)
i > l && return nothing
@inbounds row = vec[perm[i]]
@inbounds pi = perm[i]
i1 = i+1
@inbounds while i1 <= l && isequal(row, vec[perm[i1]])
@inbounds while i1 <= l && roweq(vec, pi, perm[i1])
i1 += 1
end
return (row => i:(i1-1), i1)
return (i:(i1-1), i1)
end

"""
`tiedindices(v, perm=sortperm(v))`
Given an abstract vector `v` and a permutation vector `perm`, return an iterator
of pairs `val => range` where `range` is a maximal interval such as `v[perm[range]]`
is constant: `val` is the unique value of `v[perm[range]]`.
"""
tiedindices(v, perm=sortperm(v)) = TiedIndices(v, perm)

"""
`maptiedindices(f, v, perm)`
Given a function `f`, compute the iterator `tiedindices(v, perm)` and return
in iterable object which yields `f(val, idxs)` where `val => idxs` are the pairs
iterated by `tiedindices(v, perm)`.
## Examples
`maptiedindices` is a low level building block that can be used to define grouping
operators. For example:
```jldoctest
julia> function mygroupby(f, keys, data)
perm = sortperm(keys)
StructArrays.maptiedindices(keys, perm) do key, idxs
key => f(data[perm[idxs]])
end
end
mygroupby (generic function with 1 method)
julia> StructArray(mygroupby(sum, [1, 2, 1, 3], [1, 4, 10, 11]))
3-element StructArray{Pair{Int64,Int64},1,NamedTuple{(:first, :second),Tuple{Array{Int64,1},Array{Int64,1}}}}:
1 => 11
2 => 4
3 => 11
```
"""
function maptiedindices(f, v, perm)
fast_v = optimize_isequal(v)
itr = TiedIndices(fast_v, perm)
(f(recover_original(v, val), idxs) for (val, idxs) in itr)
Base.IteratorSize(::Type{<:GroupPerm}) = Base.SizeUnknown()

Base.eltype(::Type{<:GroupPerm}) = UnitRange{Int}

@inline roweq(x::AbstractVector, i, j) = (@inbounds eq=isequal(x[i], x[j]); eq)
@inline roweq(a::PooledArray, i, j) = (@inbounds x=a.refs[i] == a.refs[j]; x)
@generated function roweq(c::StructVector{D,C}, i, j) where {D,C}
N = fieldcount(C)
ex = :(roweq(getfield(fieldarrays(c),1), i, j))
for n in 2:N
ex = :(($ex) && (roweq(getfield(fieldarrays(c),$n), i, j)))
end
ex
end

function uniquesorted(keys, perm=sortperm(keys))
maptiedindices((key, _) -> key, keys, perm)
(keys[perm[idxs[1]]] for idxs in GroupPerm(keys, perm))
end

function finduniquesorted(keys, perm=sortperm(keys))
maptiedindices((key, idxs) -> (key => perm[idxs]), keys, perm)
func = function (idxs)
p_idxs = perm[idxs]
return keys[p_idxs[1]] => p_idxs
end
(func(idxs) for idxs in GroupPerm(keys, perm))
end

function Base.sortperm(c::StructVector{T}) where {T<:Union{Tuple, NamedTuple}}
Expand Down Expand Up @@ -126,7 +88,7 @@ function refine_perm!(p, cols, c, x, y′, lo, hi)
order = Perm(Forward, y′)
y = something(forward_vec(order), y′)
nc = length(cols)
for (_, idxs) in TiedIndices(optimize_isequal(x), p, lo:hi)
for idxs in GroupPerm(x, p, lo:hi)
i, i1 = extrema(idxs)
if i1 > i
sort_sub_by!(p, i, i1, y, order, temp)
Expand Down
29 changes: 15 additions & 14 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ end
@test v_pooled == StructArrays.pool(v)
end

@testset "optimize_isequal" begin
@testset "roweq" begin
a = ["a", "b", "a", "a"]
b = PooledArrays.PooledArray(["x", "y", "z", "x"])
s = StructArray((a, b))
t = StructArrays.optimize_isequal(s)
@test t[1] != t[2]
@test t[1] != t[3]
@test t[1] == t[4]
@test t[1][2] isa Integer
@test StructArrays.recover_original(s, t[1]) == s[1]
@test StructArrays.recover_original(s, t[2]) == s[2]
@test StructArrays.recover_original(s, t[3]) == s[3]
@test StructArrays.recover_original(s, t[4]) == s[4]
@test StructArrays.roweq(s, 1, 1)
@test !StructArrays.roweq(s, 1, 2)
@test !StructArrays.roweq(s, 1, 3)
@test StructArrays.roweq(s, 1, 4)
strs = WeakRefStrings.StringArray(["a", "a", "b"])
@test StructArrays.roweq(strs, 1, 1)
@test StructArrays.roweq(strs, 1, 2)
@test !StructArrays.roweq(strs, 1, 3)
@test !StructArrays.roweq(strs, 2, 3)
end

@testset "namedtuple" begin
Expand Down Expand Up @@ -95,11 +95,12 @@ end

@testset "iterators" begin
c = [1, 2, 3, 1, 1]
d = StructArrays.tiedindices(c)
@test eltype(d) == Pair{Int, UnitRange{Int}}
d = StructArrays.GroupPerm(c)
@test eltype(d) == UnitRange{Int}
@test Base.IteratorEltype(d) == Base.HasEltype()
@test sortperm(d) == sortperm(c)
s = collect(d)
@test first.(s) == [1, 2, 3]
@test last.(s) == [1:3, 4:4, 5:5]
@test s == [1:3, 4:4, 5:5]
t = collect(StructArrays.finduniquesorted(c))
@test first.(t) == [1, 2, 3]
@test last.(t) == [[1, 4, 5], [2], [3]]
Expand Down

0 comments on commit 5962572

Please sign in to comment.