From 2af34885624e9a6a8c0bdbb71cca040d99c7bcda Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Sun, 29 Sep 2024 23:17:24 +0200 Subject: [PATCH] WIP HYPRE extension --- .github/workflows/Check.yml | 5 +- Project.toml | 8 +- docs/Manifest.toml | 82 +++++++++++++--- docs/Project.toml | 2 + ext/FerriteHYPRE.jl | 93 ++++++++++++++++++ test/hypre_mpi.jl | 187 ++++++++++++++++++++++++++++++++++++ 6 files changed, 361 insertions(+), 16 deletions(-) create mode 100644 ext/FerriteHYPRE.jl create mode 100644 test/hypre_mpi.jl diff --git a/.github/workflows/Check.yml b/.github/workflows/Check.yml index 08dc6b83e0..ac038093a7 100644 --- a/.github/workflows/Check.yml +++ b/.github/workflows/Check.yml @@ -32,11 +32,12 @@ jobs: PackageSpec(name = "ExplicitImports", version = "1.6"), PackageSpec(name = "Metis"), PackageSpec(name = "BlockArrays"), + PackageSpec(name = "HYPRE"), ]) - name: ExplicitImports.jl code checks shell: julia --project {0} run: | - using Ferrite, ExplicitImports, Metis, BlockArrays + using Ferrite, ExplicitImports, Metis, BlockArrays, HYPRE # Check Ferrite allow_unanalyzable = (ColoringAlgorithm,) # baremodules check_no_implicit_imports(Ferrite; allow_unanalyzable) @@ -44,7 +45,7 @@ jobs: check_all_qualified_accesses_via_owners(Ferrite) check_no_self_qualified_accesses(Ferrite) # Check extension modules - for ext in (:FerriteBlockArrays, :FerriteMetis) + for ext in (:FerriteBlockArrays, :FerriteHYPRE, :FerriteMetis) extmod = Base.get_extension(Ferrite, ext) if extmod !== nothing check_no_implicit_imports(extmod) diff --git a/Project.toml b/Project.toml index 52bb20f614..ff58b676ad 100644 --- a/Project.toml +++ b/Project.toml @@ -17,15 +17,20 @@ WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192" [weakdeps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b" [extensions] FerriteBlockArrays = "BlockArrays" +FerriteHYPRE = ["HYPRE", "MPI"] FerriteMetis = "Metis" [compat] BlockArrays = "0.16, 1" EnumX = "1" +HYPRE = "1.3" +MPI = "0.20" Metis = "1.3" NearestNeighbors = "0.4" OrderedCollections = "1" @@ -37,6 +42,7 @@ julia = "1.9" [extras] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" FerriteGmsh = "4f95f4f8-b27c-4ae5-9a39-ea55e634e36b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -55,4 +61,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [targets] -test = ["BlockArrays", "Downloads", "FerriteGmsh", "ForwardDiff", "Gmsh", "IterativeSolvers", "Metis", "Pkg", "NBInclude", "OhMyThreads", "ProgressMeter", "Random", "SHA", "TaskLocalValues", "Test", "TimerOutputs", "Logging"] +test = ["BlockArrays", "Downloads", "FerriteGmsh", "ForwardDiff", "Gmsh", "HYPRE", "IterativeSolvers", "Metis", "Pkg", "NBInclude", "OhMyThreads", "ProgressMeter", "Random", "SHA", "TaskLocalValues", "Test", "TimerOutputs", "Logging"] diff --git a/docs/Manifest.toml b/docs/Manifest.toml index a4d4e460ea..2a7bc845d2 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "4e52f4aa4cee9f66ec4f633f0ae538fbd227ac5e" +project_hash = "de8ffdfec9d1f3a15bebd18bcb5ace52305c3d9c" [[deps.ADTypes]] git-tree-sha1 = "eea5d80188827b35333801ef97a40c2ed653b081" @@ -183,6 +183,11 @@ git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" version = "1.0.8+1" +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + [[deps.CPUSummary]] deps = ["CpuId", "IfElse", "PrecompileTools", "Static"] git-tree-sha1 = "5a97e67919535d6841172016c9530fd69494e5ec" @@ -366,9 +371,9 @@ version = "1.9.1" [[deps.DiffEqBase]] deps = ["ArrayInterface", "ConcreteStructs", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "FastClosures", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "SciMLStructures", "Setfield", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces"] -git-tree-sha1 = "6b1af0db32958b200b7b1745796432e75821bf48" +git-tree-sha1 = "ada2a9faba0e365dca3cc456b5eca94cf3887ac3" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" -version = "6.155.2" +version = "6.156.1" [deps.DiffEqBase.extensions] DiffEqBaseCUDAExt = "CUDA" @@ -494,9 +499,9 @@ uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" version = "1.0.4" [[deps.EnzymeCore]] -git-tree-sha1 = "ee11500b17d87b22bc638e9ed8c71a7478c53d61" +git-tree-sha1 = "9c3a42611e525352e9ad5e4134ddca5c692ff209" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" -version = "0.8.3" +version = "0.8.4" weakdeps = ["Adapt"] [deps.EnzymeCore.extensions] @@ -580,10 +585,13 @@ version = "1.0.0" [deps.Ferrite.extensions] FerriteBlockArrays = "BlockArrays" + FerriteHYPRE = ["HYPRE", "MPI"] FerriteMetis = "Metis" [deps.Ferrite.weakdeps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b" [[deps.FerriteGmsh]] @@ -772,9 +780,9 @@ version = "1.3.14+0" [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" +git-tree-sha1 = "1dc470db8b1131cfc7fb4c115de89fe391b9e780" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.11.2" +version = "1.12.0" [[deps.Grisu]] git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" @@ -793,6 +801,28 @@ git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" version = "1.10.8" +[[deps.HYPRE]] +deps = ["CEnum", "HYPRE_jll", "Libdl", "MPI"] +git-tree-sha1 = "1594ec3b54b5736531e0f0c36b036c9140ce4141" +uuid = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" +version = "1.6.0" + + [deps.HYPRE.extensions] + HYPREPartitionedArrays = ["PartitionedArrays", "SparseArrays", "SparseMatricesCSR"] + HYPRESparseArrays = "SparseArrays" + HYPRESparseMatricesCSR = ["SparseArrays", "SparseMatricesCSR"] + + [deps.HYPRE.weakdeps] + PartitionedArrays = "5a9dfac6-5c52-46f7-8278-5e2210713be9" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1" + +[[deps.HYPRE_jll]] +deps = ["Artifacts", "JLLWrappers", "LAPACK_jll", "LazyArtifacts", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenBLAS_jll", "OpenMPI_jll", "Pkg", "TOML"] +git-tree-sha1 = "b77d3eca75f8442e034ccf415c87405a49e77985" +uuid = "0a602bbd-b08b-5d75-8d32-0de6eef44785" +version = "2.23.1+1" + [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll"] git-tree-sha1 = "401e4f3f30f43af2c8478fc008da50096ea5240f" @@ -807,9 +837,9 @@ version = "0.1.17" [[deps.Hwloc_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "378267a829b1e17423d32ce6d905f37a12c1fd84" +git-tree-sha1 = "dd3b49277ec2bb2c6b94eb1604d4d0616016f7a6" uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.11.1+1" +version = "2.11.2+0" [[deps.IOCapture]] deps = ["Logging", "Random"] @@ -928,6 +958,12 @@ git-tree-sha1 = "170b660facf5df5de098d866564877e119141cbd" uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" version = "3.100.2+0" +[[deps.LAPACK_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "libblastrampoline_jll"] +git-tree-sha1 = "1b25c30fa49db281be615793e0f85282a8f22822" +uuid = "51474c39-65e3-53ba-86ba-03b1b862ec14" +version = "3.12.0+2" + [[deps.LERC_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" @@ -1209,6 +1245,20 @@ git-tree-sha1 = "70a59df96945782bb0d43b56d0fbfdf1ce2e4729" uuid = "86086c02-e288-5929-a127-40944b0018b7" version = "5.6.0+0" +[[deps.MPI]] +deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "PkgVersion", "PrecompileTools", "Requires", "Serialization", "Sockets"] +git-tree-sha1 = "892676019c58f34e38743bc989b0eca5bce5edc5" +uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195" +version = "0.20.22" + + [deps.MPI.extensions] + AMDGPUExt = "AMDGPU" + CUDAExt = "CUDA" + + [deps.MPI.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + [[deps.MPICH_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" @@ -1671,6 +1721,12 @@ deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", " uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" version = "1.10.0" +[[deps.PkgVersion]] +deps = ["Pkg"] +git-tree-sha1 = "f9501cc0430a26bc3d156ae1b5b0c1b47af4d6da" +uuid = "eebad327-c553-4316-9ea0-9fa01ccd7688" +version = "0.3.3" + [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] git-tree-sha1 = "6e55c6841ce3411ccb3457ee52fc48cb698d6fb0" @@ -1889,9 +1945,9 @@ version = "0.6.43" [[deps.SciMLBase]] deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "Expronicon", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "71857d6bab17e7ac6802d86ffcc75423b8c1d812" +git-tree-sha1 = "ce6fb9b0d756446d902e4495f2447fa2ebfbb1f4" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.54.0" +version = "2.54.2" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" @@ -2543,9 +2599,9 @@ version = "1.18.0+0" [[deps.libpng_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" +git-tree-sha1 = "b70c870239dc3d7bc094eb2d6be9b73d27bef280" uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.43+1" +version = "1.6.44+0" [[deps.libvorbis_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] diff --git a/docs/Project.toml b/docs/Project.toml index ead8045637..431bd55686 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,11 +8,13 @@ FerriteGmsh = "4f95f4f8-b27c-4ae5-9a39-ea55e634e36b" FerriteMeshParser = "0f8c756f-80dd-4a75-85c6-b0a5ab9d4620" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Gmsh = "705231aa-382f-11e9-3f0c-b7cb4346fdeb" +HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" Optim = "429524aa-4258-5aef-a3af-852621145aeb" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" diff --git a/ext/FerriteHYPRE.jl b/ext/FerriteHYPRE.jl new file mode 100644 index 0000000000..b33fea5913 --- /dev/null +++ b/ext/FerriteHYPRE.jl @@ -0,0 +1,93 @@ +module FerriteHYPRE + +using Ferrite: Ferrite, ConstraintHandler, SparsityPattern +using HYPRE.LibHYPRE: @check, HYPRE_BigInt, HYPRE_Complex, HYPRE_IJMatrixAddToValues, + HYPRE_IJMatrixSetValues, HYPRE_IJVectorAddToValues, HYPRE_Int +using HYPRE: HYPRE, HYPREMatrix, HYPREVector +using MPI: MPI + +################################### +## Creating the sparsity pattern ## +################################### + +function Ferrite.allocate_matrix(::Type{<:HYPREMatrix}, sp::SparsityPattern) + # Create a new matrix + ilower = HYPRE_BigInt(1) + iupper = HYPRE_BigInt(Ferrite.getnrows(sp)) + @assert Ferrite.getnrows(sp) == Ferrite.getncols(sp) + A = HYPREMatrix(MPI.COMM_SELF, ilower, iupper) + # Add the rows, one at a time + nrows = HYPRE_Int(1) + ncols = HYPRE_Int[0] + rows = HYPRE_BigInt[0] + cols = HYPRE_BigInt[] + values = HYPRE_Complex[] + for (rowidx, colidxs) in zip(1:Ferrite.getnrows(sp), Ferrite.eachrow(sp)) + rows[1] = rowidx + n = length(colidxs) + ncols[1] = n + resize!(cols, n) + copyto!(cols, colidxs) + resize!(values, n) + fill!(values, 0) + @check HYPRE_IJMatrixSetValues(A, nrows, ncols, rows, cols, values) + end + HYPRE.Internals.assemble_matrix(A) + return A +end + +############################################ +### HYPREAssembler and associated methods ## +############################################ + +struct HYPREAssembler <: Ferrite.AbstractAssembler + A::HYPRE.HYPREAssembler +end + +Ferrite.matrix_handle(a::HYPREAssembler) = a.A.A.A # :) +Ferrite.vector_handle(a::HYPREAssembler) = a.A.b.b # :) + +function Ferrite.start_assemble(K::HYPREMatrix, f::HYPREVector) + return HYPREAssembler(HYPRE.start_assemble!(K, f)) +end + +function Ferrite.assemble!(a::HYPREAssembler, dofs::AbstractVector{<:Integer}, ke::AbstractMatrix, fe::AbstractVector) + HYPRE.assemble!(a.A, dofs, ke, fe) + return +end + +function Ferrite.finish_assemble(assembler::HYPREAssembler) + HYPRE.finish_assemble!(assembler.A) + return +end + +function Ferrite.apply!( + ::HYPREMatrix, ::Union{HYPREVector, AbstractVector}, ::ConstraintHandler + ) + msg = "Condensation of constraints with `apply!` after assembling not supported " * + "for HYPREMatrix, use local condensation with `apply_assemble!` instead." + error(msg) +end + + +### Methods for arrayutils.jl ## + +function Ferrite.addindex!(A::HYPREMatrix, v, i::Int, j::Int) + nrows = HYPRE_Int(1) + ncols = Ref{HYPRE_Int}(1) + rows = Ref{HYPRE_BigInt}(i) + cols = Ref{HYPRE_BigInt}(j) + values = Ref{HYPRE_Complex}(v) + @check HYPRE_IJMatrixAddToValues(A.ijmatrix, nrows, ncols, rows, cols, values) + return A +end + +function Ferrite.addindex!(b::HYPREVector, v, i::Int) + nvalues = HYPRE_Int(1) + indices = Ref{HYPRE_BigInt}(i) + values = Ref{HYPRE_Complex}(v) + @check HYPRE_IJVectorAddToValues(b.ijvector, nvalues, indices, values) + return b +end + +end # module FerriteHYPRE diff --git a/test/hypre_mpi.jl b/test/hypre_mpi.jl new file mode 100644 index 0000000000..e0eb413cff --- /dev/null +++ b/test/hypre_mpi.jl @@ -0,0 +1,187 @@ +# Run with e.g. +# mpiexecjl --project=docs -np 4 julia test/hypre_mpi.jl 1000 +using Ferrite, MPI, HYPRE, Metis, TimerOutputs + +# Initialize MPI and HYPRE +MPI.Init() +HYPRE.Init() + +const comm = MPI.COMM_WORLD +const root = 0 + 1 +const rank = MPI.Comm_rank(comm) + 1 +const comm_size = MPI.Comm_size(comm) + +# No changes from serial solve +function assemble_element!(Ke::Matrix, fe::Vector, cellvalues::CellValues) + n_basefuncs = getnbasefunctions(cellvalues) + fill!(Ke, 0) + fill!(fe, 0) + for q_point in 1:getnquadpoints(cellvalues) + dΩ = getdetJdV(cellvalues, q_point) + for i in 1:n_basefuncs + δu = shape_value(cellvalues, q_point, i) + ∇δu = shape_gradient(cellvalues, q_point, i) + fe[i] += δu * dΩ + for j in 1:n_basefuncs + ∇u = shape_gradient(cellvalues, q_point, j) + Ke[i, j] += (∇δu ⋅ ∇u) * dΩ + end + end + end + return Ke, fe +end + +# No changes from serial solve other than looping over owned cells +function assemble_global(cellvalues::CellValues, A::HYPREMatrix, b::HYPREVector, dh::DofHandler, ch::ConstraintHandler) + n_basefuncs = getnbasefunctions(cellvalues) + Ke = zeros(n_basefuncs, n_basefuncs) + fe = zeros(n_basefuncs) + assembler = start_assemble(A, b) + for cell in CellIterator(dh, getcellset(dh.grid, "proc-$(rank)")) + reinit!(cellvalues, cell) + assemble_element!(Ke, fe, cellvalues) + apply_assemble!(assembler, ch, celldofs(cell), Ke, fe) + end + # TODO: Should maybe be finish_assemble! (with !) + finish_assemble(assembler) + return A, b +end + +# Partition the grid using Metis.jl +function partition_grid!(grid) + # TODO: Can this be done on all ranks? Not sure if Metis is deterministic. + if rank == root + cell_connectivity = Ferrite.create_incidence_matrix(grid) + parts = Metis.partition(cell_connectivity, comm_size) + else + parts = Vector{Cint}(undef, getncells(grid)) + end + MPI.Bcast!(parts, comm) + + # Create the cell sets based on the Metis partition + sets = [Set{Int}() for _ in 1:comm_size] + for (cell_id, part_id) in pairs(parts) + push!(sets[part_id], cell_id) + end + for p in 1:comm_size + addcellset!(grid, "proc-$p", sets[p]) + end + return grid +end + +function main(n) + + reset_timer!() + + # FE Values + ip = Lagrange{RefQuadrilateral, 1}() + qr = QuadratureRule{RefQuadrilateral}(2) + cellvalues = CellValues(qr, ip) + + # Create the grid + @timeit "Generate grid" grid = generate_grid(Quadrilateral, (n, n)) + + # Partition the mesh + @timeit "Partition grid" partition_grid!(grid) + + # Create the DofHandler + @timeit "Create DofHandler" begin + dh = DofHandler(grid) + add!(dh, :u, ip) + close!(dh) + end + + # Renumber dofs by part + @timeit "Renumber DoFs by processor" begin + all = Set{Int}() + sets = [Set{Int}() for _ in 1:comm_size] + cc = CellCache(dh) + for p in 1:comm_size + set = sets[p] + for cell_id in getcellset(grid, "proc-$p") + reinit!(cc, cell_id) + union!(set, cc.dofs) + end + setdiff!(set, all) + union!(all, set) + end + iperm = Int[] + rank_dof_ranges = UnitRange{Int}[] + for set in sets + push!(rank_dof_ranges, (length(iperm) + 1):(length(iperm) + length(set))) + append!(iperm, sort!(collect(set))) + end + perm = invperm(iperm) + renumber!(dh, perm) + rank_dof_range = rank_dof_ranges[rank] + end + + + # ConstraintHandler + @timeit "Create ConstraintHandler" begin + ch = ConstraintHandler(dh) + ∂Ω = union( + getfacetset(grid, "left"), + getfacetset(grid, "right"), + getfacetset(grid, "top"), + getfacetset(grid, "bottom"), + ) + dbc = Dirichlet(:u, ∂Ω, (x, t) -> 0) + add!(ch, dbc) + close!(ch) + end + + + # Set up HYPRE arrays + ilower, iupper = extrema(rank_dof_range) + A = HYPREMatrix(comm, ilower, iupper) + b = HYPREVector(comm, ilower, iupper) + + # Assemble + @timeit "Assembly ($(length(getcellset(grid, "proc-$(rank)"))) of $(getncells(grid)) elements)" begin + assemble_global(cellvalues, A, b, dh, ch) + end + + # Set up solver and solve + @timeit "HYPRE setup and solve" begin + precond = HYPRE.BoomerAMG() + solver = HYPRE.PCG(; Precond = precond) + xh = HYPRE.solve(solver, A, b) + end + + # Copy solution from HYPRE to Julia + @timeit "Collect solution to root for VTK output" begin + x = Vector{Float64}(undef, length(rank_dof_range)) + copy!(x, xh) + + # Collect to root rank + if rank == root + X = Vector{Float64}(undef, ndofs(dh)) + counts = length.(rank_dof_ranges) + MPI.Gatherv!(x, VBuffer(X, counts), comm) + else + MPI.Gatherv!(x, nothing, comm) + end + end + + # Exporting to VTK + if rank == root + @timeit "VTK export" begin + VTKGridFile("heat_equation", dh) do vtk + write_solution(vtk, dh, X) + end + end + end + + # Print the timer on root proc + rank == root && print_timer() + + return +end + +# Run it! +if abspath(PROGRAM_FILE) == @__FILE__ + n = parse(Int, get(ARGS, 1, "100")) + main(n) + main(n) +end