Source code for pythonwrench.disk_cache

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import os
import shutil
import time
import warnings
from functools import wraps
from pathlib import Path
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Literal,
    Optional,
    Tuple,
    TypedDict,
    TypeVar,
    Union,
    get_args,
    overload,
)

from typing_extensions import ParamSpec

from pythonwrench.checksum import checksum_any
from pythonwrench.datetime import get_now
from pythonwrench.inspect import get_argnames, get_fullname

T = TypeVar("T")
P = ParamSpec("P")
U = TypeVar("U")


ChecksumFn = Callable[[Tuple[Callable[P, T], Tuple, Dict[str, Any]]], int]
SavingBackend = Literal["csv", "json", "pickle"]
StoreMode = Literal["outputs_only", "outputs_metadata", "outputs_metadata_inputs"]

_DEFAULT_CACHE_STORE_MODE: StoreMode = "outputs_only"


class _CacheMeta(TypedDict):
    datetime: str
    duration: float
    checksum: int
    fn_fullname: str
    output: Any
    input: Optional[Tuple[Any, Any]]


_DEFAULT_CACHE_DPATH = Path.home().joinpath(".cache", "disk_cache")


logger = logging.getLogger(__name__)


@overload
def disk_cache_decorator(
    fn: None = None,
    *,
    cache_dpath: Union[str, Path, None] = None,
    cache_force: bool = False,
    cache_verbose: int = 0,
    cache_checksum_fn: ChecksumFn = checksum_any,
    cache_saving_backend: Literal["custom"],
    cache_fname_fmt: Union[
        str, Callable[..., str]
    ] = "{fn_name}_{checksum_hex}{suffix}",
    cache_fname_fmt_args: Optional[Iterable[str]] = None,
    cache_dump_fn: Callable[[Any, Path], Any],
    cache_load_fn: Callable[[Path], Any],
    cache_enable: bool = True,
    cache_store_mode: StoreMode,
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...


@overload
def disk_cache_decorator(
    fn: None = None,
    *,
    cache_dpath: Union[str, Path, None] = None,
    cache_force: bool = False,
    cache_verbose: int = 0,
    cache_checksum_fn: ChecksumFn = checksum_any,
    cache_saving_backend: SavingBackend,
    cache_fname_fmt: Union[
        str, Callable[..., str]
    ] = "{fn_name}_{checksum_hex}{suffix}",
    cache_fname_fmt_args: Optional[Iterable[str]] = None,
    cache_dump_fn: None = None,
    cache_load_fn: None = None,
    cache_enable: bool = True,
    cache_store_mode: StoreMode = _DEFAULT_CACHE_STORE_MODE,
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...


@overload
def disk_cache_decorator(
    fn: None = None,
    *,
    cache_dpath: Union[str, Path, None] = None,
    cache_force: bool = False,
    cache_verbose: int = 0,
    cache_checksum_fn: ChecksumFn = checksum_any,
    cache_saving_backend: Union[SavingBackend, Literal["custom", "auto"]] = "auto",
    cache_fname_fmt: Union[
        str, Callable[..., str]
    ] = "{fn_name}_{checksum_hex}{suffix}",
    cache_fname_fmt_args: Optional[Iterable[str]] = None,
    cache_dump_fn: Optional[Callable[[Any, Path], Any]] = None,
    cache_load_fn: Optional[Callable[[Path], Any]] = None,
    cache_enable: bool = True,
    cache_store_mode: StoreMode = _DEFAULT_CACHE_STORE_MODE,
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...


@overload
def disk_cache_decorator(
    fn: Callable[P, T],
    *,
    cache_dpath: Union[str, Path, None] = None,
    cache_force: bool = False,
    cache_verbose: int = 0,
    cache_checksum_fn: ChecksumFn = checksum_any,
    cache_saving_backend: Literal["custom"],
    cache_fname_fmt: Union[
        str, Callable[..., str]
    ] = "{fn_name}_{checksum_hex}{suffix}",
    cache_fname_fmt_args: Optional[Iterable[str]] = None,
    cache_dump_fn: Callable[[Any, Path], Any],
    cache_load_fn: Callable[[Path], Any],
    cache_enable: bool = True,
    cache_store_mode: StoreMode = _DEFAULT_CACHE_STORE_MODE,
) -> Callable[P, T]: ...


@overload
def disk_cache_decorator(
    fn: Callable[P, T],
    *,
    cache_dpath: Union[str, Path, None] = None,
    cache_force: bool = False,
    cache_verbose: int = 0,
    cache_checksum_fn: ChecksumFn = checksum_any,
    cache_saving_backend: Union[SavingBackend, Literal["custom", "auto"]] = "auto",
    cache_fname_fmt: Union[
        str, Callable[..., str]
    ] = "{fn_name}_{checksum_hex}{suffix}",
    cache_fname_fmt_args: Optional[Iterable[str]] = None,
    cache_dump_fn: Optional[Callable[[Any, Path], Any]] = None,
    cache_load_fn: Optional[Callable[[Path], Any]] = None,
    cache_enable: bool = True,
    cache_store_mode: StoreMode = _DEFAULT_CACHE_STORE_MODE,
) -> Callable[P, T]: ...


[docs] def disk_cache_decorator( fn: Optional[Callable[P, T]] = None, *, cache_dpath: Union[str, Path, None] = None, cache_force: bool = False, cache_verbose: int = 0, cache_checksum_fn: ChecksumFn = checksum_any, cache_saving_backend: Union[SavingBackend, Literal["custom", "auto"]] = "auto", cache_fname_fmt: Union[ str, Callable[..., str] ] = "{fn_name}_{checksum_hex}{suffix}", cache_fname_fmt_args: Optional[Iterable[str]] = None, cache_dump_fn: Optional[Callable[[Any, Path], Any]] = None, cache_load_fn: Optional[Callable[[Path], Any]] = None, cache_enable: bool = True, cache_store_mode: StoreMode = _DEFAULT_CACHE_STORE_MODE, ) -> Callable: """Decorator to store function output in a cache file. Cache file is identified by the checksum of the function arguments, and stored by default in `"~/.cache/disk_cache/<Function_name>/"` directory. Example ------- >>> import pythonwrench as pw >>> @pw.disk_cache_decorator >>> def heavy_processing(): >>> # Lot of stuff here >>> ... >>> outputs = heavy_processing() # first time function is called >>> outputs = heavy_processing() # second time outputs is loaded from disk Args: fn: Function to store its output. By default, it must be a callable that returns a pickable object. cache_dpath: Cache directory path. defaults to `"~/.cache/disk_cache"`. cache_force: Force function call and overwrite cache. defaults to False. cache_verbose: Set verbose logging level. Higher means more verbose. defaults to 0. cache_checksum_fn: Checksum function to identify input arguments. defaults to ``pythonwrench.checksum_any``. cache_saving_backend: Optional saving backend. Can be one of ('csv', 'json', 'pickle', 'custom', 'auto'). defaults to 'auto'. cache_fname_fmt: Cache filename format. defaults to "{fn_name}_{checksum_hex}{suffix}". cache_dump_fn: Dump/save function to store outputs and overwrite saving backend. defaults to None. cache_load_fn: Load function to store outputs and overwrite saving backend. defaults to None. cache_enable: Enable disk cache. If False, the function has no effect. defaults to True. cache_store_mode: Disk cache storage mode. By default, it store function output and saved date into the cache file. defaults to 'outputs_metadata'. """ impl_fn = _disk_cache_impl( cache_dpath=cache_dpath, cache_force=cache_force, cache_verbose=cache_verbose, cache_checksum_fn=cache_checksum_fn, cache_saving_backend=cache_saving_backend, cache_fname_fmt=cache_fname_fmt, cache_fname_fmt_args=cache_fname_fmt_args, cache_dump_fn=cache_dump_fn, cache_load_fn=cache_load_fn, cache_enable=cache_enable, cache_store_mode=cache_store_mode, ) if fn is not None: return impl_fn(fn) else: return impl_fn
@overload def disk_cache_call( fn: Callable[..., T], *args, cache_dpath: Union[str, Path, None] = None, cache_force: bool = False, cache_verbose: int = 0, cache_checksum_fn: ChecksumFn = checksum_any, cache_saving_backend: Literal["custom"], cache_fname_fmt: Union[ str, Callable[..., str] ] = "{fn_name}_{checksum_hex}{suffix}", cache_fname_fmt_args: Optional[Iterable[str]] = None, cache_dump_fn: Callable[[Any, Path], Any], cache_load_fn: Callable[[Path], Any], cache_enable: bool = True, cache_store_mode: StoreMode, **kwargs, ) -> T: ... @overload def disk_cache_call( fn: Callable[..., T], *args, cache_dpath: Union[str, Path, None] = None, cache_force: bool = False, cache_verbose: int = 0, cache_checksum_fn: ChecksumFn = checksum_any, cache_saving_backend: SavingBackend, cache_fname_fmt: Union[ str, Callable[..., str] ] = "{fn_name}_{checksum_hex}{suffix}", cache_fname_fmt_args: Optional[Iterable[str]] = None, cache_dump_fn: None = None, cache_load_fn: None = None, cache_enable: bool = True, cache_store_mode: StoreMode = _DEFAULT_CACHE_STORE_MODE, **kwargs, ) -> T: ... @overload def disk_cache_call( fn: Callable[..., T], *args, cache_dpath: Union[str, Path, None] = None, cache_force: bool = False, cache_verbose: int = 0, cache_checksum_fn: ChecksumFn = checksum_any, cache_saving_backend: Union[SavingBackend, Literal["custom", "auto"]] = "auto", cache_fname_fmt: Union[ str, Callable[..., str] ] = "{fn_name}_{checksum_hex}{suffix}", cache_fname_fmt_args: Optional[Iterable[str]] = None, cache_dump_fn: Optional[Callable[[Any, Path], Any]] = None, cache_load_fn: Optional[Callable[[Path], Any]] = None, cache_enable: bool = True, cache_store_mode: StoreMode = _DEFAULT_CACHE_STORE_MODE, **kwargs, ) -> T: ...
[docs] def disk_cache_call( fn: Callable[..., T], *args, cache_dpath: Union[str, Path, None] = None, cache_force: bool = False, cache_verbose: int = 0, cache_checksum_fn: ChecksumFn = checksum_any, cache_saving_backend: Union[SavingBackend, Literal["custom", "auto"]] = "auto", cache_fname_fmt: Union[ str, Callable[..., str] ] = "{fn_name}_{checksum_hex}{suffix}", cache_fname_fmt_args: Optional[Iterable[str]] = None, cache_dump_fn: Optional[Callable[[Any, Path], Any]] = None, cache_load_fn: Optional[Callable[[Path], Any]] = None, cache_enable: bool = True, cache_store_mode: StoreMode = _DEFAULT_CACHE_STORE_MODE, **kwargs, ) -> T: r"""Call function and store output in a cache file. Cache file is identified by the checksum of the function arguments, and stored by default in '~/.cache/disk_cache/<Function_name>/' directory. Example ------- >>> import pythonwrench as pw >>> def heavy_processing(): >>> # Lot of stuff here >>> ... >>> outputs = pw.disk_cache_call(heavy_processing) # first time function is called >>> outputs = pw.disk_cache_call(heavy_processing) # second time outputs is loaded from disk Args: fn: Function to store its output. By default, it must be a callable that returns a pickable object. cache_dpath: Cache directory path. defaults to '~/.cache/disk_cache'. cache_force: Force function call and overwrite cache. defaults to False. cache_verbose: Set verbose logging level. Higher means more verbose. defaults to 0. cache_checksum_fn: Checksum function to identify input arguments. defaults to ``pythonwrench.checksum_any``. cache_saving_backend: Optional saving backend. Can be one of ('csv', 'json', 'pickle', 'custom', 'auto'). defaults to 'auto'. cache_fname_fmt: Cache filename format. defaults to '{fn_name}_{checksum_hex}{suffix}'. cache_dump_fn: Dump/save function to store outputs and overwrite saving backend. defaults to None. cache_load_fn: Load function to store outputs and overwrite saving backend. defaults to None. cache_enable: Enable disk cache. If False, the function has no effect. defaults to True. cache_store_mode: Disk cache storage mode. By default, it store function output and saved date into the cache file. defaults to 'outputs_metadata'. \*args: Positional arguments passed to the function. \*\*kwargs: Keywords arguments passed to the function. """ wrapped_fn = _disk_cache_impl( cache_dpath=cache_dpath, cache_force=cache_force, cache_verbose=cache_verbose, cache_checksum_fn=cache_checksum_fn, cache_saving_backend=cache_saving_backend, cache_fname_fmt=cache_fname_fmt, cache_fname_fmt_args=cache_fname_fmt_args, cache_dump_fn=cache_dump_fn, cache_load_fn=cache_load_fn, cache_enable=cache_enable, cache_store_mode=cache_store_mode, ) return wrapped_fn(fn)(*args, **kwargs)
def _disk_cache_impl( *, cache_dpath: Union[str, Path, None] = None, cache_force: bool = False, cache_verbose: int = 0, cache_checksum_fn: ChecksumFn = checksum_any, cache_saving_backend: Union[SavingBackend, Literal["custom", "auto"]] = "auto", cache_fname_fmt: Union[ str, Callable[..., str] ] = "{fn_name}_{checksum_hex}{suffix}", cache_fname_fmt_args: Optional[Iterable[str]] = None, cache_dump_fn: Optional[Callable[[Any, Path], Any]] = None, cache_load_fn: Optional[Callable[[Path], Any]] = None, cache_enable: bool = True, cache_store_mode: StoreMode = _DEFAULT_CACHE_STORE_MODE, ) -> Callable[[Callable[P, T]], Callable[P, T]]: # for backward compatibility if cache_fname_fmt is None: expected = "{fn_name}_{csum}{suffix}" msg = f"Deprecated argument value {cache_fname_fmt=}. (use {expected} instead)" warnings.warn(msg, DeprecationWarning) cache_fname_fmt = expected if cache_saving_backend is None: expected = "auto" msg = f"Deprecated argument value {cache_saving_backend=}. (use {expected} instead)" warnings.warn(msg, DeprecationWarning) cache_saving_backend = expected if cache_saving_backend == "auto": if cache_dump_fn is not None and cache_load_fn is not None: cache_saving_backend = "custom" else: cache_saving_backend = "pickle" if cache_saving_backend == "pickle": from pythonwrench.pickle import dump_pickle, load_pickle suffix = ".pickle" cache_dump_fn = dump_pickle cache_load_fn = load_pickle elif cache_saving_backend == "json": from pythonwrench.json import dump_json, load_json suffix = ".json" cache_dump_fn = dump_json cache_load_fn = load_json elif cache_saving_backend == "csv": from pythonwrench.csv import dump_csv, load_csv if cache_store_mode != "outputs_only": msg = f"Invalid combinaison of arguments {cache_saving_backend=} with {cache_store_mode=}." raise ValueError(msg) suffix = ".csv" cache_dump_fn = dump_csv cache_load_fn = load_csv elif cache_saving_backend == "custom": if cache_dump_fn is None or cache_load_fn is None: msg = f"If {cache_saving_backend=}, arguments cache_dump_fn and cache_load_fn cannot be None. (found {cache_dump_fn=} {cache_load_fn=})" raise ValueError(msg) suffix = "" else: msg = f"Invalid argument {cache_saving_backend=}. (expected one of {get_args(SavingBackend)})" raise ValueError(msg) if isinstance(cache_fname_fmt, str): cache_fname_fmt = cache_fname_fmt.format def _disk_cache_impl_fn(fn: Callable[P, T]) -> Callable[P, T]: fn_fullname, fn_name = _get_fn_fullname_and_name(fn) cache_fn_dpath = _get_fn_cache_dpath(fn, cache_dpath=cache_dpath) if cache_force: compute_start_msg = f"[{fn_name}] Force mode enabled, computing outputs'... (started at {{now}})" else: compute_start_msg = ( f"[{fn_name}] Cache missed, computing outputs... (started at {{now}})" ) compute_end_msg = ( f"[{fn_name}] Outputs computed in {{duration:.1f}}s. (ended at {{now}})" ) load_start_msg = f"[{fn_name}] Loading cache..." load_end_msg = f"[{fn_name}] Cache loaded." argnames = get_argnames(fn) @wraps(fn) def _disk_cache_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: checksum_args = fn, args, kwargs kwds = {} if cache_fname_fmt_args is None or "fn_name" in cache_fname_fmt_args: kwds["fn_name"] = fn_name if cache_fname_fmt_args is None or "fn_fullname" in cache_fname_fmt_args: kwds["fn_fullname"] = fn_fullname if cache_fname_fmt_args is None or "suffix" in cache_fname_fmt_args: kwds["suffix"] = suffix if cache_fname_fmt_args is None or any( k in cache_fname_fmt_args for k in ("csum", "checksum", "checksum_hex") ): csum = cache_checksum_fn(checksum_args) kwds["checksum"] = csum kwds["checksum_hex"] = hex(csum)[2:] else: csum = None inputs_kwds = { argname: argval for argname, argval in zip(argnames, args) if cache_fname_fmt_args is None or argname in cache_fname_fmt_args } kwds.update(inputs_kwds) kwds.update(kwargs) cache_fname = cache_fname_fmt(**kwds) cache_fpath = cache_fn_dpath.joinpath(cache_fname) if not cache_enable: output = fn(*args, **kwargs) elif cache_force or not cache_fpath.exists(): if cache_verbose > 0: logger.info(compute_start_msg.format(now=get_now())) start = time.perf_counter() output = fn(*args, **kwargs) duration = time.perf_counter() - start if cache_verbose > 0: logger.info( compute_end_msg.format(now=get_now(), duration=duration) ) if cache_store_mode == "outputs_only": cache_content = output elif ( cache_store_mode == "outputs_metadata" or cache_store_mode == "outputs_metadata_inputs" ): input = ( (args, kwargs) if cache_store_mode == "outputs_metadata_inputs" else None ) cache_content = { "datetime": get_now(), "duration": duration, "checksum": csum, "fn_fullname": fn_fullname, "output": output, "input": input, } else: msg = f"Invalid argument {cache_store_mode=}. (expected one of {get_args(StoreMode)})" raise ValueError(msg) cache_fn_dpath.mkdir(parents=True, exist_ok=True) cache_dump_fn(cache_content, cache_fpath) # type: ignore elif cache_fpath.is_file(): if cache_verbose > 0: logger.info(load_start_msg) cache_content: Any = cache_load_fn(cache_fpath) if cache_store_mode == "outputs_only": output = cache_content elif cache_store_mode == "outputs_metadata": output = cache_content["output"] elif cache_store_mode == "outputs_metadata_inputs": output = cache_content["output"] input_ = cache_content["input"] if input_ is not None and input_ != (args, kwargs): os.remove(cache_fpath) return _disk_cache_wrapper(*args, **kwargs) else: msg = f"Invalid argument {cache_store_mode=}. (expected one of {get_args(StoreMode)})" raise ValueError(msg) if cache_verbose > 0: logger.info(load_end_msg) if cache_store_mode != "outputs_only" and cache_verbose > 1: metadata = {k: v for k, v in cache_content.items() if k != "output"} msgs = f"Found cache metadata:\n{metadata}".split("\n") for msg in msgs: logger.debug(msg) else: msg = f"Path {str(cache_fpath)} exists but it is not a file." raise RuntimeError(msg) return output _disk_cache_wrapper.fn = fn # type: ignore return _disk_cache_wrapper return _disk_cache_impl_fn
[docs] def get_cache_dpath(cache_dpath: Union[str, Path, None] = None) -> Path: """Returns defaults disk cache directory path, which is `~/.cache/disk_cache`.""" if cache_dpath is None: cache_dpath = _DEFAULT_CACHE_DPATH else: cache_dpath = Path(cache_dpath) return cache_dpath
[docs] def remove_fn_cache( fn: Callable, *, cache_dpath: Union[str, Path, None] = None, ) -> None: """Removes all caches for a specific function.""" cache_fn_dpath = _get_fn_cache_dpath(fn, cache_dpath=cache_dpath) if cache_fn_dpath.is_dir(): shutil.rmtree(cache_fn_dpath)
def _get_fn_cache_dpath( fn: Callable, *, cache_dpath: Union[str, Path, None] = None, ) -> Path: _, fn_name = _get_fn_fullname_and_name(fn) cache_dpath = get_cache_dpath(cache_dpath) cache_fn_dpath = cache_dpath.joinpath(fn_name) return cache_fn_dpath def _get_fn_fullname_and_name(fn: Callable) -> Tuple[str, str]: fn_fullname = get_fullname(fn, inst_suffix="").replace("<locals>", "_locals_") fn_name = fn_fullname.split(".")[-1] return fn_fullname, fn_name