From 6bc066079f42c2bd56859f99d225b4224712e721 Mon Sep 17 00:00:00 2001 From: Fabio Vera Date: Mon, 25 Sep 2023 12:00:08 -0400 Subject: [PATCH] add init args to drlearner, causalforestdml Signed-off-by: Fabio Vera --- econml/dml/causal_forest.py | 4 +++- econml/dr/_drlearner.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/econml/dml/causal_forest.py b/econml/dml/causal_forest.py index 8faf157c2..bc30aa1d4 100644 --- a/econml/dml/causal_forest.py +++ b/econml/dml/causal_forest.py @@ -577,6 +577,7 @@ def __init__(self, *, model_t='auto', featurizer=None, treatment_featurizer=None, + binary_outcome=False, discrete_treatment=False, categories='auto', cv=2, @@ -630,7 +631,8 @@ def __init__(self, *, self.subforest_size = subforest_size self.n_jobs = n_jobs self.verbose = verbose - super().__init__(discrete_treatment=discrete_treatment, + super().__init__(binary_outcome=binary_outcome, + discrete_treatment=discrete_treatment, treatment_featurizer=treatment_featurizer, categories=categories, cv=cv, diff --git a/econml/dr/_drlearner.py b/econml/dr/_drlearner.py index 3ca702a0c..e4e264113 100644 --- a/econml/dr/_drlearner.py +++ b/econml/dr/_drlearner.py @@ -402,6 +402,7 @@ def __init__(self, *, model_propensity='auto', model_regression='auto', model_final=StatsModelsLinearRegression(), + binary_outcome=False, multitask_model_final=False, featurizer=None, min_propensity=1e-6, @@ -419,6 +420,7 @@ def __init__(self, *, super().__init__(cv=cv, mc_iters=mc_iters, mc_agg=mc_agg, + binary_outcome=binary_outcome, discrete_treatment=True, treatment_featurizer=None, # treatment featurization not supported with discrete treatment discrete_instrument=False, # no instrument, so doesn't matter @@ -864,6 +866,7 @@ class LinearDRLearner(StatsModelsCateEstimatorDiscreteMixin, DRLearner): def __init__(self, *, model_propensity='auto', model_regression='auto', + binary_outcome=False, featurizer=None, fit_cate_intercept=True, min_propensity=1e-6, @@ -876,6 +879,7 @@ def __init__(self, *, super().__init__(model_propensity=model_propensity, model_regression=model_regression, model_final=None, + binary_outcome=binary_outcome, featurizer=featurizer, multitask_model_final=False, min_propensity=min_propensity, @@ -1137,6 +1141,7 @@ def __init__(self, *, model_regression='auto', featurizer=None, fit_cate_intercept=True, + binary_outcome=False, alpha='auto', n_alphas=100, alpha_cov='auto', @@ -1161,6 +1166,7 @@ def __init__(self, *, super().__init__(model_propensity=model_propensity, model_regression=model_regression, model_final=None, + binary_outcome=binary_outcome, featurizer=featurizer, multitask_model_final=False, min_propensity=min_propensity, @@ -1413,6 +1419,7 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner): def __init__(self, *, model_regression="auto", model_propensity="auto", + binary_outcome=False, featurizer=None, min_propensity=1e-6, categories='auto', @@ -1449,6 +1456,7 @@ def __init__(self, *, super().__init__(model_regression=model_regression, model_propensity=model_propensity, model_final=None, + binary_outcome=binary_outcome, featurizer=featurizer, multitask_model_final=False, min_propensity=min_propensity,