Source code for skrough.predict.predict_attrs_ensemble

# pylint: disable=duplicate-code

from __future__ import annotations

from typing import Any, Iterable

import numpy as np

import skrough.typing as rght
from skrough.predict.helpers import (
    NoAnswerStrategyKey,
    PredictionResultPreparer,
    PredictStrategyKey,
    check_reference_data,
    predict_ensemble,
)
from skrough.predict.predict_attrs import predict_attrs
from skrough.structs.attrs_subset import AttrsSubset


[docs]def predict_attrs_ensemble( model_ensemble: Iterable[AttrsSubset], reference_data: np.ndarray, reference_data_y: np.ndarray, predict_data: np.ndarray, return_proba: bool = False, predict_strategy: PredictStrategyKey = "majority", no_answer_strategy: NoAnswerStrategyKey = "missing", raw_mode: bool = False, fill_missing: Any = np.nan, preferred_prediction_dtype: type[np.generic] | None = None, seed: rght.Seed = None, n_jobs: int | None = None, ): # TODO: add to docstring that if no_answer_strategy is "missing" but # missing_decision is set to some "X" (assuming "X" being an actual decision # available) then prediction result may contain "X" as the answer but predict_proba # will have a row with all nans therefore this may lead to inconsistency (in such # the case) between predictions and proba check_reference_data( reference_data=reference_data, reference_data_y=reference_data_y ) result_preparer = PredictionResultPreparer.from_reference_data_y( reference_data_y=reference_data_y, raw_mode=raw_mode, fill_missing=fill_missing, preferred_prediction_dtype=preferred_prediction_dtype, ) result = predict_ensemble( model_predict_fun=predict_attrs, model_ensemble=model_ensemble, reference_data=reference_data, reference_data_y=result_preparer.y, reference_data_y_count=len(result_preparer.y_uniques), predict_data=predict_data, return_proba=return_proba, predict_strategy=predict_strategy, no_answer_strategy=no_answer_strategy, seed=seed, n_jobs=n_jobs, ) if not return_proba: return result_preparer.prepare(result) return result, result_preparer.y_uniques