#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import operator
import random
import sys
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Hashable,
Iterable,
Iterator,
List,
Literal,
Mapping,
MutableSequence,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
get_args,
overload,
)
from typing_extensions import TypeGuard, TypeIs
from pythonwrench.collections.prop import all_eq
from pythonwrench.collections.reducers import reduce_or
from pythonwrench.functools import identity
from pythonwrench.semver import Version
from pythonwrench.typing.checks import is_builtin_scalar, isinstance_generic
from pythonwrench.typing.classes import T_BuiltinScalar
K = TypeVar("K", covariant=True, bound=Hashable)
T = TypeVar("T", covariant=True)
U = TypeVar("U", covariant=True)
V = TypeVar("V", covariant=True)
W = TypeVar("W", covariant=True)
X = TypeVar("X", covariant=True)
Y = TypeVar("Y", covariant=True)
KeyMode = Literal["intersect", "same", "union"]
Order = Literal["left", "right"]
[docs]
class SizedGenerator(Generic[T]):
"""Wraps a generator and size to provide a sized iterable object."""
def __init__(self, generator: Generator[T, None, None], size: int) -> None:
super().__init__()
self._generator = generator
self._size = size
def __iter__(self) -> Iterator[T]:
yield from self._generator
def __len__(self) -> int:
return self._size
[docs]
def contained(
x: T,
include: Optional[Iterable[T]] = None,
exclude: Optional[Iterable[T]] = None,
*,
match_fn: Callable[[T, T], bool] = operator.eq,
order: Literal["left", "right"] = "right",
) -> bool:
"""Returns True if name in include set and not in exclude set."""
if (
include is not None
and find(x, include, match_fn=match_fn, order=order, default=-1) == -1
):
return False
if (
exclude is not None
and find(x, exclude, match_fn=match_fn, order=order, default=-1) != -1
):
return False
return True
@overload
def dict_list_to_list_dict(
dic: Mapping[T, Iterable[U]],
key_mode: Literal["union"] = "union",
default_val: W = None,
) -> List[Dict[T, Union[U, W]]]: ...
@overload
def dict_list_to_list_dict(
dic: Mapping[T, Iterable[U]],
key_mode: Literal["same", "intersect"],
default_val: Any = None,
) -> List[Dict[T, U]]: ...
[docs]
def dict_list_to_list_dict(
dic: Mapping[T, Iterable[U]],
key_mode: KeyMode = "union",
default_val: W = None,
) -> List[Dict[T, Union[U, W]]]:
"""Convert dict of lists with same sizes to list of dicts.
Example 1
---------
>>> dic = {"a": [1, 2], "b": [3, 4]}
>>> dict_list_to_list_dict(dic)
... [{"a": 1, "b": 3}, {"a": 2, "b": 4}]
Example 2
---------
>>> dic = {"a": [1, 2, 3], "b": [4], "c": [5, 6]}
>>> dict_list_to_list_dict(dic, key_mode="union", default=-1)
... [{"a": 1, "b": 4, "c": 5}, {"a": 2, "b": -1, "c": 6}, {"a": 3, "b": -1, "c": -1}]
"""
if len(dic) == 0:
return []
dic = {k: list(v) if not isinstance(v, Sequence) else v for k, v in dic.items()}
lengths = [len(seq) for seq in dic.values()]
if key_mode == "same":
if not all_eq(lengths):
msg = f"Invalid sequences for batch. (found different lengths in sub-lists: {set(lengths)})"
raise ValueError(msg)
length = lengths[0]
elif key_mode == "intersect":
length = min(lengths)
elif key_mode == "union":
length = max(lengths)
else:
msg = f"Invalid argument key_mode={key_mode}. (expected one of {get_args(KeyMode)})"
raise ValueError(msg)
result = [
{k: (v[i] if i < len(v) else default_val) for k, v in dic.items()}
for i in range(length)
]
return result
[docs]
def dump_dict(
dic: Optional[Mapping[str, T]] = None,
/,
join: str = ", ",
fmt: str = "{key}={value}",
ignore_lst: Iterable[T] = (),
**kwargs,
) -> str:
"""Dump dictionary of scalars to string function to customize representation.
Example 1:
----------
>>> d = {"a": 1, "b": 2}
>>> dump_dict(d)
... 'a=1, b=2'
"""
if dic is None:
dic = {}
else:
dic = dict(dic.items())
dic.update(kwargs)
ignore_lst = dict.fromkeys(ignore_lst)
result = join.join(
fmt.format(key=key, value=value)
for key, value in dic.items()
if value not in ignore_lst
)
return result
[docs]
def filter_iterable(
it: Iterable[T],
include: Optional[Iterable[T]] = None,
exclude: Optional[Iterable[T]] = None,
*,
match_fn: Callable[[T, T], bool] = operator.eq,
order: Literal["left", "right"] = "right",
) -> List[T]:
return [
item
for item in it
if contained(
item,
include=include,
exclude=exclude,
match_fn=match_fn,
order=order,
)
]
@overload
def find(
target: T,
it: Iterable[V],
*,
match_fn: Callable[[V, T], bool] = operator.eq,
order: Literal["right"] = "right",
default: U = -1,
return_value: Literal[False] = False,
) -> Union[int, U]: ...
@overload
def find(
target: T,
it: Iterable[V],
*,
match_fn: Callable[[T, V], bool] = operator.eq,
order: Literal["left"],
default: U = -1,
return_value: Literal[False] = False,
) -> Union[int, U]: ...
@overload
def find(
target: T,
it: Iterable[V],
*,
match_fn: Callable[[V, T], bool] = operator.eq,
order: Literal["right"] = "right",
default: U = -1,
return_value: Literal[True],
) -> Tuple[Union[int, U], Union[V, U]]: ...
@overload
def find(
target: T,
it: Iterable[V],
*,
match_fn: Callable[[T, V], bool] = operator.eq,
order: Literal["left"],
default: U = -1,
return_value: Literal[True],
) -> Tuple[Union[int, U], Union[V, U]]: ...
[docs]
def find(
target: Any,
it: Iterable[V],
*,
match_fn: Callable[[Any, Any], bool] = operator.eq,
order: Order = "right",
default: U = -1,
return_value: bool = False,
) -> Union[int, U, Tuple[Union[int, U], Union[V, U]]]:
if not return_value:
result = find(
target,
it,
match_fn=match_fn,
order=order,
default=default,
return_value=True,
)
return result[0]
if order == "right":
pass
elif order == "left":
def revert(f):
def reverted_f(a, b):
return f(b, a)
return reverted_f
match_fn = revert(match_fn)
else:
raise ValueError(
f"Invalid argument {order=}. (expected one of {get_args(Order)})"
)
for i, xi in enumerate(it):
if match_fn(xi, target):
return i, xi
return default, default
@overload
def flatten(
x: T_BuiltinScalar,
start_dim: int = 0,
end_dim: Optional[int] = None,
) -> List[T_BuiltinScalar]: ...
@overload
def flatten( # type: ignore
x: Iterable[T_BuiltinScalar],
start_dim: int = 0,
end_dim: Optional[int] = None,
) -> List[T_BuiltinScalar]: ...
@overload
def flatten(
x: Any,
start_dim: int = 0,
end_dim: Optional[int] = None,
is_scalar_fn: Union[
Callable[[Any], TypeGuard[T]], Callable[[Any], TypeIs[T]]
] = is_builtin_scalar,
) -> List[Any]: ...
[docs]
def flatten(
x: Any,
start_dim: int = 0,
end_dim: Optional[int] = None,
is_scalar_fn: Union[
Callable[[Any], TypeGuard[T]], Callable[[Any], TypeIs[T]]
] = is_builtin_scalar,
) -> List[Any]:
if end_dim is None:
end_dim = sys.maxsize
if start_dim < 0:
raise ValueError(f"Invalid argument {start_dim=}. (expected positive integer)")
if end_dim < 0:
raise ValueError(f"Invalid argument {end_dim=}. (expected positive integer)")
if start_dim > end_dim:
msg = f"Invalid arguments {start_dim=} and {end_dim=}. (expected start_dim <= end_dim)"
raise ValueError(msg)
def flatten_impl(x: Any, start_dim: int, end_dim: int) -> List[Any]:
if is_scalar_fn(x):
return [x]
elif isinstance(x, Iterable):
if start_dim > 0:
return [flatten_impl(xi, start_dim - 1, end_dim - 1) for xi in x]
elif end_dim > 0:
return [
xij
for xi in x
for xij in flatten_impl(xi, start_dim - 1, end_dim - 1)
]
else:
return list(x)
else:
raise TypeError(f"Invalid argument type {type(x)=}.")
return flatten_impl(x, start_dim, end_dim)
[docs]
def flat_dict_of_dict(
nested_dic: Mapping[str, Any],
*,
sep: str = ".",
flat_iterables: bool = False,
overwrite: bool = True,
) -> Dict[str, Any]:
"""Flat a nested dictionary.
Example 1
---------
>>> dic = {
... "a": 1,
... "b": {
... "a": 2,
... "b": 10,
... },
... }
>>> flat_dict_of_dict(dic)
... {"a": 1, "b.a": 2, "b.b": 10}
Example 2
---------
>>> dic = {"a": ["hello", "world"], "b": 3}
>>> flat_dict_of_dict(dic, flat_iterables=True)
... {"a.0": "hello", "a.1": "world", "b": 3}
Args:
nested_dic: Nested mapping containing sub-mappings or iterables.
sep: Separators between keys.
flat_iterables: If True, flat iterable and use index as key.
overwrite: If True, overwrite duplicated keys in output. Otherwise duplicated keys will raises a ValueError.
"""
def _impl(nested_dic: Mapping[str, Any]) -> Dict[str, Any]:
output = {}
for k, v in nested_dic.items():
if isinstance_generic(v, Mapping[str, Any]):
v = _impl(v)
v = {f"{k}{sep}{kv}": vv for kv, vv in v.items()}
output.update(v)
elif flat_iterables and isinstance(v, Iterable) and not isinstance(v, str):
v = {f"{i}": vi for i, vi in enumerate(v)}
v = _impl(v)
v = {f"{k}{sep}{kv}": vv for kv, vv in v.items()}
output.update(v)
elif overwrite or k not in output:
output[k] = v
else:
msg = f"Ambiguous flatten dict with key '{k}'. (with value '{v}')"
raise ValueError(msg)
return output
return _impl(nested_dic)
@overload
def flat_list_of_list(
lst: Iterable[Sequence[T]],
return_sizes: Literal[True] = True,
) -> Tuple[List[T], List[int]]: ...
@overload
def flat_list_of_list(
lst: Iterable[Sequence[T]],
return_sizes: Literal[False],
) -> List[T]: ...
[docs]
def flat_list_of_list(
lst: Iterable[Sequence[T]],
return_sizes: bool = True,
) -> Union[Tuple[List[T], List[int]], List[T]]:
"""Return a flat version of the input list of sublists with each sublist size."""
flatten_lst = [elt for sublst in lst for elt in sublst]
sizes = [len(sents) for sents in lst]
if return_sizes:
return flatten_lst, sizes
else:
return flatten_lst
[docs]
def intersect_lists(lst_of_lst: Sequence[Iterable[T]]) -> List[T]:
"""Performs intersection of elements in lists (like set intersection), but keep their original order."""
if len(lst_of_lst) <= 0:
return []
out = list(dict.fromkeys(lst_of_lst[0]))
for lst_i in lst_of_lst[1:]:
out = [name for name in out if name in lst_i]
if len(out) == 0:
break
return out
@overload
def list_dict_to_dict_list(
lst: Iterable[Mapping[K, V]],
key_mode: Literal["intersect", "same"] = "same",
default_val: Any = None,
*,
default_val_fn: Any = None,
list_fn: None = None,
) -> Dict[K, List[V]]: ...
@overload
def list_dict_to_dict_list(
lst: Iterable[Mapping[K, V]],
key_mode: Literal["union"],
default_val: Any = None,
*,
default_val_fn: Callable[[K], X],
list_fn: None = None,
) -> Dict[K, List[Union[V, X]]]: ...
@overload
def list_dict_to_dict_list(
lst: Iterable[Mapping[K, V]],
key_mode: Literal["union"],
default_val: W = None,
*,
default_val_fn: None = None,
list_fn: None = None,
) -> Dict[K, List[Union[V, W]]]: ...
@overload
def list_dict_to_dict_list(
lst: Iterable[Mapping[K, V]],
key_mode: Union[KeyMode, Iterable[K]] = "same",
default_val: W = None,
*,
default_val_fn: Optional[Callable[[K], X]] = None,
list_fn: Callable[[List[Union[V, W, X]]], Y],
) -> Dict[K, Y]: ...
[docs]
def list_dict_to_dict_list(
lst: Iterable[Mapping[K, V]],
key_mode: Union[KeyMode, Iterable[K]] = "same",
default_val: W = None,
*,
default_val_fn: Optional[Callable[[K], X]] = None,
list_fn: Optional[Callable[[List[Union[V, W, X]]], Y]] = identity,
) -> Dict[K, Y]:
"""Convert list of dicts to dict of lists.
Args:
lst: The list of dict to merge. Cannot be a Generator.
key_mode: Can be "same" or "intersect". \
- If "same", all the dictionaries must contains the same keys otherwise a ValueError will be raised. \
- If "intersect", only the intersection of all keys will be used in output. \
- If "union", the output dict will contains the union of all keys, and the missing value will use the argument default_val. \
- If an iterable of elements, use them as keys for output dict.
default_val: Default value of an element when key_mode is "union". defaults to None.
default_val_fn: Function to return the default value according to a specific key. defaults to None.
list_fn: Optional function to build the values. defaults to identity.
"""
if isinstance(lst, Generator):
msg = f"Invalid argument type {type(lst)}. (expected any Iterable except Generator)"
raise TypeError(msg)
try:
item0 = next(iter(lst))
except StopIteration:
return {}
if isinstance(key_mode, str):
unique_keys = set(item0.keys())
if key_mode == "same":
invalids = [
list(item.keys()) for item in lst if unique_keys != set(item.keys())
]
if len(invalids) > 0:
msg = f"Invalid dict keys for conversion from List[dict] to Dict[list]. (with {key_mode=}, {unique_keys=} and {invalids=})"
raise ValueError(msg)
keys = list(item0.keys())
elif key_mode == "intersect":
keys = intersect_lists([item.keys() for item in lst])
elif key_mode == "union":
keys = union_lists(item.keys() for item in lst)
else:
msg = f"Invalid argument key_mode={key_mode}. (expected one of {get_args(KeyMode)})"
raise ValueError(msg)
else:
keys = list(key_mode)
if list_fn is None:
list_fn = identity # type: ignore
result = {
key: list_fn(
[
item.get(
key,
default_val_fn(key) if default_val_fn is not None else default_val,
)
for item in lst
]
) # type: ignore
for key in keys
}
return result # type: ignore
[docs]
def recursive_generator(x: Any) -> Generator[Tuple[Any, int, int], None, None]:
def recursive_generator_impl(
x: Any,
i: int,
deep: int,
) -> Generator[Tuple[Any, int, int], None, None]:
if is_builtin_scalar(x):
yield x, i, deep
elif isinstance(x, Iterable):
for j, xj in enumerate(x):
if xj == x:
yield xj, i, deep
return
else:
yield from recursive_generator_impl(xj, j, deep + 1)
else:
yield x, i, deep
return
return recursive_generator_impl(x, 0, 0)
@overload
def sorted_dict(
x: Mapping[K, V],
/,
*,
by: Literal["key"] = "key",
key: Optional[Callable[[K], Any]] = None,
reverse: bool = False,
) -> Dict[K, V]: ...
@overload
def sorted_dict(
x: Mapping[K, V],
/,
*,
by: Literal["value"],
key: Optional[Callable[[V], Any]] = None,
reverse: bool = False,
) -> Dict[K, V]: ...
@overload
def sorted_dict(
x: Mapping[K, V],
/,
*,
by: Literal["item"],
key: Optional[Callable[[Tuple[K, V]], Any]] = None,
reverse: bool = False,
) -> Dict[K, V]: ...
[docs]
def sorted_dict(
x: Mapping[K, V],
/,
*,
by: Literal["key", "value", "item"] = "key",
key: Optional[Callable[[Any], Any]] = None,
reverse: bool = False,
) -> Dict[K, V]:
"""Sort a dictionnary by key, value or item."""
if key is None or by == "item":
impl_key = key
elif by == "key":
def by_key_fn(x: Tuple[K, V]) -> Any:
return key(x[0])
impl_key = by_key_fn
elif by == "value":
def by_value_fn(x: Tuple[K, V]) -> Any:
return key(x[1])
impl_key = by_value_fn
else:
msg = f"Invalid argument {by=}. (expected one of {('key', 'value', 'item')})"
raise ValueError(msg)
return {k: v for k, v in sorted(x.items(), key=impl_key, reverse=reverse)} # type: ignore
[docs]
def shuffled(
x: MutableSequence[T],
*,
seed: Optional[int] = None,
deep: bool = False,
) -> MutableSequence[T]:
if deep:
x = copy.deepcopy(x)
else:
x = copy.copy(x)
if seed is None:
random.shuffle(x)
return x
else:
state = random.getstate()
random.seed(seed)
random.shuffle(x)
state = random.setstate(state)
return x
[docs]
def unflat_dict_of_dict(dic: Mapping[str, Any], *, sep: str = ".") -> Dict[str, Any]:
"""Unflat a dictionary.
Example 1
----------
>>> dic = {
"a.a": 1,
"b.a": 2,
"b.b": 3,
"c": 4,
}
>>> unflat_dict_of_dict(dic)
... {"a": {"a": 1}, "b": {"a": 2, "b": 3}, "c": 4}
"""
output = {}
for k, v in dic.items():
if sep not in k:
output[k] = v
else:
idx = k.index(sep)
k, kk = k[:idx], k[idx + 1 :]
if k not in output:
output[k] = {}
elif not isinstance(output[k], Mapping):
msg = f"Invalid dict argument. (found keys {k} and {k}{sep}{kk})"
raise ValueError(msg)
output[k][kk] = v
output = {
k: (unflat_dict_of_dict(v) if isinstance(v, Mapping) else v)
for k, v in output.items()
}
return output
[docs]
def unflat_list_of_list(
flatten_lst: Sequence[T],
sizes: Iterable[int],
) -> List[List[T]]:
"""Unflat a list to a list of sublists of given sizes."""
lst = []
start = 0
stop = 0
for count in sizes:
stop += count
lst.append(flatten_lst[start:stop])
start = stop
return lst
[docs]
def union_dicts(dicts: Iterable[Dict[K, V]]) -> Dict[K, V]:
"""Performs union of dictionaries."""
if Version.python() >= Version("3.9.0"):
return reduce_or(*dicts)
it = iter(dicts)
try:
dic0 = next(it)
except StopIteration:
return {}
for dic in it:
dic0.update(dic)
return dic0
[docs]
def union_lists(lst_of_lst: Iterable[Iterable[K]]) -> List[K]:
"""Performs union of elements in lists (like set union), but keep their original order."""
out = {}
for lst_i in lst_of_lst:
out.update(dict.fromkeys(lst_i))
out = list(out)
return out
@overload
def unzip(lst: Iterable[Tuple[()]]) -> Tuple[()]: ...
@overload
def unzip(lst: Iterable[Tuple[T]]) -> Tuple[List[T]]: ...
@overload
def unzip(lst: Iterable[Tuple[T, U]]) -> Tuple[List[T], List[U]]: ...
@overload
def unzip(lst: Iterable[Tuple[T, U, V]]) -> Tuple[List[T], List[U], List[V]]: ...
@overload
def unzip(
lst: Iterable[Tuple[T, U, V, W]],
) -> Tuple[List[T], List[U], List[V], List[W]]: ...
@overload
def unzip(
lst: Iterable[Tuple[T, U, V, W, X]],
) -> Tuple[List[T], List[U], List[V], List[W], List[X]]: ...
@overload
def unzip(
lst: Iterable[Tuple[T, ...]],
) -> Tuple[List[T], ...]: ...
[docs]
def unzip(lst):
"""Invert function of builtin zip().
Example
-------
>>> lst1 = [1, 2, 3, 4]
>>> lst2 = [5, 6, 7, 8]
>>> zipped_list = list(zip(lst1, lst2))
>>> zipped_list
... [(1, 5), (2, 6), (3, 7), (4, 8)]
>>> unzip(zipped_list)
... [1, 2, 3, 4], [5, 6, 7, 8]
"""
return tuple(map(list, zip(*lst)))
[docs]
def duplicate_list(lst: List[T], sizes: List[int]) -> List[T]:
"""Duplicate elements elements of a list with the corresponding sizes.
Example
-------
>>> lst = ["a", "b", "c", "d", "e"]
>>> sizes = [1, 0, 2, 1, 3]
>>> duplicate_list(lst, sizes)
... ["a", "c", "c", "d", "e", "e", "e"]
"""
if len(lst) != len(sizes):
msg = f"Invalid arguments lengths. (found {len(lst)=} != {len(sizes)=})"
raise ValueError(msg)
out_size = sum(sizes)
out: List[T] = [None for _ in range(out_size)] # type: ignore
curidx = 0
for size, elt in zip(sizes, lst):
out[curidx : curidx + size] = [elt] * size
curidx += size
return out