Skip to content

Commit

Permalink
move back to argparse but keep semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Dec 21, 2023
1 parent 5e9662b commit 695e661
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 57 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ authors = ["Olivier Labayle"]
version = "0.7.4"

[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
Expand All @@ -33,11 +33,11 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
ArgParse = "1.1.4"
Arrow = "2.5.2"
CSV = "0.10"
CategoricalArrays = "0.10"
Combinatorics = "1.0.2"
Comonicon = "1.0.6"
Configurations = "0.17.6"
DataFrames = "1.3.4"
EvoTrees = "0.16.5"
Expand Down
7 changes: 3 additions & 4 deletions src/TargetedEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ if occursin("Intel", Sys.cpu_info()[1].model)
using MKL
end

using ArgParse
using DataFrames
using MLJBase
using MLJ
Expand All @@ -25,7 +26,6 @@ using Tables
using Random
using YAML
using JSON
using Comonicon
using Configurations

import MLJModelInterface
Expand All @@ -42,10 +42,9 @@ include("resampling.jl")
include(joinpath("models", "glmnet.jl"))
include(joinpath("models", "adaptive_interaction_transformer.jl"))
include(joinpath("models", "biallelic_snp_encoder.jl"))
include("cli.jl")

@main

export Runner, tmle, sieve_variance_plateau, make_summary
export Runner, tmle, sieve_variance_plateau, make_summary, main
export GLMNetRegressor, GLMNetClassifier
export RestrictedInteractionTransformer, BiAllelicSNPEncoder
export AdaptiveCV, AdaptiveStratifiedCV, JointStratifiedCV
Expand Down
180 changes: 180 additions & 0 deletions src/cli.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
function cli_settings()
s = ArgParseSettings(description="TMLE CLI.")

@add_arg_table s begin
"tmle"
action = :command
help = "Run TMLE."

"svp"
action = :command
help = "Run Sieve Variance Plateau."

"merge"
action = :command
help = "Merges TMLE outputs together."
end

@add_arg_table s["tmle"] begin
"dataset"
arg_type = String
required = true
help = "Path to the dataset (either .csv or .arrow)"

"--estimands"
arg_type = String
help = "A string (`generateATEs`) or a serialized TMLE.Configuration (accepted formats: .json | .yaml | .jls)"
default = "generateATEs"

"--estimators"
arg_type = String
help = "A julia file containing the estimators to use."
default = "glmnet"

"--verbosity"
arg_type = Int
default = 0
help = "Verbosity level"

"--hdf5-output"
arg_type = String
help = "HDF5 file output."

"--json-output"
arg_type = String
help = "JSON file output."

"--jls-output"
arg_type = String
help = "JLS file output."

"--chunksize"
arg_type = Int
help = "Results are written in batches of size chunksize."
default = 100

"--rng"
arg_type = Int
help = "Random seed (Only used for estimands ordering at the moment)."
default = 123

"--cache-strategy"
arg_type = String
help = "Caching Strategy for the nuisance functions, any of (`release-unusable`, `no-cache`, `max-size`)."
default = "release-unusable"

"--sort-estimands"
help = "Sort estimands to minimize cache usage (A brute force approach will be used, resulting in exponentially long sorting time)."
action = :store_true
end

@add_arg_table s["svp"] begin
"input-prefix"
arg_type = String
help = "Input prefix to HDF5 files generated by the tmle CLI."

"--out"
arg_type = String
help = "Output filename."
default = "svp.hdf5"

"--grm-prefix"
arg_type = String
help = "Prefix to the aggregated GRM."
default = "GRM"

"--verbosity"
arg_type = Int
default = 0
help = "Verbosity level"

"--n-estimators"
arg_type = Int
default = 10
help = "Number of variance estimators to build for each estimate."

"--max-tau"
arg_type = Float64
default = 0.8
help = "Maximum distance between any two individuals."

"--estimator-key"
arg_type = String
help = "Estimator to use to proceed with sieve variance correction."
default = "TMLE"
end

@add_arg_table s["merge"] begin
"prefix"
arg_type = String
help = "Prefix to .hdf5 files to be used to create the summary file."

"--hdf5-output"
arg_type = String
help = "HDF5 file output."

"--json-output"
arg_type = String
help = "JSON file output."

"--jls-output"
arg_type = String
help = "JLS file output."
end

return s
end


makeOutput(T::Type, ::Nothing) = T()

function makeOutput(T::Type, str)
args = split(str, ",")
kwargs = Dict(fn => tryparse(ft, val) for (val, fn, ft) zip(args, fieldnames(T), fieldtypes(T)))
return T(;kwargs...)
end

make_outputs(hdf5_string, json_string, jls_tring) = Outputs(
hdf5=makeOutput(HDF5Output, hdf5_string),
json=makeOutput(JSONOutput, json_string),
jls=makeOutput(JLSOutput, jls_tring)
)

function main(args=ARGS)
settings = parse_args(args, cli_settings())
cmd = settings["%COMMAND%"]
cmd_settings = settings[cmd]
if cmd ("tmle", "merge")
outputs = make_outputs(cmd_settings["hdf5-output"], cmd_settings["json-output"], cmd_settings["jls-output"])
if cmd == "tmle"
tmle(cmd_settings["dataset"];
estimands=cmd_settings["estimands"],
estimators=cmd_settings["estimators"],
verbosity=cmd_settings["verbosity"],
outputs=outputs,
chunksize=cmd_settings["chunksize"],
rng=cmd_settings["rng"],
cache_strategy=cmd_settings["cache-strategy"],
sort_estimands=cmd_settings["sort-estimands"]
)
else
make_summary(cmd_settings["prefix"];
outputs=outputs
)
end
else
sieve_variance_plateau(cmd_settings["input-prefix"];
out=cmd_settings["out"],
grm_prefix=cmd_settings["grm-prefix"],
verbosity=cmd_settings["verbosity"],
n_estimators=cmd_settings["n-estimators"],
max_tau=cmd_settings["max-tau"],
estimator_key=cmd_settings["estimator-key"]
)
end
end

function julia_main()::Cint
main()
return 0
end
2 changes: 1 addition & 1 deletion src/runner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ TMLE CLI.
- `-s, --sort_estimands`: Sort estimands to minimize cache usage (A brute force approach will be used, resulting in exponentially long sorting time).
"""
@cast function tmle(dataset::String;
function tmle(dataset::String;
estimands::String="generateATEs",
estimators::String="glmnet",
verbosity::Int=0,
Expand Down
2 changes: 1 addition & 1 deletion src/sieve_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ Sieve Variance Plateau CLI.
- `-m, --max_tau`: Maximum distance between any two individuals.
- `-e, --estimator-key`: Estimator to use to proceed with sieve variance correction.
"""
@cast function sieve_variance_plateau(input_prefix::String;
function sieve_variance_plateau(input_prefix::String;
out::String="svp.hdf5",
grm_prefix::String="GRM",
verbosity::Int=0,
Expand Down
2 changes: 1 addition & 1 deletion src/summary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Combines multiple TMLE .hdf5 output files in a single file. Multiple formats can
- `-o, --outputs`: Ouptuts configuration.
"""
@cast function make_summary(
function make_summary(
prefix::String;
outputs::Outputs=Outputs()
)
Expand Down
54 changes: 28 additions & 26 deletions test/runner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,24 @@ end

@testset "Test tmle: lower p-value threshold only JSON output" begin
build_dataset(;n=1000, format="csv")
outputs = TargetedEstimation.Outputs(
json=TargetedEstimation.JSONOutput(filename="output.json", pval_threshold=1e-15)
)
tmpdir = mktempdir(cleanup=true)
estimandsfile = joinpath(tmpdir, "configuration.json")
configuration = statistical_estimands_only_config()
TMLE.write_json(estimandsfile, configuration)
estimatorfile = joinpath(CONFIGDIR, "ose_config.jl")
datafile = "data.csv"
tmle(datafile;
estimands=estimandsfile,
estimators=estimatorfile,
outputs=outputs)

# Using the main entry point
main([
"tmle",
datafile,
"--estimands", estimandsfile,
"--estimators", estimatorfile,
"--json-output", "output.json,1e-15"]
)

# Essential results
results_from_json = TMLE.read_json(outputs.json.filename)
results_from_json = TMLE.read_json("output.json")
n_IC_empties = 0
for result in results_from_json
if result[:OSE].IC != []
Expand All @@ -179,7 +181,7 @@ end
@test n_IC_empties > 0

rm(datafile)
rm(outputs.json.filename)
rm("output.json")
end

@testset "Test tmle: Failing estimands" begin
Expand Down Expand Up @@ -238,11 +240,6 @@ end

@testset "Test tmle: Causal and Composed Estimands" begin
build_dataset(;n=1000, format="csv")
outputs = TargetedEstimation.Outputs(
json = TargetedEstimation.JSONOutput(filename="output.json"),
jls = TargetedEstimation.JLSOutput(filename="output.jls"),
hdf5 = TargetedEstimation.HDF5Output(filename="output.hdf5")
)
tmpdir = mktempdir(cleanup=true)
estimandsfile = joinpath(tmpdir, "configuration.jls")

Expand All @@ -251,16 +248,21 @@ end
estimatorfile = joinpath(CONFIGDIR, "ose_config.jl")
datafile = "data.csv"

tmle(datafile;
estimands=estimandsfile,
estimators=estimatorfile,
outputs=outputs,
chunksize=2
)
# Using the main entry point
main([
"tmle",
datafile,
"--estimands", estimandsfile,
"--estimators", estimatorfile,
"--chunksize", "2",
"--json-output", "output.json",
"--hdf5-output", "output.hdf5",
"--jls-output", "output.jls"
])

# JLS Output
results = []
open(outputs.jls.filename) do io
open("output.jls") do io
while !eof(io)
push!(results, deserialize(io))
end
Expand All @@ -279,19 +281,19 @@ end
@test results[3].OSE isa TMLE.ComposedEstimate

# JSON Output
results_from_json = TMLE.read_json(outputs.json.filename)
results_from_json = TMLE.read_json("output.json")
@test length(results_from_json) == 3

# HDF5
results_from_json = jldopen(outputs.hdf5.filename)
results_from_json = jldopen("output.hdf5")
@test length(results_from_json["Batch_1"]) == 2
composed_result = only(results_from_json["Batch_2"])
@test composed_result.OSE.cov == results[3].OSE.cov

rm(datafile)
rm(outputs.jls.filename)
rm(outputs.json.filename)
rm(outputs.hdf5.filename)
rm("output.jls")
rm("output.json")
rm("output.hdf5")
end


Expand Down
Loading

0 comments on commit 695e661

Please sign in to comment.