Source code for pythonwrench.math

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import math
import struct
from numbers import Real
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar

from pythonwrench.functools import compose, function_alias

T = TypeVar("T")
T_Real = TypeVar("T_Real", bound=Real)


[docs] def clip( x: T_Real, xmin: Optional[T_Real] = None, xmax: Optional[T_Real] = None, ) -> T_Real: if xmin is not None: x = max(x, xmin) if xmax is not None: x = min(x, xmax) return x
[docs] @function_alias(clip) def clamp(*args, **kwargs): ...
[docs] def argmax(x: Iterable) -> int: max_index, _max_value = max(enumerate(x), key=lambda t: t[1]) return max_index
[docs] def argmin(x: Iterable) -> int: min_index, _max_value = min(enumerate(x), key=lambda t: t[1]) return min_index
[docs] def argsort( x: Iterable[T], *, key: Optional[Callable[[T], Any]] = None, reverse: bool = False, ) -> List[int]: def get_second(t: Tuple[int, T]) -> T: return t[1] if key is None: key_fn = get_second else: key_fn = compose(get_second, key) sorted_x = sorted(enumerate(x), key=key_fn, reverse=reverse) # type: ignore indices = [idx for idx, _ in sorted_x] return indices
[docs] def nextdown(x: float) -> float: return -_nextup(-x)
[docs] def nextafter(x: float, y: float) -> float: """Equivalent to `math.nextafter` for python <=3.8.""" # BASED on https://stackoverflow.com/questions/10420848/how-do-you-get-the-next-value-in-the-floating-point-sequence/10426033#10426033 # If either argument is a NaN, return that argument. # This matches the implementation in decimal.Decimal if math.isnan(x): return x if math.isnan(y): return y if y == x: return y elif y > x: return _nextup(x) else: return nextdown(x)
def _nextup(x: float) -> float: # NaNs and positive infinity map to themselves. if math.isnan(x) or (math.isinf(x) and x > 0): return x # 0.0 and -0.0 both map to the smallest +ve float. if x == 0.0: x = 0.0 n = struct.unpack("<q", struct.pack("<d", x))[0] if n >= 0: n += 1 else: n -= 1 return struct.unpack("<d", struct.pack("<q", n))[0]