Source code for skrough.predict.helpers

from __future__ import annotations

from typing import Any, Callable, Iterable, Literal, Mapping, cast, get_args

import joblib
import numba
import numba.typed
import numpy as np
from attrs import define

import skrough.typing as rght
from skrough.algorithms.meta.processing import RNG_INTEGERS_PARAM
from skrough.dataprep import prepare_factorized_array, prepare_factorized_vector
from skrough.permutations import get_objs_permutation
from skrough.predict.aggregate import aggregate_predictions
from skrough.structs.group_index import GroupIndex
from skrough.unique import get_uniques, get_uniques_and_positions


[docs]def check_reference_data( reference_data: np.ndarray, reference_data_y: np.ndarray ) -> None: if reference_data.ndim != 2: raise ValueError("the reference data should be 2d array") if reference_data_y.ndim != 1: raise ValueError("the reference data should be 1d vector") if len(reference_data) != len(reference_data_y): raise ValueError("the reference data and targets should be of equal length")
[docs]@define class PredictionResultPreparer: reference_data_y: np.ndarray raw_mode: bool y: np.ndarray y_uniques: np.ndarray fill_missing: Any preferred_prediction_dtype: type[np.generic] | None
[docs] @classmethod def from_reference_data_y( cls, reference_data_y: np.ndarray, raw_mode: bool, fill_missing: Any, preferred_prediction_dtype: type[np.generic] | None, ): if raw_mode: y = reference_data_y y_uniques = get_uniques(y) else: y, _, y_uniques = prepare_factorized_vector( reference_data_y, return_unique_values=True ) return cls( reference_data_y=reference_data_y, raw_mode=raw_mode, y=y, y_uniques=y_uniques, fill_missing=fill_missing, preferred_prediction_dtype=preferred_prediction_dtype, )
[docs] def determine_dtype(self): result_dtype = np.object_ try: result_type_args = [self.reference_data_y.dtype, self.fill_missing] if self.preferred_prediction_dtype is not None: result_type_args.append(self.preferred_prediction_dtype) result_dtype = np.result_type(*result_type_args) except TypeError: pass return result_dtype
[docs] def prepare(self, predictions: np.ndarray) -> np.ndarray: if not self.raw_mode: altered_predictions = np.full_like( predictions, fill_value=self.fill_missing, dtype=self.determine_dtype(), ) nonnans = ~np.isnan(predictions) altered_predictions[nonnans] = cast(np.ndarray, self.y_uniques)[ predictions[nonnans].astype(int) ] predictions = altered_predictions return predictions
@numba.njit def _predict( reference_group_ids: np.ndarray, reference_decisions_offsets: np.ndarray, reference_decisions: np.ndarray, input_group_ids: np.ndarray, ): # create group_id -> dec mapping group_id_to_dec = {} # pylint: disable-next=consider-using-enumerate for i in range(len(reference_group_ids)): group_id_to_dec[reference_group_ids[i]] = reference_decisions[ reference_decisions_offsets[i] ] # create empty result, filled with nan result = np.full(len(input_group_ids), fill_value=np.nan, dtype=np.float64) # pylint: disable-next=consider-using-enumerate for i in range(len(input_group_ids)): # check if input_group_id is in the group_id -> dec mapping # if so, add the decision to the result if input_group_ids[i] in group_id_to_dec: result[i] = group_id_to_dec[input_group_ids[i]] return result
[docs]def predict_strategy_original_order( reference_ids: np.ndarray, reference_data_y: np.ndarray, predict_ids: np.ndarray, seed: rght.Seed = None, # pylint: disable=unused-argument ) -> Any: # prepare unique group_ids and their offsets unique_ids, uniques_index = get_uniques_and_positions(reference_ids) # prepare the result result = _predict(unique_ids, uniques_index, reference_data_y, predict_ids) return result
[docs]def predict_strategy_randomized_order( reference_ids: np.ndarray, reference_data_y: np.ndarray, predict_ids: np.ndarray, seed: rght.Seed = None, ) -> Any: reference_permutation = get_objs_permutation(len(reference_ids), seed=seed) reference_ids = reference_ids[reference_permutation] reference_data_y = reference_data_y[reference_permutation] result = predict_strategy_original_order( reference_ids=reference_ids, reference_data_y=reference_data_y, predict_ids=predict_ids, seed=seed, ) return result
[docs]def predict_strategy_majority( reference_ids: np.ndarray, reference_data_y: np.ndarray, predict_ids: np.ndarray, seed: rght.Seed = None, # pylint: disable=unused-argument ) -> Any: group_index = GroupIndex.from_index(reference_ids) n_decisions = reference_data_y.max() + 1 distribution = group_index.get_distribution(reference_data_y, n_decisions) present_ids = np.flatnonzero(distribution.any(axis=1)) # it looks that the commented out version is slower then there are more present_ids # present_ids_decisions = distribution.argmax(axis=1)[present_ids] present_ids_decisions = distribution[present_ids, :].argmax(axis=1) # prepare the result result = _predict( present_ids, np.arange(len(present_ids)), present_ids_decisions, predict_ids, ) return result
PredictStrategyKey = Literal[ "original_order", "randomized_order", "majority", ] PREDICT_STRATEGIES: Mapping[PredictStrategyKey, rght.PredictStrategyFunction] = { "original_order": predict_strategy_original_order, "randomized_order": predict_strategy_randomized_order, "majority": predict_strategy_majority, }
[docs]class PredictStrategyRunner(rght.PredictStrategyFunction):
[docs] def __init__(self, strategy: PredictStrategyKey) -> None: if strategy not in get_args(PredictStrategyKey): raise ValueError("Unrecognized prediction strategy") self.predict_strategy = PREDICT_STRATEGIES[strategy]
def __call__( self, reference_ids: np.ndarray, reference_data_y: np.ndarray, predict_ids: np.ndarray, seed: rght.Seed = None, ): return self.predict_strategy( reference_ids=reference_ids, reference_data_y=reference_data_y, predict_ids=predict_ids, seed=seed, )
[docs]def no_answer_strategy_missing( reference_data_y: np.ndarray, # pylint: disable=unused-argument seed: rght.Seed = None, # pylint: disable=unused-argument ): return np.nan
[docs]def no_answer_strategy_most_frequent( reference_data_y: np.ndarray, seed: rght.Seed = None, ): # TODO: implement most-frequent strategy raise NotImplementedError
NoAnswerStrategyKey = Literal[ "missing", "most_frequent", ] NO_ANSWER_STRATEGIES: Mapping[NoAnswerStrategyKey, rght.NoAnswerStrategyFunction] = { "missing": no_answer_strategy_missing, "most_frequent": no_answer_strategy_most_frequent, }
[docs]class NoAnswerStrategyRunner(rght.NoAnswerStrategyFunction):
[docs] def __init__(self, strategy: NoAnswerStrategyKey) -> None: if strategy not in get_args(NoAnswerStrategyKey): raise ValueError("Unrecognized no-answer strategy") self.no_answer_strategy = NO_ANSWER_STRATEGIES[strategy]
def __call__( self, reference_data_y: np.ndarray, seed: rght.Seed = None, ): return self.no_answer_strategy( reference_data_y=reference_data_y, seed=seed, )
[docs]def get_group_ids_for_reference_and_predict_data( reference_data: np.ndarray, predict_data: np.ndarray, ): """Get group ids for reference and for predict/input data.""" data_x = np.row_stack([reference_data, predict_data]) x, x_counts = prepare_factorized_array(data_x) group_index = GroupIndex.from_data(x, x_counts) return np.split(group_index.index, [len(reference_data)])
[docs]def get_predictions_from_proba( predict_proba: np.ndarray, counts: np.ndarray, no_answer_value=np.nan, ) -> np.ndarray: predict_proba = np.where( counts == 0, no_answer_value, np.argmax(predict_proba, axis=1) ) return predict_proba
[docs]def predict_single( reference_data: np.ndarray, reference_data_y: np.ndarray, predict_data: np.ndarray, predict_strategy: PredictStrategyKey, no_answer_strategy: NoAnswerStrategyKey, seed: rght.Seed = None, ): predict_strategy_runner = PredictStrategyRunner(predict_strategy) no_answer_strategy_runner = NoAnswerStrategyRunner(no_answer_strategy) rng = np.random.default_rng(seed) # pylint: disable-next=unbalanced-tuple-unpacking reference_ids, predict_ids = get_group_ids_for_reference_and_predict_data( reference_data=reference_data, predict_data=predict_data, ) result = predict_strategy_runner( reference_ids=reference_ids, reference_data_y=reference_data_y, predict_ids=predict_ids, seed=rng.integers(RNG_INTEGERS_PARAM), ) no_answer_value = no_answer_strategy_runner( reference_data_y=reference_data, seed=rng.integers(RNG_INTEGERS_PARAM) ) # fix no-answer in-place np.nan_to_num(result, copy=False, nan=no_answer_value) return result
# TODO: define protocol for model_predict_fun
[docs]def predict_ensemble( model_predict_fun: Callable, model_ensemble: Iterable, reference_data: np.ndarray, reference_data_y: np.ndarray, reference_data_y_count: int, predict_data: np.ndarray, predict_strategy: PredictStrategyKey, no_answer_strategy: NoAnswerStrategyKey, return_proba: bool = False, seed: rght.Seed = None, n_jobs: int | None = None, ): no_answer_strategy_runner = NoAnswerStrategyRunner(no_answer_strategy) rng = np.random.default_rng(seed) predictions_collection = joblib.Parallel(n_jobs=n_jobs)( joblib.delayed(model_predict_fun)( model=model, reference_data=reference_data, reference_data_y=reference_data_y, predict_data=predict_data, predict_strategy=predict_strategy, no_answer_strategy="missing", raw_mode=True, seed=rng.integers(RNG_INTEGERS_PARAM), ) for model in model_ensemble ) result, counts = aggregate_predictions( n_objs=len(predict_data), n_classes=reference_data_y_count, predictions_collection=numba.typed.List(predictions_collection), ) if not return_proba: no_answer_value = no_answer_strategy_runner( reference_data_y=reference_data_y, seed=rng.integers(RNG_INTEGERS_PARAM) ) result = get_predictions_from_proba( result, counts, no_answer_value=no_answer_value, ) return result