Skip to content

Commit

Permalink
allow missing in X for some ortholearner subclasses
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Vera <fabiovera@microsoft.com>
  • Loading branch information
fverac committed Aug 14, 2023
1 parent 00dd506 commit ab283c4
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 59 deletions.
13 changes: 9 additions & 4 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
How to aggregate the nuisance value for each sample across the `mc_iters` monte carlo iterations of
cross-fitting.
enable_missing: list, default ['W']
Which data arrays to allow missing values for. Currently only supports missing values for W, X.
Examples
--------
Expand Down Expand Up @@ -434,7 +437,7 @@ def _gen_ortho_learner_model_final(self):
def __init__(self, *,
discrete_treatment, treatment_featurizer,
discrete_instrument, categories, cv, random_state,
mc_iters=None, mc_agg='mean'):
mc_iters=None, mc_agg='mean', enable_missing=None):
self.cv = cv
self.discrete_treatment = discrete_treatment
self.treatment_featurizer = treatment_featurizer
Expand All @@ -443,6 +446,7 @@ def __init__(self, *,
self.categories = categories
self.mc_iters = mc_iters
self.mc_agg = mc_agg
self._enable_missing = enable_missing or []
super().__init__()

@abstractmethod
Expand Down Expand Up @@ -605,9 +609,10 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N
assert not (self.discrete_treatment and self.treatment_featurizer), "Treatment featurization " \
"is not supported when treatment is discrete"
if check_input:
Y, T, X, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, X, Z, sample_weight, freq_weight, sample_var, groups)
W, = check_input_arrays(W, force_all_finite='allow-nan')
Y, T, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, Z, sample_weight, freq_weight, sample_var, groups)
X, = check_input_arrays(X, force_all_finite='allow-nan' if 'X' in self._enable_missing else True)
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._enable_missing else True)
self._check_input_dims(Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)

if not only_final:
Expand Down
5 changes: 3 additions & 2 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,16 @@ def _gen_rlearner_model_final(self):
"""

def __init__(self, *, discrete_treatment, treatment_featurizer, categories,
cv, random_state, mc_iters=None, mc_agg='mean'):
cv, random_state, mc_iters=None, mc_agg='mean', enable_missing=None):
super().__init__(discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
cv=cv,
random_state=random_state,
mc_iters=mc_iters,
mc_agg=mc_agg)
mc_agg=mc_agg,
enable_missing=enable_missing)

@abstractmethod
def _gen_model_y(self):
Expand Down
8 changes: 5 additions & 3 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,8 @@ def __init__(self, *,
subforest_size=4,
n_jobs=-1,
random_state=None,
verbose=0):
verbose=0,
enable_missing=False):

# TODO: consider whether we need more care around stateful featurizers,
# since we clone it and fit separate copies
Expand Down Expand Up @@ -636,7 +637,8 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=['W'] if enable_missing else None)

def _get_inference_options(self):
options = super()._get_inference_options()
Expand Down Expand Up @@ -738,7 +740,7 @@ def tune(self, Y, T, *, X=None, W=None,
"""
from ..score import RScorer # import here to avoid circular import issue
Y, T, X, sample_weight, groups = check_input_arrays(Y, T, X, sample_weight, groups)
W, = check_input_arrays(W, force_all_finite='allow-nan')
W, = check_input_arrays(W, force_all_finite='allow-nan' if self._enable_missing else True)

if params == 'auto':
params = {
Expand Down
36 changes: 25 additions & 11 deletions econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def __init__(self, model_final, fit_cate_intercept, featurizer, use_weight_trick
else:
self._fit_cate_intercept = fit_cate_intercept
if self._fit_cate_intercept:
# data is already validated at initial fit time
add_intercept_trans = FunctionTransformer(add_intercept,
validate=True)
validate=False)
if featurizer:
self._featurizer = Pipeline([('featurize', self._original_featurizer),
('add_intercept', add_intercept_trans)])
Expand Down Expand Up @@ -466,7 +467,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
# TODO: consider whether we need more care around stateful featurizers,
# since we clone it and fit separate copies
self.fit_cate_intercept = fit_cate_intercept
Expand All @@ -481,7 +483,8 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)

def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
Expand Down Expand Up @@ -692,7 +695,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
super().__init__(model_y=model_y,
model_t=model_t,
model_final=None,
Expand All @@ -705,7 +709,9 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state,)
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_model_final(self):
return StatsModelsLinearRegression(fit_intercept=False)
Expand Down Expand Up @@ -932,7 +938,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
self.alpha = alpha
self.n_alphas = n_alphas
self.alpha_cov = alpha_cov
Expand All @@ -952,7 +959,9 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_model_final(self):
return MultiOutputDebiasedLasso(alpha=self.alpha,
Expand Down Expand Up @@ -1139,7 +1148,8 @@ def __init__(self, model_y='auto', model_t='auto',
bw=1.0,
cv=2,
mc_iters=None, mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
self.dim = dim
self.bw = bw
super().__init__(model_y=model_y,
Expand All @@ -1153,7 +1163,9 @@ def __init__(self, model_y='auto', model_t='auto',
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_model_final(self):
return ElasticNetCV(fit_intercept=False, random_state=self.random_state)
Expand Down Expand Up @@ -1326,7 +1338,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):

# TODO: consider whether we need more care around stateful featurizers,
# since we clone it and fit separate copies
Expand All @@ -1340,7 +1353,8 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)

def _get_inference_options(self):
# add blb to parent's options
Expand Down
27 changes: 19 additions & 8 deletions econml/dr/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
self.model_propensity = clone(model_propensity, safe=False)
self.model_regression = clone(model_regression, safe=False)
self.model_final = clone(model_final, safe=False)
Expand All @@ -423,7 +424,8 @@ def __init__(self, *,
treatment_featurizer=None, # treatment featurization not supported with discrete treatment
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
random_state=random_state)
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)

# override only so that we can exclude treatment featurization verbiage in docstring
def const_marginal_effect(self, X=None):
Expand Down Expand Up @@ -871,7 +873,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
self.fit_cate_intercept = fit_cate_intercept
super().__init__(model_propensity=model_propensity,
model_regression=model_regression,
Expand All @@ -883,7 +886,9 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_model_final(self):
return StatsModelsLinearRegression(fit_intercept=self.fit_cate_intercept)
Expand Down Expand Up @@ -1149,7 +1154,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
self.fit_cate_intercept = fit_cate_intercept
self.alpha = alpha
self.n_alphas = n_alphas
Expand All @@ -1168,7 +1174,9 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_model_final(self):
return DebiasedLasso(alpha=self.alpha,
Expand Down Expand Up @@ -1432,7 +1440,8 @@ def __init__(self, *,
subforest_size=4,
n_jobs=-1,
verbose=0,
random_state=None):
random_state=None,
enable_missing=False):
self.n_estimators = n_estimators
self.max_depth = max_depth
self.min_samples_split = min_samples_split
Expand All @@ -1456,7 +1465,9 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=enable_missing)
self._enable_missing = ['W'] if enable_missing else [] # override super's default, which is ['X', 'W']

def _gen_model_final(self):
return RegressionForest(n_estimators=self.n_estimators,
Expand Down
18 changes: 12 additions & 6 deletions econml/iv/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
self.model_y_xw = clone(model_y_xw, safe=False)
self.model_t_xw = clone(model_t_xw, safe=False)
self.model_t_xwz = clone(model_t_xwz, safe=False)
Expand All @@ -366,7 +367,8 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=['W'] if enable_missing else None)

def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
Expand Down Expand Up @@ -1149,7 +1151,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
self.model_y_xw = clone(model_y_xw, safe=False)
self.model_t_xw = clone(model_t_xw, safe=False)
self.model_t_xwz = clone(model_t_xwz, safe=False)
Expand All @@ -1163,7 +1166,8 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)

def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
Expand Down Expand Up @@ -1540,7 +1544,8 @@ def __init__(self, *,
cv=2,
mc_iters=None,
mc_agg='mean',
random_state=None):
random_state=None,
enable_missing=False):
self.model_y_xw = clone(model_y_xw, safe=False)
self.model_t_xw = clone(model_t_xw, safe=False)
self.model_t_xwz = clone(model_t_xwz, safe=False)
Expand All @@ -1553,7 +1558,8 @@ def __init__(self, *,
cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
random_state=random_state,
enable_missing=['X', 'W'] if enable_missing else None)

def _gen_featurizer(self):
return clone(self.featurizer, safe=False)
Expand Down
Loading

0 comments on commit ab283c4

Please sign in to comment.