Source code for nbprint.config.hydra
from ast import literal_eval
from logging import getLogger
from pathlib import Path
from typing import List, Optional
from ccflow import ModelRegistry
from hydra import compose, initialize_config_dir
from nbformat import read as nb_read
from omegaconf import DictConfig, OmegaConf, open_dict
__all__ = ("load_config",)
_logger = getLogger(__name__)
[docs]
def load_config(
path: str,
overrides: List[str] | None = None,
) -> dict:
# convert to Path
path = Path(path)
if not isinstance(overrides, list):
# maybe running via python, reset
overrides = []
# prune any empty strings
overrides = [o for o in overrides if o]
# TODO: right now, nbprint runs off a specific config file, whereas
# hydra takes a config dir and config name. For now we use the nbprint
# style, so we adjust accordingly
if path.suffix in (".yaml",):
config_dir = str(path.parent.absolute().resolve())
config_name = str(path.name)
is_default = False
else:
# Use base
config_dir = str((Path(__file__).parent).absolute().resolve())
config_name = "base.yaml"
is_default = True
with initialize_config_dir(config_dir=config_dir, version_base=None):
if not is_default:
cfg = compose(config_name=config_name, overrides=[], return_hydra_config=True)
searchpaths = cfg["hydra"]["searchpath"]
searchpaths.extend([config_dir])
overrides = [*overrides.copy(), f"hydra.searchpath=[{','.join(searchpaths)}]"]
cfg = compose(config_name=config_name, overrides=overrides)
if "nbprint" not in cfg:
_logger.warning("No 'nbprint' config found in the provided configuration. Assuming entire config is for nbprint.")
with open_dict(cfg):
if "_target_" not in cfg:
cfg._target_ = "nbprint.Configuration"
cfg.nbprint = cfg.copy()
# bridge hydra and non-hydra
if "name" not in cfg.nbprint:
with open_dict(cfg.nbprint):
cfg.nbprint.name = path.name.replace(".yaml", "").replace(".ipynb", "")
# If its a notebook, parse it out and run directly
# Read notebook contents and shove into config
if path.suffix in (".ipynb",):
with open_dict(cfg):
cfg.nbprint.notebook = path
if isinstance(cfg, DictConfig):
cfg = OmegaConf.to_container(cfg, resolve=True)
registry = ModelRegistry.root()
registry.load_config(cfg=cfg, overwrite=True)
if "callable" in cfg:
registry.add("callable", registry[cfg["callable"]], overwrite=True)
elif "callable" in cfg.get("nbprint", {}):
registry.add("callable", registry[cfg["nbprint"]["callable"]], overwrite=True)
else:
registry.add("callable", registry["nbprint"], overwrite=True)
return registry