Source code for rdf2vecgpu.logger.mlflow_logger
from contextlib import contextmanager
from typing_extensions import override
from .base import BaseTracker
from loguru import logger
from typing import Dict, Any, Optional
import os
import tempfile
try:
import mlflow
except ImportError as e:
logger.exception("mlflow is not installed. Please install it to use MLflowLogger.")
raise
[docs]
class MlflowTracker(BaseTracker):
def __init__(
self, experiment: str, tracking_uri: str, registry_uri: Optional[str] = None
):
mlflow.set_tracking_uri(tracking_uri)
if registry_uri:
mlflow.set_registry_uri(registry_uri)
mlflow.set_experiment(experiment)
self._parent_run = None
self._active_stage_runs = []
[docs]
@override
def enabled(self) -> bool:
return True
[docs]
@override
def start_pipeline(
self,
run_name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> "MlflowTracker":
if run_name is not None:
self._parent_run = mlflow.start_run(run_name=run_name)
else:
self._parent_run = mlflow.start_run()
if tags:
mlflow.set_tags(tags)
return self
[docs]
@override
@contextmanager
def stage(self, name: str):
run = mlflow.start_run(run_name=name, nested=True)
self._active_stage_runs.append(run)
try:
yield self
finally:
mlflow.end_run()
self._active_stage_runs.pop()
[docs]
@override
def log_params(self, params: Dict[str, Any]):
mlflow.log_params(params)
[docs]
@override
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
if step:
mlflow.log_metrics(metrics, step=step)
else:
mlflow.log_metrics(metrics)
[docs]
@override
def log_artifact(self, path, artifact_path=None):
mlflow.log_artifact(path, artifact_path=artifact_path)
[docs]
@override
def log_data(
self, sample_data, data_name, artifact_path, tags: Optional[Dict[str, str]]
):
dataset = mlflow.data.from_pandas(sample_data, name=data_name)
mlflow.log_table(dataset, artifact_path=artifact_path, tags=tags)
[docs]
@override
def log_pytorch(self):
mlflow.pytorch.autolog(log_models=False)
[docs]
@override
def log_model_pytorch(self, model, artifact_path: str):
mlflow.pytorch.log_model(model, artifact_path=artifact_path)
[docs]
@override
def close(self):
while self._active_stage_runs:
mlflow.end_run()
self._active_stage_runs.pop()
if self._parent_run:
mlflow.end_run()
self._parent_run = None