Source code for skrough.permutations

"""Permutation related utils."""

from typing import Literal, Mapping, Optional, Union, get_args

import numpy as np

import skrough.typing as rght
from skrough.weights import prepare_weights


[docs]def get_permutation( start: int, stop: int, proba: Optional[np.ndarray] = None, seed: rght.Seed = None, ) -> np.ndarray: """Get permutation. Get permutation of values between ``start`` (inclusively) and ``stop`` (exclusively) using the given probabilistic distribution ``proba`` for the values being permuted. The value of ``proba`` should be understood as follows: - when ``proba`` is given in a form of :class:`numpy.ndarray` - it is treated as a discrete probabilistic distribution over values from ``start`` to ``stop`` (exclusively) and therefore it should sum to ``1.0`` and has the length equal to :code:`stop - start`. The higher the probability value on the given position ``i``, the more likely for element :code:`range(start, stop)[i]` to appear earlier in the output permutation. - when :code:`proba is None` - the uniform distribution is used Args: start: Start (inclusively) of the interval being permuted. stop: Stop (exclusively) of interval being permuted. proba: Probabilistic distribution for interval used to create permutation. Defaults to :obj:`None`. seed: A seed to initialize random generator. Defaults to :obj:`None`. Returns: Output permutation. """ if start >= stop: result = np.arange(0, dtype=np.int64) else: rng = np.random.default_rng(seed) result = rng.choice( np.arange(start, stop), size=stop - start, replace=False, p=proba, ) return result
def _get_objs_attrs_permutation_strategy_one_before( objs_before: bool, n_objs: int, n_attrs: int, objs_weights: Optional[Union[int, float, np.ndarray]] = None, attrs_weights: Optional[Union[int, float, np.ndarray]] = None, rng: rght.Seed = None, ) -> np.ndarray: objs_proba = prepare_weights( objs_weights, n_objs, expand_none=False, ) objs = get_permutation(0, n_objs, objs_proba, seed=rng) attrs_proba = prepare_weights( attrs_weights, n_attrs, expand_none=False, ) attrs = get_permutation(n_objs, n_objs + n_attrs, attrs_proba, seed=rng) if objs_before: tmp = (objs, attrs) else: tmp = (attrs, objs) result = np.concatenate(tmp) return result
[docs]def get_objs_attrs_permutation_strategy_objs_before( n_objs: int, n_attrs: int, objs_weights: Optional[Union[int, float, np.ndarray]] = None, attrs_weights: Optional[Union[int, float, np.ndarray]] = None, rng: rght.Seed = None, ) -> np.ndarray: return _get_objs_attrs_permutation_strategy_one_before( objs_before=True, n_objs=n_objs, n_attrs=n_attrs, objs_weights=objs_weights, attrs_weights=attrs_weights, rng=rng, )
[docs]def get_objs_attrs_permutation_strategy_attrs_before( n_objs: int, n_attrs: int, objs_weights: Optional[Union[int, float, np.ndarray]] = None, attrs_weights: Optional[Union[int, float, np.ndarray]] = None, rng: rght.Seed = None, ) -> np.ndarray: return _get_objs_attrs_permutation_strategy_one_before( objs_before=False, n_objs=n_objs, n_attrs=n_attrs, objs_weights=objs_weights, attrs_weights=attrs_weights, rng=rng, )
[docs]def get_objs_attrs_permutation_strategy_mixed( n_objs: int, n_attrs: int, objs_weights: Optional[Union[int, float, np.ndarray]] = None, attrs_weights: Optional[Union[int, float, np.ndarray]] = None, rng: rght.Seed = None, ) -> np.ndarray: weights = np.concatenate( ( prepare_weights( objs_weights, n_objs, expand_none=True, normalize=False, ), prepare_weights( attrs_weights, n_attrs, expand_none=True, normalize=False, ), ) ) proba = prepare_weights(weights, n_objs + n_attrs) result = get_permutation(0, n_objs + n_attrs, proba, seed=rng) return result
ObjsAttrsPermutationStrategy = Literal[ "attrs_before", "mixed", "objs_before", ] OBJS_ATTRS_PERMUTATION_STRATEGIES: Mapping[ ObjsAttrsPermutationStrategy, rght.ObjsAttrsPermutationStrategyFunction, ] = { "attrs_before": get_objs_attrs_permutation_strategy_attrs_before, "mixed": get_objs_attrs_permutation_strategy_mixed, "objs_before": get_objs_attrs_permutation_strategy_objs_before, }
[docs]def get_objs_attrs_permutation( n_objs: int, n_attrs: int, objs_weights: Optional[Union[int, float, np.ndarray]] = None, attrs_weights: Optional[Union[int, float, np.ndarray]] = None, strategy: ObjsAttrsPermutationStrategy = "mixed", seed: rght.Seed = None, ) -> np.ndarray: if strategy not in get_args(ObjsAttrsPermutationStrategy): raise ValueError("Unrecognized permutation strategy") if n_objs < 0: raise ValueError("`n_objs` cannot be less than zero") if n_attrs < 0: raise ValueError("`n_attrs` cannot be less than zero") rng = np.random.default_rng(seed) result = OBJS_ATTRS_PERMUTATION_STRATEGIES[strategy]( n_objs=n_objs, n_attrs=n_attrs, objs_weights=objs_weights, attrs_weights=attrs_weights, rng=rng, ) return result
[docs]def get_objs_permutation( n_objs: int, objs_weights: Optional[Union[int, float, np.ndarray]] = None, seed: rght.Seed = None, ) -> np.ndarray: return get_objs_attrs_permutation( n_objs=n_objs, n_attrs=0, objs_weights=objs_weights, strategy="objs_before", seed=seed, )
[docs]def get_attrs_permutation( n_attrs: int, attrs_weights: Optional[Union[int, float, np.ndarray]] = None, seed: rght.Seed = None, ) -> np.ndarray: return get_objs_attrs_permutation( n_objs=0, n_attrs=n_attrs, attrs_weights=attrs_weights, strategy="attrs_before", seed=seed, )