Source code for skrough.rough

"""Rough sets related functions."""

from typing import List, Tuple

import numba
import numpy as np

import skrough.typing as rght
from skrough.homogeneity import get_homogeneity
from skrough.structs.group_index import GroupIndex
from skrough.utils import get_positions_where_values_in


[docs]def get_positive_region( x: np.ndarray, x_counts: np.ndarray, y: np.ndarray, y_count: int, attrs: rght.LocationsLike, ) -> List[int]: group_index = GroupIndex.from_data(x, x_counts, attrs) dec_distribution = group_index.get_distribution(y, y_count) homogeneity = get_homogeneity(dec_distribution) # compute positions in ``homogeneity`` (here positions correspond to group ids) that # are equal to True homogenous_groups = homogeneity.nonzero()[0] # return positions in group_index (they correspond to objects) for which values # belong to the set of homogenous groups return get_positions_where_values_in(group_index.index, homogenous_groups)
[docs]def get_gamma_value( x: np.ndarray, x_counts: np.ndarray, y: np.ndarray, y_count: int, attrs: rght.LocationsLike, ) -> float: if len(x) == 0: return 1 pos = get_positive_region(x, x_counts, y, y_count, attrs) return len(pos) / len(x)
[docs]@numba.njit def get_lower_upper_group_ids( membership_distr: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: if membership_distr.ndim != 2 or membership_distr.shape[1] != 2: raise ValueError( "Membership distribution should be a 2D array of just two columns" ) lower = [] upper = [] ngroup = len(membership_distr) for i in numba.prange(ngroup): # pylint: disable=not-an-iterable if membership_distr[i, 1] > 0: upper.append(i) if membership_distr[i, 0] == 0: lower.append(i) return np.asarray(lower), np.asarray(upper)
[docs]def get_approximations( x: np.ndarray, x_counts: np.ndarray, objs: rght.LocationsLike, attrs: rght.LocationsLike, ) -> Tuple[List[int], List[int]]: group_index = GroupIndex.from_data(x, x_counts, attrs) # treat membership as a decision attribute for this computation # imposed interpretation: 0 - not in objs, 1 - in obj membership = np.isin(np.arange(len(x)), objs).astype(int) membership_count = 2 membership_distr = group_index.get_distribution(membership, membership_count) lower_group_ids, upper_group_ids = get_lower_upper_group_ids(membership_distr) lower = get_positions_where_values_in(group_index.index, lower_group_ids) upper = get_positions_where_values_in(group_index.index, upper_group_ids) return lower, upper