Skip to content

Commit

Permalink
add init args to drlearner, causalforestdml
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Vera <fabiovera@microsoft.com>
  • Loading branch information
fverac committed Sep 25, 2023
1 parent 06f85fe commit 6bc0660
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def __init__(self, *,
model_t='auto',
featurizer=None,
treatment_featurizer=None,
binary_outcome=False,
discrete_treatment=False,
categories='auto',
cv=2,
Expand Down Expand Up @@ -630,7 +631,8 @@ def __init__(self, *,
self.subforest_size = subforest_size
self.n_jobs = n_jobs
self.verbose = verbose
super().__init__(discrete_treatment=discrete_treatment,
super().__init__(binary_outcome=binary_outcome,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
Expand Down
8 changes: 8 additions & 0 deletions econml/dr/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def __init__(self, *,
model_propensity='auto',
model_regression='auto',
model_final=StatsModelsLinearRegression(),
binary_outcome=False,
multitask_model_final=False,
featurizer=None,
min_propensity=1e-6,
Expand All @@ -419,6 +420,7 @@ def __init__(self, *,
super().__init__(cv=cv,
mc_iters=mc_iters,
mc_agg=mc_agg,
binary_outcome=binary_outcome,
discrete_treatment=True,
treatment_featurizer=None, # treatment featurization not supported with discrete treatment
discrete_instrument=False, # no instrument, so doesn't matter
Expand Down Expand Up @@ -864,6 +866,7 @@ class LinearDRLearner(StatsModelsCateEstimatorDiscreteMixin, DRLearner):
def __init__(self, *,
model_propensity='auto',
model_regression='auto',
binary_outcome=False,
featurizer=None,
fit_cate_intercept=True,
min_propensity=1e-6,
Expand All @@ -876,6 +879,7 @@ def __init__(self, *,
super().__init__(model_propensity=model_propensity,
model_regression=model_regression,
model_final=None,
binary_outcome=binary_outcome,
featurizer=featurizer,
multitask_model_final=False,
min_propensity=min_propensity,
Expand Down Expand Up @@ -1137,6 +1141,7 @@ def __init__(self, *,
model_regression='auto',
featurizer=None,
fit_cate_intercept=True,
binary_outcome=False,
alpha='auto',
n_alphas=100,
alpha_cov='auto',
Expand All @@ -1161,6 +1166,7 @@ def __init__(self, *,
super().__init__(model_propensity=model_propensity,
model_regression=model_regression,
model_final=None,
binary_outcome=binary_outcome,
featurizer=featurizer,
multitask_model_final=False,
min_propensity=min_propensity,
Expand Down Expand Up @@ -1413,6 +1419,7 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner):
def __init__(self, *,
model_regression="auto",
model_propensity="auto",
binary_outcome=False,
featurizer=None,
min_propensity=1e-6,
categories='auto',
Expand Down Expand Up @@ -1449,6 +1456,7 @@ def __init__(self, *,
super().__init__(model_regression=model_regression,
model_propensity=model_propensity,
model_final=None,
binary_outcome=binary_outcome,
featurizer=featurizer,
multitask_model_final=False,
min_propensity=min_propensity,
Expand Down

0 comments on commit 6bc0660

Please sign in to comment.