import logging
from typing import Any, List, Optional
import pandas as pd
from attrs import define
from sklearn.base import BaseEstimator
import skrough.typing as rght
from skrough.algorithms.exceptions import LoopBreak
from skrough.algorithms.meta.describe import (
autogenerate_description_node,
describe,
inspect_config_keys,
inspect_input_data_keys,
inspect_values_keys,
)
from skrough.algorithms.meta.helpers import normalize_sequence
from skrough.algorithms.meta.visual_block import sk_visual_block
from skrough.logs import log_start_end
from skrough.structs.state import ProcessingState
logger = logging.getLogger(__name__)
[docs]class AggregateMixin(rght.Describable):
# pylint: disable-next=protected-access
_repr_mimebundle_ = BaseEstimator._repr_mimebundle_
_sk_visual_block_ = sk_visual_block
[docs] def get_description_graph(self):
"""Return the description of an aggregate processing element."""
result = autogenerate_description_node(
processing_element=self, process_docstring=True
)
hooks_list_description = describe(self.normalized_hooks) # type: ignore
result.children = hooks_list_description.children
return result
[docs] def get_config_keys(self) -> List[str]:
return self._get_keys_from_elements(
children=self.normalized_hooks, # type: ignore
inspect_keys_function=inspect_config_keys,
)
[docs] def get_values_keys(self) -> List[str]:
return self._get_keys_from_elements(
children=self.normalized_hooks, # type: ignore
inspect_keys_function=inspect_values_keys,
)
[docs]@define
class StopHooksAggregate(AggregateMixin):
normalized_hooks: List[rght.StopHook]
[docs] @classmethod
@log_start_end(logger)
def from_hooks(
cls,
hooks: rght.OneOrSequence[rght.StopHook],
):
normalized_hooks = normalize_sequence(hooks, optional=False)
return cls(normalized_hooks=normalized_hooks)
@log_start_end(logger)
def __call__(
self,
state: ProcessingState,
raise_loop_break: bool,
) -> bool:
result = any(stop_hook(state) for stop_hook in self.normalized_hooks)
if result and raise_loop_break:
raise LoopBreak()
return result
[docs]@define
class InnerStopHooksAggregate(AggregateMixin):
normalized_hooks: List[rght.InnerStopHook]
[docs] @classmethod
@log_start_end(logger)
def from_hooks(
cls,
hooks: rght.OneOrSequence[rght.InnerStopHook],
):
normalized_hooks = normalize_sequence(hooks, optional=False)
return cls(normalized_hooks=normalized_hooks)
@log_start_end(logger)
def __call__(
self,
state: ProcessingState,
elements: rght.Elements,
raise_loop_break: bool,
) -> bool:
result = any(
stop_hook(state=state, elements=elements)
for stop_hook in self.normalized_hooks
)
if result and raise_loop_break:
raise LoopBreak()
return result
[docs]@define
class UpdateStateHooksAggregate(AggregateMixin):
normalized_hooks: List[rght.UpdateStateHook]
[docs] @classmethod
@log_start_end(logger)
def from_hooks(
cls,
hooks: Optional[rght.OneOrSequence[rght.UpdateStateHook]],
):
normalized_hooks = normalize_sequence(hooks, optional=True)
return cls(normalized_hooks=normalized_hooks)
@log_start_end(logger)
def __call__(
self,
state: ProcessingState,
) -> None:
for hook in self.normalized_hooks:
hook(state)
[docs]@define
class ProduceElementsHooksAggregate(AggregateMixin):
normalized_hooks: List[rght.ProduceElementsHook]
[docs] @classmethod
@log_start_end(logger)
def from_hooks(
cls,
hooks: Optional[rght.OneOrSequence[rght.ProduceElementsHook]],
):
normalized_hooks = normalize_sequence(hooks, optional=True)
return cls(normalized_hooks=normalized_hooks)
@log_start_end(logger)
def __call__(
self,
state: ProcessingState,
) -> rght.Elements:
result: List[Any] = []
for hook in self.normalized_hooks:
result.extend(hook(state))
return pd.Series(result).unique()
[docs]@define
class ProcessElementsHooksAggregate(AggregateMixin):
normalized_hooks: List[rght.ProcessElementsHook]
[docs] @classmethod
@log_start_end(logger)
def from_hooks(
cls,
hooks: Optional[rght.OneOrSequence[rght.ProcessElementsHook]],
):
normalized_hooks = normalize_sequence(hooks, optional=True)
return cls(normalized_hooks=normalized_hooks)
@log_start_end(logger)
def __call__(
self,
state: ProcessingState,
elements: rght.Elements,
) -> rght.Elements:
result: List[Any] = []
for hook in self.normalized_hooks:
result.extend(hook(state, elements))
return pd.Series(result).unique()
[docs]@define
class ChainProcessElementsHooksAggregate(AggregateMixin):
normalized_hooks: List[rght.ProcessElementsHook]
[docs] @classmethod
@log_start_end(logger)
def from_hooks(
cls,
hooks: Optional[rght.OneOrSequence[rght.ProcessElementsHook]],
):
normalized_hooks = normalize_sequence(hooks, optional=True)
return cls(normalized_hooks=normalized_hooks)
@log_start_end(logger)
def __call__(
self,
state: ProcessingState,
elements: rght.Elements,
) -> rght.Elements:
result = elements
for hook in self.normalized_hooks:
result = hook(state, result)
return result