Source code for nbprint.config.core.parameters

import ast
from json import dumps
from typing import Any, Dict

from ccflow import ContextBase
from nbformat import NotebookNode
from pydantic import Field, field_validator, model_serializer, model_validator

from nbprint.config.base import BaseModel, Role

__all__ = ("PapermillParameters", "Parameters")


[docs] class Parameters(ContextBase, BaseModel): tags: list[str] = Field(default_factory=list) role: Role = Role.PARAMETERS ignore: bool = True @field_validator("tags", mode="after") @classmethod def _ensure_tags(cls, v: list[str]) -> list[str]: if "nbprint:parameters" not in v: v.append("nbprint:parameters") if "parameters" not in v: v.append("parameters") return v
[docs] def generate(self, metadata: dict, **_) -> NotebookNode: cell = self._base_generate_meta(metadata=metadata) # if nb_vars: # # add parameter variable # nb_vars.add(k) mod = ast.Module(body=[], type_ignores=[]) # NOTE: use model_dump(mode="json") here to be compatible with # papermill-based json parameters for i, (k, v) in enumerate(self.model_dump(mode="json", exclude={"type", "tags", "role", "ignore"}).items()): if isinstance(v, bool): # Handle separately to_write = str(v) else: to_write = dumps(v) if isinstance(to_write, str) and (": true," in to_write or ": false," in to_write): to_write = to_write.replace(": true,", ": True,").replace(": false,", ": False,") mod.body.append( ast.Assign( targets=[ast.Name(id=k, ctx=ast.Store())], value=ast.parse(to_write, mode="eval").body, lineno=i, ) ) source = ast.unparse(mod).replace('"', '\\"') cell.source = source return cell
[docs] class PapermillParameters(Parameters): """Papermill parameters function implicitly as a dict""" vars: Dict[str, Any] = Field(default_factory=dict)
[docs] def generate(self, metadata: dict, **_) -> NotebookNode: cell = self._base_generate_meta(metadata=metadata) mod = ast.Module(body=[], type_ignores=[]) # Create a dictionary assignment for parameters for i, (k, v) in enumerate(self.vars.items()): if isinstance(v, bool): to_write = str(v) else: to_write = dumps(v) if isinstance(to_write, str) and (": true," in to_write or ": false," in to_write): to_write = to_write.replace(": true,", ": True,").replace(": false,", ": False,") mod.body.append( ast.Assign( targets=[ast.Name(id=k, ctx=ast.Store())], value=ast.parse(to_write, mode="eval").body, lineno=i, ) ) source = ast.unparse(mod).replace('"', '\\"') cell.source = source return cell
@model_validator(mode="before") @classmethod def _model_before_validator(cls, data) -> dict: # Move all fields except those defined in Parameters to vars params_fields = set(cls.model_fields.keys()) | {"_target_"} vars_dict = {k: v for k, v in data.items() if k not in params_fields} for k in vars_dict: data.pop(k) if "vars" not in data: data["vars"] = {} data["vars"].update(vars_dict) return data @model_serializer(mode="wrap") def _serialize_model(self, handler) -> dict[str, object]: serialized = handler(self) serialized = {k: v for k, v in serialized.items() if k != "vars"} serialized.update(self.vars) return serialized # NOTE: this shouldve been possible via a wrap or before validator, # but alas i could not get it to work def __setattr__(self, name: str, value) -> None: if name in self.__class__.model_fields and name != "vars": super().__setattr__(name, value) elif name == "vars": self.vars.update(value) else: self.vars[name] = value