Skip to content

Commit

Permalink
Merge pull request #53 from altescy/rename
Browse files Browse the repository at this point in the history
Rename MLFlow -> Mlflow
  • Loading branch information
altescy authored Jun 20, 2024
2 parents 3d03c74 + 24b6486 commit e57b609
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 98 deletions.
6 changes: 3 additions & 3 deletions examples/breast_cancer/breast_cancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tango
from tango.common import Registrable

from tango_mlflow.step import MLflowStep
from tango_mlflow.step import MlflowStep


class Model(abc.ABC, Registrable):
Expand Down Expand Up @@ -169,7 +169,7 @@ def run( # type: ignore[override]


@tango.Step.register("train_model")
class TrainModel(MLflowStep):
class TrainModel(MlflowStep):
def run( # type: ignore[override]
self,
dataset: Tuple[numpy.ndarray, numpy.ndarray],
Expand All @@ -190,7 +190,7 @@ def training_callback(metrics: Dict[str, float]) -> None:


@tango.Step.register("train_preprocessor")
class TrainPreprocessor(MLflowStep):
class TrainPreprocessor(MlflowStep):
def run( # type: ignore[override]
self,
dataset: Tuple[numpy.ndarray, numpy.ndarray],
Expand Down
8 changes: 4 additions & 4 deletions tango_mlflow/commands/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
get_mlflow_run_by_tango_run,
is_all_child_run_finished,
)
from tango_mlflow.workspace import MLFlowWorkspace
from tango_mlflow.workspace import MlflowWorkspace

logger = getLogger(__name__)

Expand Down Expand Up @@ -193,15 +193,15 @@ def _objective(
)

workspace = Workspace.from_params(Params(tango_settings.workspace or {}))
assert isinstance(workspace, MLFlowWorkspace)
assert isinstance(workspace, MlflowWorkspace)

mlflow_run = get_mlflow_run_by_tango_run(
workspace.mlflow_client,
workspace.experiment_name,
tango_run=run_name,
)
if mlflow_run is None:
raise RuntimeError(f"Could not find MLFlow run for tango run {run_name}")
raise RuntimeError(f"Could not find MLflow run for tango run {run_name}")

if mlflow_run.info.status == "RUNNING":
status = (
Expand Down Expand Up @@ -244,7 +244,7 @@ def _objective(
)

workspace = Workspace.from_params(tango_settings.workspace or {})
if not isinstance(workspace, MLFlowWorkspace):
if not isinstance(workspace, MlflowWorkspace):
raise ValueError("Tango workspace type must be mlflow.")

optuna_settings = OptunaSettings.from_args(args)
Expand Down
22 changes: 11 additions & 11 deletions tango_mlflow/flax_train_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from tango.common.exceptions import IntegrationMissingError

from tango_mlflow.util import RunKind, flatten_dict, get_mlflow_run_by_tango_step, get_timestamp
from tango_mlflow.workspace import MLFlowWorkspace
from tango_mlflow.workspace import MlflowWorkspace

with suppress(ModuleNotFoundError, IntegrationMissingError):
import jax
from flax import jax_utils
from tango.integrations.flax.train_callback import TrainCallback

@TrainCallback.register("mlflow::log_flax")
class MLFlowFlaxTrainCallback(TrainCallback):
class MlflowFlaxTrainCallback(TrainCallback):
def __init__(
self,
*args: Any,
Expand All @@ -29,7 +29,7 @@ def __init__(
) -> None:
super().__init__(*args, **kwargs)

if isinstance(self.workspace, MLFlowWorkspace):
if isinstance(self.workspace, MlflowWorkspace):
experiment_name = experiment_name or self.workspace.experiment_name
tracking_uri = tracking_uri or self.workspace.mlflow_tracking_uri

Expand All @@ -46,13 +46,13 @@ def __init__(

@property
def mlflow_client(self) -> mlflow.tracking.MlflowClient:
if isinstance(self.workspace, MLFlowWorkspace):
if isinstance(self.workspace, MlflowWorkspace):
return self.workspace.mlflow_client
return mlflow.tracking.MlflowClient(tracking_uri=self.tracking_uri)

def ensure_mlflow_run(self) -> MlflowRun:
if self.mlflow_run is None:
raise RuntimeError("MLFlow run not initialized")
raise RuntimeError("Mlflow run not initialized")
return self.mlflow_run

def state_dict(self) -> Dict[str, Any]:
Expand All @@ -62,18 +62,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.resume = "allow"

def pre_train_loop(self) -> None:
if isinstance(self.workspace, MLFlowWorkspace):
# Use existing MLFlow run created by the MLFlowWorkspace
if isinstance(self.workspace, MlflowWorkspace):
# Use existing Mlflow run created by the MlflowWorkspace
self.mlflow_run = get_mlflow_run_by_tango_step(
self.mlflow_client,
experiment=self.experiment_name,
tango_step=self.step_id,
additional_filter_string="attributes.status = 'RUNNING'",
)
if self.mlflow_run is None:
raise RuntimeError(f"Could not find a running MLFlow run for step {self.step_id}")
raise RuntimeError(f"Could not find a running Mlflow run for step {self.step_id}")
else:
# Create a new MLFlow run and log the config
# Create a new Mlflow run and log the config
self.mlflow_run = self.mlflow_client.create_run(
experiment_id=mlflow.get_experiment_by_name(self.experiment_name).experiment_id,
tags=context_registry.resolve_tags(
Expand All @@ -93,8 +93,8 @@ def pre_train_loop(self) -> None:
self.mlflow_client.log_batch(self.mlflow_run.info.run_id, metrics=metrics)

def post_train_loop(self, step: int, epoch: int) -> None:
if isinstance(self.workspace, MLFlowWorkspace):
# We don't need to do anything here, as the MLFlow run will be closed by the MLFlowWorkspace
if isinstance(self.workspace, MlflowWorkspace):
# We don't need to do anything here, as the Mlflow run will be closed by the MlflowWorkspace
return
mlflow_run = self.ensure_mlflow_run()
self.mlflow_client.set_terminated(mlflow_run.info.run_id)
Expand Down
8 changes: 4 additions & 4 deletions tango_mlflow/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@

import dacite
import pandas
from mlflow.entities import Run as MLFlowRun
from mlflow.entities import Run as MlflowRun
from mlflow.tracking import MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_LOGGED_ARTIFACTS
from tango.common import PathOrStr
from tango.format import Format, JsonFormat, JsonFormatIterator


@runtime_checkable
class MLFlowFormat(Protocol):
class MlflowFormat(Protocol):
def get_mlflow_artifact_path(self) -> str:
...

def mlflow_callback(self, client: MlflowClient, run: MLFlowRun) -> None:
def mlflow_callback(self, client: MlflowClient, run: MlflowRun) -> None:
...


Expand Down Expand Up @@ -203,7 +203,7 @@ def read(self, dir: PathOrStr) -> T_TableFormattable:
def get_mlflow_artifact_path(self) -> str:
return self._FILENAME

def mlflow_callback(self, client: MlflowClient, run: MLFlowRun) -> None:
def mlflow_callback(self, client: MlflowClient, run: MlflowRun) -> None:
client.set_tag(
run.info.run_id,
MLFLOW_LOGGED_ARTIFACTS,
Expand Down
30 changes: 15 additions & 15 deletions tango_mlflow/step.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import typing
from typing import Any, Dict, Optional, Protocol, TypeVar

from mlflow.entities import Run as MLflowRun
from mlflow.entities import Run as MlflowRun
from mlflow.tracking import MlflowClient
from tango.step import Step

T = TypeVar("T")


@typing.runtime_checkable
class MLflowSummaryStep(Protocol):
class MlflowSummaryStep(Protocol):
MLFLOW_SUMMARY: bool


class MLflowLogger:
def __init__(self, mlflow_run: MLflowRun):
class MlflowLogger:
def __init__(self, mlflow_run: MlflowRun):
self._mlflow_run = mlflow_run

@property
def mlflow_run(self) -> MLflowRun:
def mlflow_run(self) -> MlflowRun:
return self._mlflow_run

@property
Expand All @@ -45,28 +45,28 @@ def log_metrics(self, metrics: Dict[str, float]) -> None:
self.log_metric(key, value)


class MLflowStep(Step[T]):
class MlflowStep(Step[T]):
MLFLOW_SUMMARY: bool = False

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._mlflow_run: Optional[MLflowRun] = None
self._mlflow_logger: Optional[MLflowLogger] = None
self._mlflow_run: Optional[MlflowRun] = None
self._mlflow_logger: Optional[MlflowLogger] = None

def setup_mlflow(self, mlflow_run: MLflowRun) -> None:
def setup_mlflow(self, mlflow_run: MlflowRun) -> None:
if self._mlflow_run is not None:
raise RuntimeError("MLflow run already set")
raise RuntimeError("Mlflow run already set")
self._mlflow_run = mlflow_run
self._mlflow_logger = MLflowLogger(self._mlflow_run)
self._mlflow_logger = MlflowLogger(self._mlflow_run)

@property
def mlflow_run(self) -> MLflowRun:
def mlflow_run(self) -> MlflowRun:
if self._mlflow_run is None:
raise RuntimeError("MLflow run not set")
raise RuntimeError("Mlflow run not set")
return self._mlflow_run

@property
def mlflow_logger(self) -> MLflowLogger:
def mlflow_logger(self) -> MlflowLogger:
if self._mlflow_logger is None:
raise RuntimeError("MLflow logger not set")
raise RuntimeError("Mlflow logger not set")
return self._mlflow_logger
10 changes: 5 additions & 5 deletions tango_mlflow/step_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Optional, Union, cast

import mlflow
from mlflow.entities import Run as MLFlowRun
from mlflow.entities import Run as MlflowRun
from mlflow.tracking.client import MlflowClient
from tango.common.aliases import PathOrStr
from tango.common.file_lock import AcquireReturnProxy, FileLock
Expand All @@ -17,7 +17,7 @@
from tango.step_caches.local_step_cache import LocalStepCache
from tango.step_info import StepInfo

from tango_mlflow.format import MLFlowFormat
from tango_mlflow.format import MlflowFormat
from tango_mlflow.util import (
RunKind,
get_mlflow_local_artifact_storage_path,
Expand All @@ -29,7 +29,7 @@


@StepCache.register("mlflow")
class MLFlowStepCache(LocalStepCache):
class MlflowStepCache(LocalStepCache):
def __init__(self, experiment_name: str) -> None:
super().__init__(tango_cache_dir() / "mlflow_cache")
self.experiment_name = experiment_name
Expand All @@ -54,7 +54,7 @@ def step_dir(self, step: Union[Step, StepInfo, str]) -> Path:

return mlflow_local_artifact_storage_path

def get_step_result_mlflow_run(self, step: Union[Step, StepInfo]) -> Optional[MLFlowRun]:
def get_step_result_mlflow_run(self, step: Union[Step, StepInfo]) -> Optional[MlflowRun]:
return get_mlflow_run_by_tango_step(
self.mlflow_client,
self.experiment_name,
Expand All @@ -81,7 +81,7 @@ def create_step_result_artifact(
if objects_dir is not None:
self.mlflow_client.log_artifacts(mlflow_run.info.run_id, objects_dir)

if isinstance(step, Step) and isinstance(step.FORMAT, MLFlowFormat):
if isinstance(step, Step) and isinstance(step.FORMAT, MlflowFormat):
step.FORMAT.mlflow_callback(self.mlflow_client, mlflow_run)

def _acquire_step_lock_file(
Expand Down
22 changes: 11 additions & 11 deletions tango_mlflow/torch_train_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from tango.common.exceptions import IntegrationMissingError

from tango_mlflow.util import RunKind, flatten_dict, get_mlflow_run_by_tango_step, get_timestamp
from tango_mlflow.workspace import MLFlowWorkspace
from tango_mlflow.workspace import MlflowWorkspace

with suppress(ModuleNotFoundError, IntegrationMissingError):
import torch
from tango.integrations.torch.train_callback import TrainCallback
from tango.integrations.torch.util import peak_gpu_memory

@TrainCallback.register("mlflow::log")
class MLFlowTrainCallback(TrainCallback):
class MlflowTrainCallback(TrainCallback):
def __init__(
self,
*args: Any,
Expand All @@ -32,7 +32,7 @@ def __init__(

super().__init__(*args, **kwargs)

if isinstance(self.workspace, MLFlowWorkspace):
if isinstance(self.workspace, MlflowWorkspace):
experiment_name = experiment_name or self.workspace.experiment_name
tracking_uri = tracking_uri or self.workspace.mlflow_tracking_uri

Expand All @@ -49,13 +49,13 @@ def __init__(

@property
def mlflow_client(self) -> mlflow.tracking.MlflowClient:
if isinstance(self.workspace, MLFlowWorkspace):
if isinstance(self.workspace, MlflowWorkspace):
return self.workspace.mlflow_client
return mlflow.tracking.MlflowClient(tracking_uri=self.tracking_uri)

def ensure_mlflow_run(self) -> MlflowRun:
if self.mlflow_run is None:
raise RuntimeError("MLFlow run not initialized")
raise RuntimeError("Mlflow run not initialized")
return self.mlflow_run

def state_dict(self) -> Dict[str, Any]:
Expand All @@ -65,18 +65,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
pass

def pre_train_loop(self) -> None:
if isinstance(self.workspace, MLFlowWorkspace):
# Use existing MLFlow run created by the MLFlowWorkspace
if isinstance(self.workspace, MlflowWorkspace):
# Use existing Mlflow run created by the MlflowWorkspace
self.mlflow_run = get_mlflow_run_by_tango_step(
self.mlflow_client,
experiment=self.experiment_name,
tango_step=self.step_id,
additional_filter_string="attributes.status = 'RUNNING'",
)
if self.mlflow_run is None:
raise RuntimeError(f"Could not find a running MLFlow run for step {self.step_id}")
raise RuntimeError(f"Could not find a running Mlflow run for step {self.step_id}")
else:
# Create a new MLFlow run and log the config
# Create a new Mlflow run and log the config
self.mlflow_run = self.mlflow_client.create_run(
experiment_id=mlflow.get_experiment_by_name(self.experiment_name).experiment_id,
tags=context_registry.resolve_tags(
Expand Down Expand Up @@ -104,8 +104,8 @@ def pre_train_loop(self) -> None:
self.mlflow_client.log_batch(self.mlflow_run.info.run_id, metrics=metrics)

def post_train_loop(self, step: int, epoch: int) -> None:
if isinstance(self.workspace, MLFlowWorkspace):
# We don't need to do anything here, as the MLFlow run will be closed by the MLFlowWorkspace
if isinstance(self.workspace, MlflowWorkspace):
# We don't need to do anything here, as the Mlflow run will be closed by the MlflowWorkspace
return
if self.is_local_main_process:
mlflow_run = self.ensure_mlflow_run()
Expand Down
Loading

0 comments on commit e57b609

Please sign in to comment.