from contextlib import contextmanager
from typing import Optional, Dict, Any
from .base import BaseTracker
[docs]
class WandbTracker(BaseTracker):
def __init__(
self,
project: str,
entity: Optional[str] = None,
run_name: Optional[str] = None,
config: Optional[dict] = None,
):
import wandb
self.wandb = wandb
self.project = project
self.run = None
self.entity = entity
self.kw = dict(
project=project, entity=entity, name=run_name, config=config or {}
)
[docs]
def enabled(self) -> bool:
return True
[docs]
def start_pipeline(self, run_name, tags=None, resume_run_id=None):
self.kw["name"] = self.kw.get("name") or run_name
self.run = self.wandb.init(**self.kw)
if tags:
self.wandb.run.tags = list(
{*(self.wandb.run.tags or []), *list(tags.values())}
)
return self
[docs]
@contextmanager
def stage(self, name):
self.wandb.run.log({"stage/start": name})
try:
yield self
finally:
self.wandb.run.log({"stage/end": name})
[docs]
def log_params(self, params):
self.wandb.config.update(params, allow_val_change=True)
[docs]
def log_metrics(self, metrics, step=None):
self.wandb.log(metrics if step is None else {**metrics, "step": step})
[docs]
def log_artifact(self, path, artifact_path=None):
at = self.wandb.Artifact(artifact_path or "artifacts", type="artifact")
at.add_file(path)
self.wandb.run.log_artifact(at)
[docs]
def log_dict(self, d, artifact_file, artifact_path=None):
import json, tempfile, os
with tempfile.TemporaryDirectory() as td:
p = os.path.join(td, artifact_file)
with open(p, "w", encoding="utf-8") as f:
json.dump(d, f, indent=2)
self.log_artifact(p, artifact_path or "data")
[docs]
def log_text(self, text, artifact_file, artifact_path=None):
import tempfile, os
with tempfile.TemporaryDirectory() as td:
p = os.path.join(td, artifact_file)
with open(p, "w", encoding="utf-8") as f:
f.write(text)
self.log_artifact(p, artifact_path or "notes")
[docs]
def log_pytorch(self):
pass
[docs]
def log_model_pytorch(self, model, artifact_path: str):
import tempfile
import os
import torch
with tempfile.TemporaryDirectory() as td:
model_path = os.path.join(td, "model.pt")
torch.save(model.state_dict(), model_path)
at = self.wandb.Artifact(artifact_path, type="model")
at.add_file(model_path)
self.wandb.run.log_artifact(at)
[docs]
def log_data(self, sample_data, data_name, artifact_path, tags=None):
tbl = self.wandb.Table(dataframe=sample_data)
self.wandb.log({data_name: tbl})
[docs]
def close(self):
try:
self.wandb.finish()
except Exception:
pass