Skip to content

Commit

Permalink
refactor keyword arg to be bool only, add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Vera <fabiovera@microsoft.com>
  • Loading branch information
fverac committed Aug 14, 2023
1 parent ab283c4 commit 2f06055
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 43 deletions.
16 changes: 10 additions & 6 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def _gen_ortho_learner_model_final(self):
def __init__(self, *,
discrete_treatment, treatment_featurizer,
discrete_instrument, categories, cv, random_state,
mc_iters=None, mc_agg='mean', enable_missing=None):
mc_iters=None, mc_agg='mean', enable_missing=False):
self.cv = cv
self.discrete_treatment = discrete_treatment
self.treatment_featurizer = treatment_featurizer
Expand All @@ -446,9 +446,12 @@ def __init__(self, *,
self.categories = categories
self.mc_iters = mc_iters
self.mc_agg = mc_agg
self._enable_missing = enable_missing or []
self.enable_missing = enable_missing
super().__init__()

def _gen_allowed_missing_vars(self):
return ['X', 'W'] if self.enable_missing else []

@abstractmethod
def _gen_ortho_learner_model_nuisance(self):
""" Must return a fresh instance of a nuisance model
Expand Down Expand Up @@ -611,8 +614,8 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N
if check_input:
Y, T, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, Z, sample_weight, freq_weight, sample_var, groups)
X, = check_input_arrays(X, force_all_finite='allow-nan' if 'X' in self._enable_missing else True)
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._enable_missing else True)
X, = check_input_arrays(X, force_all_finite='allow-nan' if 'X' in self._gen_allowed_missing_vars() else True)
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)
self._check_input_dims(Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)

if not only_final:
Expand Down Expand Up @@ -884,8 +887,9 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
"""
if not hasattr(self._ortho_learner_model_final, 'score'):
raise AttributeError("Final model does not have a score method!")
Y, T, X, Z = check_input_arrays(Y, T, X, Z)
W, = check_input_arrays(W, force_all_finite='allow-nan')
Y, T, Z = check_input_arrays(Y, T, Z)
X, = check_input_arrays(X, force_all_finite='allow-nan' if 'X' in self._gen_allowed_missing_vars() else True)
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)
self._check_fitted_dims(X)
self._check_fitted_dims_w_z(W, Z)
X, T = self._expand_treatments(X, T)
Expand Down
2 changes: 1 addition & 1 deletion econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _gen_rlearner_model_final(self):
"""

def __init__(self, *, discrete_treatment, treatment_featurizer, categories,
cv, random_state, mc_iters=None, mc_agg='mean', enable_missing=None):
cv, random_state, mc_iters=None, mc_agg='mean', enable_missing=False):
super().__init__(discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
discrete_instrument=False, # no instrument, so doesn't matter
Expand Down
7 changes: 5 additions & 2 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,10 @@ def __init__(self, *,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=['W'] if enable_missing else None)
enable_missing=enable_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

def _get_inference_options(self):
options = super()._get_inference_options()
Expand Down Expand Up @@ -740,7 +743,7 @@ def tune(self, Y, T, *, X=None, W=None,
"""
from ..score import RScorer # import here to avoid circular import issue
Y, T, X, sample_weight, groups = check_input_arrays(Y, T, X, sample_weight, groups)
W, = check_input_arrays(W, force_all_finite='allow-nan' if self._enable_missing else True)
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)

if params == 'auto':
params = {
Expand Down
22 changes: 17 additions & 5 deletions econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,10 @@ def __init__(self, *,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)
enable_missing=enable_missing)

def _gen_allowed_missing_vars(self):
return ['X', 'W'] if self.enable_missing else []

def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
Expand Down Expand Up @@ -711,7 +714,9 @@ def __init__(self, *,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

def _gen_model_final(self):
return StatsModelsLinearRegression(fit_intercept=False)
Expand Down Expand Up @@ -961,7 +966,9 @@ def __init__(self, *,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

def _gen_model_final(self):
return MultiOutputDebiasedLasso(alpha=self.alpha,
Expand Down Expand Up @@ -1165,7 +1172,9 @@ def __init__(self, model_y='auto', model_t='auto',
mc_agg=mc_agg,
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

def _gen_model_final(self):
return ElasticNetCV(fit_intercept=False, random_state=self.random_state)
Expand Down Expand Up @@ -1354,7 +1363,10 @@ def __init__(self, *,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)
enable_missing=enable_missing)

def _gen_allowed_missing_vars(self):
return ['X', 'W'] if self.enable_missing else []

def _get_inference_options(self):
# add blb to parent's options
Expand Down
5 changes: 4 additions & 1 deletion econml/dowhy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,11 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, outcome_names=None, treatment_nam
column_names = outcome_names + treatment_names + feature_names + confounder_names + instrument_names

# transfer input to numpy arrays
if 'X' in self._cate_estimator._gen_allowed_missing_vars():
warnings.warn("DoWhyWrapper does not support missing values in X.")
Y, T, X, Z = check_input_arrays(Y, T, X, Z)
W, = check_input_arrays(W, force_all_finite='allow-nan')
W, = check_input_arrays(
W, force_all_finite='allow-nan' if 'W' in self._cate_estimator._gen_allowed_missing_vars() else True)
# transfer input to 2d arrays
n_obs = Y.shape[0]
Y, T, X, W, Z = reshape_arrays_2dim(n_obs, Y, T, X, W, Z)
Expand Down
17 changes: 13 additions & 4 deletions econml/dr/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,10 @@ def __init__(self, *,
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)
enable_missing=enable_missing)

def _gen_allowed_missing_vars(self):
return ['X', 'W'] if self.enable_missing else []

# override only so that we can exclude treatment featurization verbiage in docstring
def const_marginal_effect(self, X=None):
Expand Down Expand Up @@ -888,7 +891,9 @@ def __init__(self, *,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

def _gen_model_final(self):
return StatsModelsLinearRegression(fit_intercept=self.fit_cate_intercept)
Expand Down Expand Up @@ -1176,7 +1181,9 @@ def __init__(self, *,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

def _gen_model_final(self):
return DebiasedLasso(alpha=self.alpha,
Expand Down Expand Up @@ -1467,7 +1474,9 @@ def __init__(self, *,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

def _gen_model_final(self):
return RegressionForest(n_estimators=self.n_estimators,
Expand Down
9 changes: 6 additions & 3 deletions econml/iv/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,10 @@ def __init__(self, *,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=['W'] if enable_missing else None)
enable_missing=enable_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
Expand Down Expand Up @@ -1167,7 +1170,7 @@ def __init__(self, *,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)
enable_missing=enable_missing)

def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
Expand Down Expand Up @@ -1559,7 +1562,7 @@ def __init__(self, *,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)
enable_missing=enable_missing)

def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
Expand Down
24 changes: 12 additions & 12 deletions econml/iv/dr/_dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__(self, *,
mc_iters=None,
mc_agg='mean',
random_state=None,
enable_missing=None):
enable_missing=False):
self.model_final = clone(model_final, safe=False)
self.featurizer = clone(featurizer, safe=False)
self.fit_cate_intercept = fit_cate_intercept
Expand All @@ -325,6 +325,9 @@ def __init__(self, *,
random_state=random_state,
enable_missing=enable_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

# Maggie: I think that would be the case?
def _get_inference_options(self):
options = super()._get_inference_options()
Expand Down Expand Up @@ -560,7 +563,7 @@ def __init__(self, *,
mc_iters=None,
mc_agg='mean',
random_state=None,
enable_missing=None):
enable_missing=False):
self.model_y_xw = clone(model_y_xw, safe=False)
self.model_t_xw = clone(model_t_xw, safe=False)
self.model_t_xwz = clone(model_t_xwz, safe=False)
Expand Down Expand Up @@ -879,7 +882,7 @@ def __init__(self, *,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=['W'] if enable_missing else None)
enable_missing=enable_missing)

def _gen_model_final(self):
if self.model_final is None:
Expand Down Expand Up @@ -907,7 +910,7 @@ def _gen_prel_model_effect(self):
mc_iters=self.mc_iters,
mc_agg=self.mc_agg,
random_state=self.random_state,
enable_missing=self._enable_missing)
enable_missing=self.enable_missing)
elif self.prel_cate_approach == "dmliv":
return NonParamDMLIV(model_y_xw=clone(self.model_y_xw, safe=False),
model_t_xw=clone(self.model_t_xw, safe=False),
Expand All @@ -921,7 +924,7 @@ def _gen_prel_model_effect(self):
mc_iters=self.mc_iters,
mc_agg=self.mc_agg,
random_state=self.random_state,
enable_missing=True if self._enable_missing else False)
enable_missing=self.enable_missing)
else:
raise ValueError(
"We only support 'dmliv' or 'driv' preliminary model effect, "
Expand Down Expand Up @@ -975,9 +978,6 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None,
assert self.model_t_xwz == "auto", ("In the case of projection=False and prel_cate_approach='driv', "
"model_t_xwz will not be fitted, "
"please keep it as default!")
# assert not (self._enable_missing and self.prel_cate_approach == "dmliv" and not self.projection), \
# ("Cannot handle missing data when prel_cate_approach='dmliv' and projection=False!")

return super().fit(Y, T, X=X, W=W, Z=Z,
sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups,
cache_values=cache_values, inference=inference)
Expand Down Expand Up @@ -2288,7 +2288,7 @@ def __init__(self, *,
mc_iters=None,
mc_agg='mean',
random_state=None,
enable_missing=None):
enable_missing=False):
self.model_y_xw = clone(model_y_xw, safe=False)
self.model_t_xwz = clone(model_t_xwz, safe=False)
self.prel_model_effect = clone(prel_model_effect, safe=False)
Expand Down Expand Up @@ -2533,7 +2533,7 @@ def __init__(self, *,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=['W'] if enable_missing else None)
enable_missing=enable_missing)

def _gen_model_final(self):
if self.model_final is None:
Expand All @@ -2553,7 +2553,7 @@ def _gen_prel_model_effect(self):
opt_reweighted=self.prel_opt_reweighted,
cv=self.prel_cv,
random_state=self.random_state,
enable_missing=self._enable_missing)
enable_missing=self.enable_missing)
elif self.prel_cate_approach == "dmliv":
return NonParamDMLIV(model_y_xw=clone(self.model_y_xw, safe=False),
model_t_xw=clone(self.model_t_xwz, safe=False),
Expand All @@ -2567,7 +2567,7 @@ def _gen_prel_model_effect(self):
mc_iters=self.mc_iters,
mc_agg=self.mc_agg,
random_state=self.random_state,
enable_missing=True if self._enable_missing else False)
enable_missing=self.enable_missing)
else:
raise ValueError(
"We only support 'dmliv' or 'driv' preliminary model effect, "
Expand Down
11 changes: 8 additions & 3 deletions econml/panel/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
self.fit_cate_intercept = fit_cate_intercept
self.linear_first_stages = linear_first_stages
self.featurizer = clone(featurizer, safe=False)
Expand All @@ -476,7 +477,11 @@ def __init__(self, *,
cv=GroupKFold(cv) if isinstance(cv, int) else cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=enable_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

# override only so that we can exclude treatment featurization verbiage in docstring
def const_marginal_effect(self, X=None):
Expand Down Expand Up @@ -672,7 +677,7 @@ def score(self, Y, T, X=None, W=None, sample_weight=None, *, groups):
if not hasattr(self._ortho_learner_model_final, 'score'):
raise AttributeError("Final model does not have a score method!")
Y, T, X, groups = check_input_arrays(Y, T, X, groups)
W, = check_input_arrays(W, force_all_finite='allow-nan')
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)
self._check_fitted_dims(X)
X, T = super()._expand_treatments(X, T)
n_iters = len(self._models_nuisance)
Expand Down
Loading

0 comments on commit 2f06055

Please sign in to comment.