diff --git a/econml/tests/test_federated_learning.py b/econml/tests/test_federated_learning.py index 8db9b6043..cfb1a1322 100644 --- a/econml/tests/test_federated_learning.py +++ b/econml/tests/test_federated_learning.py @@ -5,10 +5,11 @@ import unittest from econml.dml import LinearDML from econml.inference import StatsModelsInference -from econml.sklearn_extensions.federated_learning import FederatedLearner +from econml.federated_learning import FederatedEstimator class FunctionRegressor: + """A simple model that ignores the data it is fitted on, always just using the specified function to predict""" def __init__(self, func): self.func = func @@ -20,7 +21,20 @@ def predict(self, X): class TestFederatedLearning(unittest.TestCase): + """ + A set of unit tests for the FederatedLearner class. + These tests check various scenarios of splitting, aggregation, and comparison + between FederatedLearner and individual LinearDML estimators. + + Parameters + ---------- + None + + Returns + ------- + None + """ def test_splitting_works(self): num_samples = 1000 @@ -69,15 +83,19 @@ def test_splitting_works(self): sample_weight=weights, freq_weight=freq_weights, sample_var=sample_var, inference=StatsModelsInference(cov_type=cov_type)) est_h1.fit(Y1, T1, X=X1, W=W1, - sample_weight=weights1, freq_weight=freq_weights1, sample_var=sample_var1, + sample_weight=weights1, + freq_weight=freq_weights1, + sample_var=sample_var1, inference=StatsModelsInference(cov_type=cov_type)) est_h2.fit(Y2, T2, X=X2, W=W2, - sample_weight=weights2, freq_weight=freq_weights2, sample_var=sample_var2, + sample_weight=weights2, + freq_weight=freq_weights2, + sample_var=sample_var2, inference=StatsModelsInference(cov_type=cov_type)) - est_fed1 = FederatedLearner([est_all]) + est_fed1 = FederatedEstimator([est_all]) - est_fed2 = FederatedLearner([est_h1, est_h2]) + est_fed2 = FederatedEstimator([est_h1, est_h2]) np.testing.assert_allclose(est_fed1.model_final_._param, est_fed2.model_final_._param)