Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isolate config and argument parsing #1086

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 55 additions & 140 deletions experiments/ClimaEarth/run_amip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,133 +86,51 @@ dictionary and the simulation-specific configuration dictionary, which allows th
We can additionally pass the configuration dictionary to the component model initializers, which will then override the default settings of the component models.
=#

## coupler simulation default configuration
include("cli_options.jl")
parsed_args = parse_commandline(argparse_settings())

## modify parsed args for fast testing from REPL #hide
if isinteractive()
parsed_args["config_file"] =
isnothing(parsed_args["config_file"]) ? joinpath(pkg_dir, "config/ci_configs/interactive_debug.yml") :
parsed_args["config_file"]
parsed_args["job_id"] = "interactive_debug"
end

## the unique job id should be passed in via the command line
job_id = parsed_args["job_id"]
@assert !isnothing(job_id) "job_id must be passed in via the command line"

## read in config dictionary from file, overriding the coupler defaults in `parsed_args`
config_dict = YAML.load_file(parsed_args["config_file"])
config_dict = merge(parsed_args, config_dict)
include("user_io/arg_parsing.jl")
config_dict, job_id = get_coupler_config()
(;
mode_name,
random_seed,
FT,
comms_ctx,
t_end,
t_start,
date0,
date,
Δt_cpl,
component_dt_dict,
saveat,
hourly_checkpoint,
hourly_checkpoint_dt,
restart_dir,
restart_t,
use_coupler_diagnostics,
use_land_diagnostics,
calendar_dt,
evolving_ocean,
mono_surface,
turb_flux_partition,
land_domain_type,
land_albedo_type,
land_temperature_anomaly,
energy_check,
conservation_softfail,
output_dir_root,
make_ci_plots,
) = get_args(config_dict)

comms_ctx = Utilities.get_comms_context(config_dict)
## get component model dictionaries (if applicable)
# TODO don't modify coupler config dict in get_atmos_config_dict - what do we have to change to make this possbile?
# TODO this has to come after arg parsing to get the correct dt's, but we should read in configs, then parse args
atmos_config_dict, config_dict = get_atmos_config_dict(config_dict, job_id)
(; dt_rad, output_default_diagnostics) = get_atmos_args(atmos_config_dict)

## set unique random seed if desired, otherwise use default
random_seed = config_dict["unique_seed"] ? time_ns() : 1234
Random.seed!(random_seed)
@info "Random seed set to $(random_seed)"

## set up diagnostics before retrieving atmos config
mode_name = config_dict["mode_name"]
use_coupler_diagnostics = config_dict["use_coupler_diagnostics"]
t_end = Float64(Utilities.time_to_seconds(config_dict["t_end"]))
t_start = 0.0

function get_period(t_start, t_end)
sim_duration = t_end - t_start
secs_per_day = 86400
if sim_duration >= 90 * secs_per_day
# if duration >= 90 days, take monthly means
period = "1months"
calendar_dt = Dates.Month(1)
elseif sim_duration >= 30 * secs_per_day
# if duration >= 30 days, take means over 10 days
period = "10days"
calendar_dt = Dates.Day(10)
elseif sim_duration >= secs_per_day
# if duration >= 1 day, take daily means
period = "1days"
calendar_dt = Dates.Day(1)
else
# if duration < 1 day, take hourly means
period = "1hours"
calendar_dt = Dates.Hour(1)
end
return (period, calendar_dt)
end

if mode_name == "amip" && use_coupler_diagnostics
@info "Using default AMIP diagnostics"
(period, calendar_dt) = get_period(t_start, t_end)

!haskey(config_dict, "diagnostics") && (config_dict["diagnostics"] = Vector{Dict{Any, Any}}())
push!(
config_dict["diagnostics"],
Dict("short_name" => ["toa_fluxes_net"], "reduction_time" => "average", "period" => period),
)
end


## read in some parsed command line arguments, required by this script
energy_check = config_dict["energy_check"]
const FT = config_dict["FLOAT_TYPE"] == "Float64" ? Float64 : Float32
land_sim_name = "bucket"
t_end = Float64(Utilities.time_to_seconds(config_dict["t_end"]))
t_start = 0.0
tspan = (t_start, t_end)
Δt_cpl = Float64(config_dict["dt_cpl"])
component_dt_names = ["dt_atmos", "dt_land", "dt_ocean", "dt_seaice"]
component_dt_dict = Dict{String, Float64}()
# check if all component dt's are specified
if all(key -> !isnothing(config_dict[key]), component_dt_names)
# when all component dt's are specified, ignore the dt field
if haskey(config_dict, "dt")
@warn "Removing dt in favor of individual component dt's"
delete!(config_dict, "dt")
end
for key in component_dt_names
component_dt = Float64(Utilities.time_to_seconds(config_dict[key]))
@assert Δt_cpl % component_dt == 0.0 "Coupler dt must be divisible by all component dt's\n dt_cpl = $Δt_cpl\n $key = $component_dt"
component_dt_dict[key] = component_dt
end
else
# when not all component dt's are specified, use the dt field
@assert haskey(config_dict, "dt") "dt or (dt_atmos, dt_land, dt_ocean, and dt_seaice) must be specified"
for key in component_dt_names
if !isnothing(config_dict[key])
@warn "Removing $key from config in favor of dt because not all component dt's are specified"
end
delete!(config_dict, key)
component_dt_dict[key] = Float64(Utilities.time_to_seconds(config_dict["dt"]))
end
end
## get component model dictionaries (if applicable)
atmos_config_dict, config_dict = get_atmos_config_dict(config_dict, job_id)
atmos_config_object = CA.AtmosConfig(atmos_config_dict)

saveat = Float64(Utilities.time_to_seconds(config_dict["dt_save_to_sol"]))
date0 = date = Dates.DateTime(config_dict["start_date"], Dates.dateformat"yyyymmdd")
mono_surface = config_dict["mono_surface"]
hourly_checkpoint = config_dict["hourly_checkpoint"]
hourly_checkpoint_dt = config_dict["hourly_checkpoint_dt"]
restart_dir = config_dict["restart_dir"]
restart_t = Int(config_dict["restart_t"])
evolving_ocean = config_dict["evolving_ocean"]
dt_rad = config_dict["dt_rad"]
use_land_diagnostics = config_dict["use_land_diagnostics"]

#=
## Setup Communication Context
We set up communication context for CPU single thread/CPU with MPI/GPU. If no device is passed to `ClimaComms.context()`
then `ClimaComms` automatically selects the device from which this code is called.
=#


## make sure we don't use animations for GPU runs
if comms_ctx.device isa ClimaComms.CUDADevice
config_dict["anim"] = false
end

#=
### I/O Directory Setup
Expand All @@ -221,14 +139,11 @@ the plots (from postprocessing and the conservation checks) of the simulation wi
temporary files will be saved.
=#

COUPLER_OUTPUT_DIR = joinpath(config_dict["coupler_output_dir"], joinpath(mode_name, job_id))
COUPLER_OUTPUT_DIR = joinpath(output_dir_root, mode_name, job_id)
dir_paths = Utilities.setup_output_dirs(output_dir = COUPLER_OUTPUT_DIR, comms_ctx = comms_ctx)
@info "Coupler output directory $(dir_paths.output)"
@info "Coupler artifacts directory $(dir_paths.artifacts)"

@info(dir_paths.output)
config_dict["print_config_dict"] && @info(config_dict)

#=
## Data File Paths
=#
Expand Down Expand Up @@ -264,7 +179,7 @@ This uses the `ClimaAtmos.jl` model, with parameterization options specified in
Utilities.show_memory_usage()

## init atmos model component
atmos_sim = atmos_init(atmos_config_object);
atmos_sim = atmos_init(CA.AtmosConfig(atmos_config_dict));
# Get surface elevation from `atmos` coordinate field
surface_elevation = CC.Fields.level(CC.Fields.coordinate_field(atmos_sim.integrator.u.f).z, CC.Utilities.half)
Utilities.show_memory_usage()
Expand Down Expand Up @@ -329,9 +244,9 @@ if mode_name == "amip"
land_sim = bucket_init(
FT,
tspan,
config_dict["land_domain_type"],
config_dict["land_albedo_type"],
config_dict["land_temperature_anomaly"],
land_domain_type,
land_albedo_type,
land_temperature_anomaly,
dir_paths;
dt = component_dt_dict["dt_land"],
space = boundary_space,
Expand Down Expand Up @@ -427,9 +342,9 @@ elseif mode_name in ("slabplanet", "slabplanet_aqua", "slabplanet_terra")
land_sim = bucket_init(
FT,
tspan,
config_dict["land_domain_type"],
config_dict["land_albedo_type"],
config_dict["land_temperature_anomaly"],
land_domain_type,
land_albedo_type,
land_temperature_anomaly,
dir_paths;
dt = component_dt_dict["dt_land"],
space = boundary_space,
Expand Down Expand Up @@ -477,9 +392,9 @@ elseif mode_name == "slabplanet_eisenman"
land_sim = bucket_init(
FT,
tspan,
config_dict["land_domain_type"],
config_dict["land_albedo_type"],
config_dict["land_temperature_anomaly"],
land_domain_type,
land_albedo_type,
land_temperature_anomaly,
dir_paths;
dt = component_dt_dict["dt_land"],
space = boundary_space,
Expand Down Expand Up @@ -622,9 +537,9 @@ callbacks =
Decide on the type of turbulent flux partition, partitioned or combined (see `FluxCalculator` documentation for more details).
=#
turbulent_fluxes = nothing
if config_dict["turb_flux_partition"] == "PartitionedStateFluxes"
if turb_flux_partition == "PartitionedStateFluxes"
turbulent_fluxes = FluxCalculator.PartitionedStateFluxes()
elseif config_dict["turb_flux_partition"] == "CombinedStateFluxesMOST"
elseif turb_flux_partition == "CombinedStateFluxesMOST"
turbulent_fluxes = FluxCalculator.CombinedStateFluxesMOST()
else
error("turb_flux_partition must be either PartitionedStateFluxes or CombinedStateFluxesMOST")
Expand Down Expand Up @@ -891,7 +806,7 @@ end
#=
## Postprocessing
All postprocessing is performed using the root process only, if applicable.
Our postprocessing consists of outputting a number of plots and animations to visualize the model output.
Our postprocessing consists of outputting a number of plots to visualize the model output.

The postprocessing includes:
- Energy and water conservation checks (if running SlabPlanet with checks enabled)
Expand All @@ -910,14 +825,14 @@ if ClimaComms.iamroot(comms_ctx)
plot_global_conservation(
cs.conservation_checks.energy,
cs,
config_dict["conservation_softfail"],
conservation_softfail,
figname1 = joinpath(dir_paths.artifacts, "total_energy_bucket.png"),
figname2 = joinpath(dir_paths.artifacts, "total_energy_log_bucket.png"),
)
plot_global_conservation(
cs.conservation_checks.water,
cs,
config_dict["conservation_softfail"],
conservation_softfail,
figname1 = joinpath(dir_paths.artifacts, "total_water_bucket.png"),
figname2 = joinpath(dir_paths.artifacts, "total_water_log_bucket.png"),
)
Expand Down Expand Up @@ -954,15 +869,15 @@ if ClimaComms.iamroot(comms_ctx)
end

# Check this because we only want monthly data for making plots
if t_end > 84600 * 31 * 3 && config_dict["output_default_diagnostics"]
if t_end > 84600 * 31 * 3 && output_default_diagnostics
include("leaderboard/leaderboard.jl")
diagnostics_folder_path = atmos_sim.integrator.p.output_dir
leaderboard_base_path = dir_paths.artifacts
compute_leaderboard(leaderboard_base_path, diagnostics_folder_path)
end
end
## plot extra atmosphere diagnostics if specified
if config_dict["ci_plots"]
if make_ci_plots
@info "Generating CI plots"
include("user_io/ci_plots.jl")
make_ci_plots(atmos_sim.integrator.p.output_dir, dir_paths.artifacts)
Expand Down
Loading
Loading