Skip to content

Commit

Permalink
added atac_layer argument to train_model tasks and made tests for it
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderAivazidis committed Jul 8, 2024
1 parent 4429882 commit f34196b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/pyrovelocity/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def train_dataset(
"batch_size": batch_size,
}
)

adata, trained_model, posterior_samples = train_model(
adata=adata,
guide_type=guide_type,
Expand Down Expand Up @@ -279,6 +279,7 @@ def check_shared_time(posterior_samples, adata):
@beartype
def train_model(
adata: str | AnnData,
atac_layer: Optional[str] = None,
guide_type: str = "auto",
model_type: str = "auto",
batch_size: int = -1,
Expand All @@ -305,6 +306,7 @@ def train_model(
Args:
adata (str | AnnData): Path to a file that can be read to an AnnData object or an AnnData object.
atac_layer (Optional[str], optional): Name of AnnData layer that contains atac data, if present.
guide_type (str, optional): The type of guide function for the Pyro model. Default is "auto".
model_type (str, optional): The type of Pyro model. Default is "auto".
batch_size (int, optional): Batch size for training. Default is -1, which indicates using the full dataset.
Expand Down Expand Up @@ -347,6 +349,9 @@ def train_model(
>>> copy_raw_counts(adata)
>>> _, model, posterior_samples = train_model(adata, use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path)
"""

if atac_layer:
logger.info("Multiome model not yet implemented. Proceeding without atac data.")
if isinstance(adata, str):
adata = load_anndata_from_path(adata)

Expand Down
1 change: 1 addition & 0 deletions src/pyrovelocity/tests/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit test package for pyrovelocity."""
12 changes: 12 additions & 0 deletions src/pyrovelocity/tests/tasks/test_train_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Tests for `pyrovelocity._train_model` task."""

from pyrovelocity.tasks.train import train_model
from pyrovelocity.utils import generate_sample_data
from pyrovelocity.tasks.preprocess import copy_raw_counts
def test_train_model(tmp_path):
loss_plot_path = str(tmp_path) + "/loss_plot_docs.png"
print(loss_plot_path)
adata = generate_sample_data(random_seed=99)
copy_raw_counts(adata)
_, model, posterior_samples = train_model(adata, atac_layer = 'atac',
use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path)

0 comments on commit f34196b

Please sign in to comment.