Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Vera <fabiovera@microsoft.com>
  • Loading branch information
fverac committed Aug 15, 2023
1 parent 2f06055 commit 9812ef4
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 15 deletions.
6 changes: 4 additions & 2 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,10 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N
if check_input:
Y, T, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, Z, sample_weight, freq_weight, sample_var, groups)
X, = check_input_arrays(X, force_all_finite='allow-nan' if 'X' in self._gen_allowed_missing_vars() else True)
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)
X, = check_input_arrays(
X, force_all_finite='allow-nan' if 'X' in self._gen_allowed_missing_vars() else True)
W, = check_input_arrays(
W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True)
self._check_input_dims(Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)

if not only_final:
Expand Down
2 changes: 1 addition & 1 deletion econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def __init__(self, *,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=enable_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

Expand Down
2 changes: 1 addition & 1 deletion econml/iv/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def __init__(self, *,
mc_agg=mc_agg,
random_state=random_state,
enable_missing=enable_missing)

def _gen_allowed_missing_vars(self):
return ['W'] if self.enable_missing else []

Expand Down
22 changes: 11 additions & 11 deletions econml/tests/test_missing_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,28 +158,28 @@ def test_missing2(self):
NonParamDML(model_y=model_y, model_t=model_t, model_final=non_param_model_final,
discrete_treatment=discrete_treatment, enable_missing=True),
DML(model_y=model_y, model_t=model_t, model_final=param_model_final, enable_missing=True),
DMLIV(model_y_xw=model_y, model_t_xw=model_t, model_t_xwz=model_t,
model_final=param_model_final, discrete_treatment=discrete_treatment,
DMLIV(model_y_xw=model_y, model_t_xw=model_t, model_t_xwz=model_t,
model_final=param_model_final, discrete_treatment=discrete_treatment,
discrete_instrument=discrete_instrument, enable_missing=True),
NonParamDMLIV(model_y_xw=model_y, model_t_xw=model_t, model_t_xwz=model_t,
model_final=non_param_model_final, discrete_treatment=discrete_treatment,
model_final=non_param_model_final, discrete_treatment=discrete_treatment,
discrete_instrument=discrete_instrument, enable_missing=True),
DRLearner(model_propensity=model_t, model_regression=model_y, model_final=model_final, enable_missing=True)
]

# test W only
w_missing_models = [
DRIV(model_y_xw=model_y, model_t_xw=model_t, model_z_xw=model_t, model_tz_xw=model_t,
discrete_treatment=discrete_treatment, discrete_instrument=discrete_instrument,
DRIV(model_y_xw=model_y, model_t_xw=model_t, model_z_xw=model_t, model_tz_xw=model_t,
discrete_treatment=discrete_treatment, discrete_instrument=discrete_instrument,
prel_cate_approach='driv', projection=False, enable_missing=True),
DRIV(model_y_xw=model_y, model_t_xw=model_t, model_t_xwz=model_t, model_tz_xw=model_y,
discrete_treatment=discrete_treatment, discrete_instrument=discrete_instrument,
DRIV(model_y_xw=model_y, model_t_xw=model_t, model_t_xwz=model_t, model_tz_xw=model_y,
discrete_treatment=discrete_treatment, discrete_instrument=discrete_instrument,
prel_cate_approach='driv', projection=True, enable_missing=True),
DRIV(model_y_xw=model_y, model_t_xw=model_t, model_z_xw=model_t, model_t_xwz=model_t, model_tz_xw=model_t,
discrete_treatment=discrete_treatment, discrete_instrument=discrete_instrument,
discrete_treatment=discrete_treatment, discrete_instrument=discrete_instrument,
prel_cate_approach='dmliv', projection=False, enable_missing=True),
DRIV(model_y_xw=model_y, model_t_xw=model_t, model_t_xwz=model_t, model_tz_xw=model_y,
discrete_treatment=discrete_treatment, discrete_instrument=discrete_instrument,
DRIV(model_y_xw=model_y, model_t_xw=model_t, model_t_xwz=model_t, model_tz_xw=model_y,
discrete_treatment=discrete_treatment, discrete_instrument=discrete_instrument,
prel_cate_approach='dmliv', projection=True, enable_missing=True),
IntentToTreatDRIV(model_y_xw=model_y, model_t_xwz=model_t, prel_cate_approach='driv',
model_final=model_final, enable_missing=True),
Expand All @@ -195,7 +195,7 @@ def test_missing2(self):
OrthoIV(model_y_xw=model_y, model_t_xw=model_t, model_z_xw=model_t,
discrete_treatment=True, discrete_instrument=True, enable_missing=True),
LinearDRIV(model_y_xw=model_y, model_t_xw=model_t, model_z_xw=model_t, model_tz_xw=model_t,
prel_cate_approach='driv', discrete_treatment=True, discrete_instrument=True,
prel_cate_approach='driv', discrete_treatment=True, discrete_instrument=True,
enable_missing=True),
SparseLinearDRIV(model_y_xw=model_y, model_t_xw=model_t, model_z_xw=model_t, model_tz_xw=model_t,
prel_cate_approach='driv', discrete_treatment=True, discrete_instrument=True,
Expand Down

0 comments on commit 9812ef4

Please sign in to comment.