Skip to content

Commit

Permalink
cleanup branch: test, file structure, docstring and linting
Browse files Browse the repository at this point in the history
Signed-off-by: kgao <kevin.leo.gao@gmail.com>
  • Loading branch information
kgao committed Oct 6, 2023
1 parent e652b08 commit 6c7e485
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions econml/tests/test_federated_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import unittest
from econml.dml import LinearDML
from econml.inference import StatsModelsInference
from econml.sklearn_extensions.federated_learning import FederatedLearner
from econml.federated_learning import FederatedEstimator


class FunctionRegressor:
"""A simple model that ignores the data it is fitted on, always just using the specified function to predict"""
def __init__(self, func):
self.func = func

Expand All @@ -20,7 +21,20 @@ def predict(self, X):


class TestFederatedLearning(unittest.TestCase):
"""
A set of unit tests for the FederatedLearner class.
These tests check various scenarios of splitting, aggregation, and comparison
between FederatedLearner and individual LinearDML estimators.
Parameters
----------
None
Returns
-------
None
"""
def test_splitting_works(self):

num_samples = 1000
Expand Down Expand Up @@ -69,15 +83,19 @@ def test_splitting_works(self):
sample_weight=weights, freq_weight=freq_weights, sample_var=sample_var,
inference=StatsModelsInference(cov_type=cov_type))
est_h1.fit(Y1, T1, X=X1, W=W1,
sample_weight=weights1, freq_weight=freq_weights1, sample_var=sample_var1,
sample_weight=weights1,
freq_weight=freq_weights1,
sample_var=sample_var1,
inference=StatsModelsInference(cov_type=cov_type))
est_h2.fit(Y2, T2, X=X2, W=W2,
sample_weight=weights2, freq_weight=freq_weights2, sample_var=sample_var2,
sample_weight=weights2,
freq_weight=freq_weights2,
sample_var=sample_var2,
inference=StatsModelsInference(cov_type=cov_type))

est_fed1 = FederatedLearner([est_all])
est_fed1 = FederatedEstimator([est_all])

est_fed2 = FederatedLearner([est_h1, est_h2])
est_fed2 = FederatedEstimator([est_h1, est_h2])

np.testing.assert_allclose(est_fed1.model_final_._param,
est_fed2.model_final_._param)
Expand Down

0 comments on commit 6c7e485

Please sign in to comment.