Source code for skrough.attrs_checks

import logging

import numpy as np

import skrough.typing as rght
from skrough.logs import log_start_end
from skrough.structs.group_index import GroupIndex

logger = logging.getLogger(__name__)


[docs]@log_start_end(logger) def check_if_attr_better_than_shuffled( group_index: GroupIndex, attr_values: np.ndarray, attr_values_count: int, values: np.ndarray, values_count: int, allowed_randomness: float, probes_count: int, smoothing_parameter: float, fast: bool, disorder_fun: rght.DisorderMeasure, rng: np.random.Generator, ) -> bool: # for result to be True we need `attr_probe_score >= (1 - allowed_randomness)` # # where `attr_probe_score` is estimated using the Laplace smoothing # ``` # attr_probe_score = (attr_is_better_count + smoothing_parameter) / ( # probes_count + smoothing_parameter * smoothing_dims # ) # # attr_is_better_count = number of times attr is better than shuffled # ``` # # the inequality can be transformed to the following # # ``` # attr_is_better_count >= threshold # # threshold = (1 - allowed_randomness) # * (probes_count + smoothing_parameter * smoothing_dims) - smoothing_parameter # # ``` # # alternatively, as `attr_is_better_count = probe_count - attr_is_worse_equal_count` # we can transform the above to # # ``` # probe_count - attr_is_worse_equal_count >= threshold # attr_is_worse_equal_count <= probe_count - threshold # ``` # # and therefore we can say that result is False if # ``` # attr_is_worse_equal_count > probe_count - threshold # ``` # # therefore (early stopping), sometimes we can determine (even before the loop ends) # that the result is: # - True, if `CURRENT_attr_is_BETTER_count >= threshold` # - False, if `CURRENT_attr_is_WORSE_EQUAL_count > probe_count - threshold` result = True if smoothing_parameter < 0: raise ValueError("smoothing parameter cannot be less than zero") smoothing_dims = 2 # binomial distribution, i.e., better/worse threshold = (1 - allowed_randomness) * ( probes_count + smoothing_parameter * smoothing_dims ) - smoothing_parameter attr_disorder_score = group_index.get_disorder_score_after_split( attr_values, attr_values_count, values, values_count, disorder_fun, ) attr_values_shuffled: np.ndarray = np.array(attr_values) # let us prepare a function that shuffles `attr_values_shuffled` if fast: permutation = rng.permutation(len(attr_values_shuffled)) def shuffle_values(): nonlocal attr_values_shuffled attr_values_shuffled = attr_values_shuffled[permutation] else: def shuffle_values(): rng.shuffle(attr_values_shuffled) iterations = 0 current_attr_is_better_count = 0 for _ in range(probes_count): iterations += 1 shuffle_values() shuffled_disorder_score = group_index.get_disorder_score_after_split( attr_values_shuffled, attr_values_count, values, values_count, disorder_fun, ) if attr_disorder_score < shuffled_disorder_score: current_attr_is_better_count += 1 # early stopping - positive case if current_attr_is_better_count >= threshold: result = True break # early stopping - negative case # current_attrs_is_worse_equal_count # == iterations - current_attr_is_better_count if iterations - current_attr_is_better_count > probes_count - threshold: result = False break logger.debug("smoothing_parameter == %f", smoothing_parameter) logger.debug("threshold == %f", threshold) logger.debug("probes_count == %d", probes_count) logger.debug("iterations == %d", iterations) logger.debug("current_attr_is_better_count == %d", current_attr_is_better_count) logger.debug("allowed_randomness == %f", allowed_randomness) return result