From 695e661e2bf8bac5032f4593da60c75d557e406e Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 21 Dec 2023 14:46:50 +0100 Subject: [PATCH] move back to argparse but keep semantics --- Project.toml | 4 +- src/TargetedEstimation.jl | 7 +- src/cli.jl | 180 ++++++++++++++++++++++++++++++++++++++ src/runner.jl | 2 +- src/sieve_variance.jl | 2 +- src/summary.jl | 2 +- test/runner.jl | 54 ++++++------ test/sieve_variance.jl | 26 ++++-- test/summary.jl | 27 +++--- 9 files changed, 247 insertions(+), 57 deletions(-) create mode 100644 src/cli.jl diff --git a/Project.toml b/Project.toml index cff4a17..7575ae0 100644 --- a/Project.toml +++ b/Project.toml @@ -4,11 +4,11 @@ authors = ["Olivier Labayle"] version = "0.7.4" [deps] +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" -Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" @@ -33,11 +33,11 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" [compat] +ArgParse = "1.1.4" Arrow = "2.5.2" CSV = "0.10" CategoricalArrays = "0.10" Combinatorics = "1.0.2" -Comonicon = "1.0.6" Configurations = "0.17.6" DataFrames = "1.3.4" EvoTrees = "0.16.5" diff --git a/src/TargetedEstimation.jl b/src/TargetedEstimation.jl index 74856d8..a3ee31c 100644 --- a/src/TargetedEstimation.jl +++ b/src/TargetedEstimation.jl @@ -4,6 +4,7 @@ if occursin("Intel", Sys.cpu_info()[1].model) using MKL end +using ArgParse using DataFrames using MLJBase using MLJ @@ -25,7 +26,6 @@ using Tables using Random using YAML using JSON -using Comonicon using Configurations import MLJModelInterface @@ -42,10 +42,9 @@ include("resampling.jl") include(joinpath("models", "glmnet.jl")) include(joinpath("models", "adaptive_interaction_transformer.jl")) include(joinpath("models", "biallelic_snp_encoder.jl")) +include("cli.jl") -@main - -export Runner, tmle, sieve_variance_plateau, make_summary +export Runner, tmle, sieve_variance_plateau, make_summary, main export GLMNetRegressor, GLMNetClassifier export RestrictedInteractionTransformer, BiAllelicSNPEncoder export AdaptiveCV, AdaptiveStratifiedCV, JointStratifiedCV diff --git a/src/cli.jl b/src/cli.jl new file mode 100644 index 0000000..985b088 --- /dev/null +++ b/src/cli.jl @@ -0,0 +1,180 @@ +function cli_settings() + s = ArgParseSettings(description="TMLE CLI.") + + @add_arg_table s begin + "tmle" + action = :command + help = "Run TMLE." + + "svp" + action = :command + help = "Run Sieve Variance Plateau." + + "merge" + action = :command + help = "Merges TMLE outputs together." + end + + @add_arg_table s["tmle"] begin + "dataset" + arg_type = String + required = true + help = "Path to the dataset (either .csv or .arrow)" + + "--estimands" + arg_type = String + help = "A string (`generateATEs`) or a serialized TMLE.Configuration (accepted formats: .json | .yaml | .jls)" + default = "generateATEs" + + "--estimators" + arg_type = String + help = "A julia file containing the estimators to use." + default = "glmnet" + + "--verbosity" + arg_type = Int + default = 0 + help = "Verbosity level" + + "--hdf5-output" + arg_type = String + help = "HDF5 file output." + + "--json-output" + arg_type = String + help = "JSON file output." + + "--jls-output" + arg_type = String + help = "JLS file output." + + "--chunksize" + arg_type = Int + help = "Results are written in batches of size chunksize." + default = 100 + + "--rng" + arg_type = Int + help = "Random seed (Only used for estimands ordering at the moment)." + default = 123 + + "--cache-strategy" + arg_type = String + help = "Caching Strategy for the nuisance functions, any of (`release-unusable`, `no-cache`, `max-size`)." + default = "release-unusable" + + "--sort-estimands" + help = "Sort estimands to minimize cache usage (A brute force approach will be used, resulting in exponentially long sorting time)." + action = :store_true + end + + @add_arg_table s["svp"] begin + "input-prefix" + arg_type = String + help = "Input prefix to HDF5 files generated by the tmle CLI." + + "--out" + arg_type = String + help = "Output filename." + default = "svp.hdf5" + + "--grm-prefix" + arg_type = String + help = "Prefix to the aggregated GRM." + default = "GRM" + + "--verbosity" + arg_type = Int + default = 0 + help = "Verbosity level" + + "--n-estimators" + arg_type = Int + default = 10 + help = "Number of variance estimators to build for each estimate." + + "--max-tau" + arg_type = Float64 + default = 0.8 + help = "Maximum distance between any two individuals." + + "--estimator-key" + arg_type = String + help = "Estimator to use to proceed with sieve variance correction." + default = "TMLE" + end + + @add_arg_table s["merge"] begin + "prefix" + arg_type = String + help = "Prefix to .hdf5 files to be used to create the summary file." + + "--hdf5-output" + arg_type = String + help = "HDF5 file output." + + "--json-output" + arg_type = String + help = "JSON file output." + + "--jls-output" + arg_type = String + help = "JLS file output." + end + + return s +end + + +makeOutput(T::Type, ::Nothing) = T() + +function makeOutput(T::Type, str) + args = split(str, ",") + kwargs = Dict(fn => tryparse(ft, val) for (val, fn, ft) ∈ zip(args, fieldnames(T), fieldtypes(T))) + return T(;kwargs...) +end + +make_outputs(hdf5_string, json_string, jls_tring) = Outputs( + hdf5=makeOutput(HDF5Output, hdf5_string), + json=makeOutput(JSONOutput, json_string), + jls=makeOutput(JLSOutput, jls_tring) +) + +function main(args=ARGS) + settings = parse_args(args, cli_settings()) + cmd = settings["%COMMAND%"] + cmd_settings = settings[cmd] + if cmd ∈ ("tmle", "merge") + outputs = make_outputs(cmd_settings["hdf5-output"], cmd_settings["json-output"], cmd_settings["jls-output"]) + if cmd == "tmle" + tmle(cmd_settings["dataset"]; + estimands=cmd_settings["estimands"], + estimators=cmd_settings["estimators"], + verbosity=cmd_settings["verbosity"], + outputs=outputs, + chunksize=cmd_settings["chunksize"], + rng=cmd_settings["rng"], + cache_strategy=cmd_settings["cache-strategy"], + sort_estimands=cmd_settings["sort-estimands"] + ) + else + make_summary(cmd_settings["prefix"]; + outputs=outputs + ) + end + else + sieve_variance_plateau(cmd_settings["input-prefix"]; + out=cmd_settings["out"], + grm_prefix=cmd_settings["grm-prefix"], + verbosity=cmd_settings["verbosity"], + n_estimators=cmd_settings["n-estimators"], + max_tau=cmd_settings["max-tau"], + estimator_key=cmd_settings["estimator-key"] + ) + end +end + +function julia_main()::Cint + main() + return 0 +end \ No newline at end of file diff --git a/src/runner.jl b/src/runner.jl index a827bd5..1079db7 100644 --- a/src/runner.jl +++ b/src/runner.jl @@ -150,7 +150,7 @@ TMLE CLI. - `-s, --sort_estimands`: Sort estimands to minimize cache usage (A brute force approach will be used, resulting in exponentially long sorting time). """ -@cast function tmle(dataset::String; +function tmle(dataset::String; estimands::String="generateATEs", estimators::String="glmnet", verbosity::Int=0, diff --git a/src/sieve_variance.jl b/src/sieve_variance.jl index 70c446d..ea41eb9 100644 --- a/src/sieve_variance.jl +++ b/src/sieve_variance.jl @@ -226,7 +226,7 @@ Sieve Variance Plateau CLI. - `-m, --max_tau`: Maximum distance between any two individuals. - `-e, --estimator-key`: Estimator to use to proceed with sieve variance correction. """ -@cast function sieve_variance_plateau(input_prefix::String; +function sieve_variance_plateau(input_prefix::String; out::String="svp.hdf5", grm_prefix::String="GRM", verbosity::Int=0, diff --git a/src/summary.jl b/src/summary.jl index 0089df9..a83a383 100644 --- a/src/summary.jl +++ b/src/summary.jl @@ -28,7 +28,7 @@ Combines multiple TMLE .hdf5 output files in a single file. Multiple formats can - `-o, --outputs`: Ouptuts configuration. """ -@cast function make_summary( +function make_summary( prefix::String; outputs::Outputs=Outputs() ) diff --git a/test/runner.jl b/test/runner.jl index f52ead4..a8ce1f2 100644 --- a/test/runner.jl +++ b/test/runner.jl @@ -154,22 +154,24 @@ end @testset "Test tmle: lower p-value threshold only JSON output" begin build_dataset(;n=1000, format="csv") - outputs = TargetedEstimation.Outputs( - json=TargetedEstimation.JSONOutput(filename="output.json", pval_threshold=1e-15) - ) tmpdir = mktempdir(cleanup=true) estimandsfile = joinpath(tmpdir, "configuration.json") configuration = statistical_estimands_only_config() TMLE.write_json(estimandsfile, configuration) estimatorfile = joinpath(CONFIGDIR, "ose_config.jl") datafile = "data.csv" - tmle(datafile; - estimands=estimandsfile, - estimators=estimatorfile, - outputs=outputs) + + # Using the main entry point + main([ + "tmle", + datafile, + "--estimands", estimandsfile, + "--estimators", estimatorfile, + "--json-output", "output.json,1e-15"] + ) # Essential results - results_from_json = TMLE.read_json(outputs.json.filename) + results_from_json = TMLE.read_json("output.json") n_IC_empties = 0 for result in results_from_json if result[:OSE].IC != [] @@ -179,7 +181,7 @@ end @test n_IC_empties > 0 rm(datafile) - rm(outputs.json.filename) + rm("output.json") end @testset "Test tmle: Failing estimands" begin @@ -238,11 +240,6 @@ end @testset "Test tmle: Causal and Composed Estimands" begin build_dataset(;n=1000, format="csv") - outputs = TargetedEstimation.Outputs( - json = TargetedEstimation.JSONOutput(filename="output.json"), - jls = TargetedEstimation.JLSOutput(filename="output.jls"), - hdf5 = TargetedEstimation.HDF5Output(filename="output.hdf5") - ) tmpdir = mktempdir(cleanup=true) estimandsfile = joinpath(tmpdir, "configuration.jls") @@ -251,16 +248,21 @@ end estimatorfile = joinpath(CONFIGDIR, "ose_config.jl") datafile = "data.csv" - tmle(datafile; - estimands=estimandsfile, - estimators=estimatorfile, - outputs=outputs, - chunksize=2 - ) + # Using the main entry point + main([ + "tmle", + datafile, + "--estimands", estimandsfile, + "--estimators", estimatorfile, + "--chunksize", "2", + "--json-output", "output.json", + "--hdf5-output", "output.hdf5", + "--jls-output", "output.jls" + ]) # JLS Output results = [] - open(outputs.jls.filename) do io + open("output.jls") do io while !eof(io) push!(results, deserialize(io)) end @@ -279,19 +281,19 @@ end @test results[3].OSE isa TMLE.ComposedEstimate # JSON Output - results_from_json = TMLE.read_json(outputs.json.filename) + results_from_json = TMLE.read_json("output.json") @test length(results_from_json) == 3 # HDF5 - results_from_json = jldopen(outputs.hdf5.filename) + results_from_json = jldopen("output.hdf5") @test length(results_from_json["Batch_1"]) == 2 composed_result = only(results_from_json["Batch_2"]) @test composed_result.OSE.cov == results[3].OSE.cov rm(datafile) - rm(outputs.jls.filename) - rm(outputs.json.filename) - rm(outputs.hdf5.filename) + rm("output.jls") + rm("output.json") + rm("output.hdf5") end diff --git a/test/sieve_variance.jl b/test/sieve_variance.jl index 62fb7a9..de1465c 100644 --- a/test/sieve_variance.jl +++ b/test/sieve_variance.jl @@ -285,10 +285,13 @@ end TMLE.write_json(estimandsfile_2, config_2) build_tmle_output_file(grm_ids.SAMPLE_ID, estimandsfile_2, "tmle_output_2"; pval=pval) - sieve_variance_plateau("tmle_output"; - grm_prefix=joinpath(TESTDIR, "data", "grm", "test.grm"), - max_tau=0.75 - ) + # Using the main command + main([ + "svp", + "tmle_output", + "--grm-prefix", joinpath(TESTDIR, "data", "grm", "test.grm"), + "--max-tau", "0.75" + ]) io = jldopen("svp.hdf5") # Check τs @@ -332,11 +335,16 @@ end "tmle_output"; estimatorfile=joinpath(TESTDIR, "config", "ose_config.jl") ) - sieve_variance_plateau("tmle_output"; - grm_prefix=joinpath(TESTDIR, "data", "grm", "test.grm"), - max_tau=0.75, - estimator_key="OSE" - ) + + # Using the main command + main([ + "svp", + "tmle_output", + "--grm-prefix", joinpath(TESTDIR, "data", "grm", "test.grm"), + "--max-tau", "0.75", + "--estimator-key", "OSE" + ]) + # The ComposedEstimate std is not updated but each component is. src_results = jldopen("tmle_output.hdf5")["Batch_1"] io = jldopen("svp.hdf5") diff --git a/test/summary.jl b/test/summary.jl index 1b8f5d8..92c454b 100644 --- a/test/summary.jl +++ b/test/summary.jl @@ -41,27 +41,28 @@ include(joinpath(TESTDIR, "testutils.jl")) outputs=tmle_output_2 ) - # Make summary files - outputs = TargetedEstimation.Outputs( - json=TargetedEstimation.JSONOutput(filename="summary.json"), - hdf5=TargetedEstimation.HDF5Output(filename="summary.hdf5"), - jls=TargetedEstimation.JLSOutput(filename="summary.jls") - ) - make_summary("tmle_output", outputs=outputs) + # Using the main entry point + main([ + "merge", + "tmle_output", + "--json-output", "summary.json", + "--jls-output", "summary.jls", + "--hdf5-output", "summary.hdf5" + ]) # Test correctness hdf5file_1 = jldopen("tmle_output_1.hdf5") hdf5file_2 = jldopen("tmle_output_2.hdf5") inputs = vcat(hdf5file_1["Batch_1"], hdf5file_1["Batch_2"], hdf5file_2["Batch_1"]) - json_outputs = TMLE.read_json(outputs.json.filename) + json_outputs = TMLE.read_json("summary.json") jls_outputs = [] - open(outputs.jls.filename) do io + open("summary.jls") do io while !eof(io) push!(jls_outputs, deserialize(io)) end end - hdf5_output = jldopen(outputs.hdf5.filename) + hdf5_output = jldopen("summary.hdf5") hdf5_outputs = vcat((hdf5_output[key] for key in keys(hdf5_output))...) @test length(inputs) == 9 @@ -72,9 +73,9 @@ include(joinpath(TESTDIR, "testutils.jl")) # cleanup rm("tmle_output_1.hdf5") rm("tmle_output_2.hdf5") - rm(outputs.json.filename) - rm(outputs.jls.filename) - rm(outputs.hdf5.filename) + rm("summary.hdf5") + rm("summary.jls") + rm("summary.json") rm(datafile) end