Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix A* implementation #125

Merged
merged 3 commits into from
May 22, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ docs/site/
benchmark/.results/*
benchmark/.tune.jld
*.cov
Manifest.toml
gdalle marked this conversation as resolved.
Show resolved Hide resolved
/Manifest.toml
238 changes: 238 additions & 0 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.7.2"
manifest_format = "2.0"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[deps.ArnoldiMethod]]
deps = ["LinearAlgebra", "Random", "StaticArrays"]
git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae"
uuid = "ec485272-7323-5ecc-a04f-4719b315124d"
version = "0.2.0"

[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[deps.Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "b153278a25dd42c65abbf4e62344f9d22e59191b"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.43.0"

[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"

[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "cc1a8e22627f33c789ab60b36a9132ac050bbf75"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.12"

[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[deps.DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.6"

[[deps.Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.26.3"

[[deps.Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[deps.Graphs]]
deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"]
path = ".."
uuid = "86223c79-3864-5bf0-83f7-82e725a168b6"
version = "1.6.0"

[[deps.IOCapture]]
deps = ["Logging"]
git-tree-sha1 = "377252859f740c217b936cebcd918a44f9b53b59"
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
version = "0.1.1"

[[deps.Inflate]]
git-tree-sha1 = "f5fc07d4e706b84f72d54eedcc1c13d92fb0871c"
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
version = "0.1.2"

[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.3"

[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"

[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"

[[deps.LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"

[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[deps.LinearAlgebra]]
deps = ["Libdl", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.9"

[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"

[[deps.OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.1"

[[deps.Parsers]]
deps = ["Dates"]
git-tree-sha1 = "1285416549ccfcdf0c50d4997a94331e88d68413"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.3.1"

[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[deps.REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[deps.Random]]
deps = ["SHA", "Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[deps.SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[deps.SimpleTraits]]
deps = ["InteractiveUtils", "MacroTools"]
git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231"
uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
version = "0.9.4"

[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[deps.SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.4.4"

[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[[deps.Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"

[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"

[compat]
Documenter = "~0.26.2"
55 changes: 23 additions & 32 deletions src/shortestpaths/astar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,34 @@
function reconstruct_path!(total_path, # a vector to be filled with the shortest path
came_from, # a vector holding the parent of each node in the A* exploration
end_idx, # the end vertex
g) # the graph

E = edgetype(g)
g, # the graph
edgetype_to_return::Type{E}=edgetype(g)) where {E<:AbstractEdge}
curr_idx = end_idx
while came_from[curr_idx] != curr_idx
pushfirst!(total_path, E(came_from[curr_idx], curr_idx))
pushfirst!(total_path, edgetype_to_return(came_from[curr_idx], curr_idx))
curr_idx = came_from[curr_idx]
end
end

function a_star_impl!(g, # the graph
goal, # the end vertex
open_set, # an initialized heap containing the active vertices
closed_set, # an (initialized) color-map to indicate status of vertices
g_score, # a vector holding g scores for each node
f_score, # a vector holding f scores for each node
came_from, # a vector holding the parent of each node in the A* exploration
distmx,
heuristic)

E = edgetype(g)
total_path = Vector{E}()
heuristic,
edgetype_to_return::Type{E}) where {E<:AbstractEdge}
total_path = Vector{edgetype_to_return}()

@inbounds while !isempty(open_set)
current = dequeue!(open_set)

if current == goal
reconstruct_path!(total_path, came_from, current, g)
reconstruct_path!(total_path, came_from, current, g, edgetype_to_return)
return total_path
end

closed_set[current] = true

for neighbor in Graphs.outneighbors(g, current)
closed_set[neighbor] && continue

tentative_g_score = g_score[current] + distmx[current, neighbor]

if tentative_g_score < g_score[neighbor]
Expand All @@ -56,38 +48,37 @@ function a_star_impl!(g, # the graph
end

"""
a_star(g, s, t[, distmx][, heuristic])
a_star(g, s, t[, distmx][, heuristic][, edgetype_to_return])

Compute a shortest path using the [A* search algorithm](http://en.wikipedia.org/wiki/A%2A_search_algorithm).

Return a vector of edges comprising the shortest path between vertices `s` and `t`
using the [A* search algorithm](http://en.wikipedia.org/wiki/A%2A_search_algorithm).
An optional heuristic function and edge distance matrix may be supplied. If missing,
the distance matrix is set to [`Graphs.DefaultDistance`](@ref) and the heuristic is set to
`n -> 0`.
# Arguments
- `g::AbstractGraph`: the graph
- `s::Integer`: the source vertex
- `t::Integer`: the target vertex
- `distmx::AbstractMatrix`: an optional (possibly sparse) `n × n` matrix of edge weights. It is set to `weights(g)` by default (which itself falls back on [`Graphs.DefaultDistance`](@ref)).
- `heuristic`: an optional callable mapping each vertex to a lower estimate of the remaining distance from `v` to `t`. It is set to `v -> 0` by default (which corresponds to Dijkstra's algorithm)
- `edgetype_to_return::Type{E}`: the eltype `E<:AbstractEdge` of the vector of edges returned. It is set to `edgetype(g)` by default. Note that the two-argument constructor `E(u, v)` must be defined, even for weighted edges: if it isn't, consider using `E = Graphs.SimpleEdge`.
"""
function a_star(g::AbstractGraph{U}, # the g
s::Integer, # the start vertex
t::Integer, # the end vertex
distmx::AbstractMatrix{T}=weights(g),
heuristic::Function=n -> zero(T)) where {T, U}

E = Edge{eltype(g)}

heuristic=n -> zero(T),
gdalle marked this conversation as resolved.
Show resolved Hide resolved
edgetype_to_return::Type{E}=edgetype(g)) where {T, U, E<:AbstractEdge}
gdalle marked this conversation as resolved.
Show resolved Hide resolved
# if we do checkbounds here, we can use @inbounds in a_star_impl!
checkbounds(distmx, Base.OneTo(nv(g)), Base.OneTo(nv(g)))

open_set = PriorityQueue{Integer, T}()
open_set = PriorityQueue{U, T}()
enqueue!(open_set, s, 0)

closed_set = zeros(Bool, nv(g))

g_score = fill(Inf, nv(g))
g_score[s] = 0

f_score = fill(Inf, nv(g))
f_score[s] = heuristic(s)

came_from = fill(-one(s), nv(g))
came_from[s] = s

a_star_impl!(g, t, open_set, closed_set, g_score, f_score, came_from, distmx, heuristic)
a_star_impl!(
g, t, open_set, g_score, came_from, distmx, heuristic, edgetype_to_return
)
end
7 changes: 7 additions & 0 deletions test/shortestpaths/astar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@
g = complete_graph(4)
w = float([1 1 1 4; 1 1 1 1; 1 1 1 1; 4 1 1 1])
@test length(a_star(g, 1, 4, w)) == 2

# test for #120
struct MyFavoriteEdgeType <: AbstractEdge{Int}
s::Int
d::Int
end
@test eltype(a_star(g, 1, 4, w, n -> 0, MyFavoriteEdgeType)) == MyFavoriteEdgeType
end