Source code for skrough.predict.aggregate
import numba
import numba.typed
import numpy as np
[docs]@numba.njit
def aggregate_predictions(
n_objs: int, n_classes: int, predictions_collection: numba.typed.List[np.ndarray]
):
distribution = np.zeros(
shape=(n_objs, n_classes),
dtype=np.float64,
)
counts = np.zeros(
shape=n_objs,
dtype=np.float64,
)
for predictions in predictions_collection:
for i in range(len(predictions)): # pylint: disable=consider-using-enumerate
if not np.isnan(predictions[i]):
counts[i] += 1
distribution[i, int(predictions[i])] += 1
for i in range(n_objs):
if counts[i] == 0:
distribution[i, :] = np.nan
else:
distribution[i, :] /= counts[i]
return distribution, counts