diff --git a/pnpxai/__init__.py b/pnpxai/__init__.py index d383906b..3d0e3016 100644 --- a/pnpxai/__init__.py +++ b/pnpxai/__init__.py @@ -1,10 +1,4 @@ from pnpxai.core.detector import detect_model_architecture from pnpxai.core.recommender import XaiRecommender -from pnpxai.core.experiment import ( - Experiment, - AutoExplanation, - AutoExplanationForImageClassification, - AutoExplanationForTextClassification, - AutoExplanationForVisualQuestionAnswering, - AutoExplanationForTSClassification, -) +from pnpxai.core.experiment import Experiment, AutoExplanation +from pnpxai.core.modality import Modality diff --git a/pnpxai/core/__init__.py b/pnpxai/core/__init__.py index 9e174e1c..5b400b13 100644 --- a/pnpxai/core/__init__.py +++ b/pnpxai/core/__init__.py @@ -1,10 +1,5 @@ from pnpxai.core.detector import detect_model_architecture from pnpxai.core.recommender import XaiRecommender -from pnpxai.core.experiment import ( - Experiment, - AutoExplanation, - AutoExplanationForImageClassification, - AutoExplanationForTextClassification, - AutoExplanationForVisualQuestionAnswering, - AutoExplanationForTSClassification, -) \ No newline at end of file +from pnpxai.core.experiment import Experiment, AutoExplanation +from pnpxai.core.utils import ModelWrapper +from pnpxai.core.modality import Modality \ No newline at end of file diff --git a/pnpxai/core/_types.py b/pnpxai/core/_types.py index c0cf0157..37f56bfb 100644 --- a/pnpxai/core/_types.py +++ b/pnpxai/core/_types.py @@ -13,6 +13,7 @@ Task = Literal["classification"] ExplanationType = Literal["attribution"] + class ConfigKeys(Enum): EXPLAINERS = 'explainers' METRICS = 'metrics' diff --git a/pnpxai/core/detector/__init__.py b/pnpxai/core/detector/__init__.py index bce9f23b..2d6e630a 100644 --- a/pnpxai/core/detector/__init__.py +++ b/pnpxai/core/detector/__init__.py @@ -1 +1,6 @@ -from pnpxai.core.detector.detector import detect_model_architecture, symbolic_trace, extract_graph_data +from pnpxai.core.detector.detector import ( + detect_model_architecture, + symbolic_trace, + extract_graph_data, + detect_data_modality, +) diff --git a/pnpxai/core/detector/_core.py b/pnpxai/core/detector/_core.py deleted file mode 100644 index 72cdacac..00000000 --- a/pnpxai/core/detector/_core.py +++ /dev/null @@ -1,430 +0,0 @@ -import _operator -from dataclasses import dataclass, asdict -from typing import Literal, List, Optional, Callable, Union - -import torch -from torch.fx import Node, symbolic_trace, Tracer, GraphModule - -SUPPORTED_FUNCTION_MODULES = { - "torch": torch, - "operator": _operator, -} - -REPLACE_PREFIX = "_replaced_" - - -@dataclass -class NodeInfo: - """ - Represents information about a node in a computation graph. - - Attributes: - - opcode (Literal[str]): The operation code associated with the node. - - name (str): The name of the node. - - target (Union[Callable, str]): The target of the node, which could be a callable object or a string. - """ - opcode: Literal[ - "placeholder", "get_attr", "call_function", - "call_module", "call_method", "output", - ] - name: str - target: Union[Callable, str] - - @classmethod - def from_node(cls, n: Node): - """ - Constructs a NodeInfo object from a given Node. - - Args: - - n (Node): The node from which to construct the NodeInfo object. - - Returns: - - NodeInfo: The NodeInfo object constructed from the given Node. - """ - self = cls( - opcode=n.op, - name=n.name, - target=n._pretty_print_target(n.target), - ) - self._mode = "from_node" - self._set_node(n.graph) - return self - - @classmethod - def from_module(cls, module: torch.nn.Module, **kwargs): - """ - Constructs a NodeInfo object from a given torch.nn.Module. - - Args: - - module (torch.nn.Module): The module from which to construct the NodeInfo object. - - **kwargs: Additional keyword arguments. - - Returns: - - NodeInfo: The NodeInfo object constructed from the given module. - """ - self = cls( - opcode="call_module", - name=None, - target=None, - ) - self._mode = "from_module" - self._set_operator(module) - self._kwargs = kwargs - return self - - @classmethod - def from_function(cls, func: Callable, **kwargs): - """ - Constructs a NodeInfo object from a given callable function. - - Args: - - func (Callable): The function from which to construct the NodeInfo object. - - **kwargs: Additional keyword arguments. - - Returns: - - NodeInfo: The NodeInfo object constructed from the given function. - """ - self = cls( - opcode="call_function", - name=None, - target=None, - ) - self._mode = "from_function" - self._set_operator(func) - self._kwargs = kwargs - return self - - # set original node from the graph - def _set_node(self, graph): - assert self._mode == "from_node" - # [GH] no `get_node` method in `torch.fx.Graph`. just find. - for n in graph.nodes: - if n.name == self.name: - self._node = n - - def _set_operator(self, operator: Callable): - assert self._mode != "from_node" - self._valid_operator(operator) - self._operator = operator - return self - - # [TODO] validation for operator - def _valid_operator(self, operator: Callable): - pass - - def _get_operator(self, targets: List[str], root_module=None): - operator = root_module if root_module else self._node.graph.owning_module - for s in targets: - operator = getattr(operator, s) - return operator - - # cloning main attributions - @property - def meta(self): - """ - Property to access the meta information of the node. - - Returns: - - Any: The meta information of the node. - """ - return self._node.meta - - @property - def args(self): - """ - Property to access the arguments of the node. - - Returns: - - tuple: The arguments of the node. - """ - return tuple([ - NodeInfo.from_node(a) if isinstance(a, Node) else a - for a in self._node.args - ]) - - @property - def kwargs(self): - """ - Property to access the keyword arguments of the node. - - Returns: - - dict: The keyword arguments of the node. - """ - return { - k: NodeInfo.from_node(v) if isinstance(v, Node) else v - for k, v in self._node.kwargs.items() - } - - @property - def users(self): - """ - Property to access the users of the node. - - Returns: - - tuple: The users of the node. - """ - return tuple([ - NodeInfo.from_node(u) - for u in self._node.users.keys() - ]) - - @property - def next(self): - """ - Property to access the next node. - - Returns: - - NodeInfo: The next node. - """ - if self._node.next.op == "root": - return - return NodeInfo.from_node(self._node.next) - - @property - def prev(self): - """ - Property to access the previous node. - - Returns: - - NodeInfo: The previous node. - """ - if self._node.prev.op == "root": - return - return NodeInfo.from_node(self._node.prev) - - # additional properties for detection - @property - def operator(self) -> Optional[Callable]: - """ - Property to access the operator. - - Returns: - - Optional[Callable]: The operator. - """ - if self._mode == "from_node": - if self.opcode == "call_module": - return self._get_operator(self.target.split(".")) - elif self.opcode == "call_function": - targets = self.target.split(".") - root_module = SUPPORTED_FUNCTION_MODULES.get(targets.pop(0)) - if root_module: - return self._get_operator(targets, root_module) - return - return self._operator - - @property - def owning_module(self) -> Optional[str]: - """ - Property to access the owning module. - - Returns: - - Optional[str]: The owning module. - """ - if self.opcode in ["call_module", "call_function"]: - if self.meta.get("nn_module_stack"): - nm = next(iter(self.meta["nn_module_stack"])) - return nm - return - - # convert data format - def to_dict(self): - """ - Converts the NodeInfo object to a dictionary. - - Returns: - - dict: The dictionary representation of the NodeInfo object. - """ - return {**asdict(self), "operator": self.operator} - - # [TODO] to_json for visualization - def to_json_serializable(self): - pass - - -class ModelArchitecture: - """ - Represents the architecture of a model with methods for manipulating nodes. - - Attributes: - - model: The model for which the architecture is defined. - - tracer: A tracer to trace the model. By default, the model is traced by `torch.fx.symbolic_trace`. - """ - - def __init__(self, model, tracer:Optional[Tracer]=None): - self.model = model - self.tracer = tracer - - self._traced_model = self._trace(model) - self._replacing = False - - def _trace(self, model): - if self.tracer is None: - return symbolic_trace(model) - graph = self.tracer.trace(model) - name = ( - model.__class__.__name__ if isinstance(model, torch.nn.Module) else model.__name__ - ) - return GraphModule(self.tracer.root, graph, name) - - def list_nodes(self) -> List[NodeInfo]: - """ - Lists all nodes in the model. - - Returns: - - List[NodeInfo]: A list of NodeInfo objects representing the nodes in the model. - """ - return [NodeInfo.from_node(n) for n in self._traced_model.graph.nodes] - - def get_node(self, name: str) -> NodeInfo: - """ - Retrieves a node by its name. - - Args: - - name (str): The name of the node to retrieve. - - Returns: - - NodeInfo: The NodeInfo object corresponding to the specified node name. - """ - for n in self._traced_model.graph.nodes: - if n.name == name: - return NodeInfo.from_node(n) - return # [TODO] Error for no result? - - def find_node( - self, - filter_func: Callable[[NodeInfo], bool], - root: Optional[NodeInfo] = None, - get_all: bool = False - ) -> Union[NodeInfo, List[NodeInfo]]: - """ - Finds a node based on a filtering function. - - Args: - filter_func (Callable): The function used to filter nodes. - root (Optional[NodeInfo]): The root node from which to start the search. - get_all (bool): Whether to find all nodes matching the criteria. - - Returns: - Union[NodeInfo, List[NodeInfo]]: The found node(s) or None if no node is found. - """ - if root is None: - # Take the first node - root = NodeInfo.from_node(next(iter(self._traced_model.graph.nodes))) - - node = root - nodes = [] - while node is not None: - is_found = filter_func(node) - if is_found: - if not get_all: - return node - nodes.append(node) - node = node.next - if len(nodes) == 0 and not get_all: - return None - return nodes - - def _validate_new_node(self, new_node: NodeInfo): - if new_node.name is not None: - exists = self.get_node(name=new_node.name) - assert not exists, f"A node named {new_node.name} already exists." - return True - - def _ensure_graph(self) -> None: - self._traced_model.graph.lint() - self._traced_model.recompile() - - def replace_node(self, node: NodeInfo, new_node: NodeInfo) -> NodeInfo: - """ - Replaces a node in the model with a new node. - - Args: - - node (NodeInfo): The node to replace. - - new_node (NodeInfo): The new node to insert. - - Returns: - - NodeInfo: The inserted node. - """ - self._replacing = True - self._validate_new_node(new_node) - try: - if new_node._mode == "from_module": - new_node.name = f"{REPLACE_PREFIX}{node.name}" - inserted = self.insert_node(new_node, base_node=node) - self._traced_model.graph.erase_node(node._node) - self._ensure_graph() - finally: - self._replacing = False - return inserted - - def insert_node(self, new_node: NodeInfo, base_node: NodeInfo, before=False) -> NodeInfo: - """ - Inserts a new node into the model. - - Args: - - new_node (NodeInfo): The new node to insert. - - base_node (NodeInfo): The node before which or after which to insert the new node. - - before (bool): Whether to insert the new node before the base node. - - Returns: - - NodeInfo: The inserted node. - """ - self._validate_new_node(new_node) - inserting = self._traced_model.graph.inserting_before if before else self._traced_model.graph.inserting_after - if self._replacing: - _inserted_args = base_node._node.args - _inserted_kwargs = base_node._node.kwargs - pass - elif before: - _inserted_args = tuple( - arg for arg in base_node._node.args if isinstance(arg, Node)) - _inserted_kwargs = { - kw: arg for kw, arg in base_node._node.kwargs.items() if isinstance(arg, Node)} - _inserted_kwargs = {**_inserted_kwargs, **new_node._kwargs} - else: - _inserted_args = (base_node,) - _inserted_kwargs = new_node._kwargs - - # [TODO] validation for new_node: if new_node._mode=="from_module" , new_node.name exists - # insert - with inserting(base_node._node): - if new_node._mode == "from_module": - self._traced_model.add_submodule( - new_node.name, new_node.operator) - _inserted = self._traced_model.graph.call_module( - new_node.name, - _inserted_args, - _inserted_kwargs, - ) - elif new_node._mode == "from_function": - _inserted = self._traced_model.graph.call_function( - new_node.operator, - _inserted_args, - _inserted_kwargs, - ) - - if self._replacing or not before: - base_node._node.replace_all_uses_with(_inserted) - pass - elif before: - _not_inserted_args = [ - arg for arg in base_node._node.args if not isinstance(arg, Node)] - base_node._node.args = tuple([_inserted] + _not_inserted_args) - base_node._node.kwargs = { - kw: arg for kw, arg in base_node._node.kwargs.items() if not isinstance(arg, Node)} - self._ensure_graph() - return NodeInfo.from_node(_inserted) - - def to_dict(self): - """ - Converts the model architecture to a dictionary representation. - - Returns: - - dict: A dictionary containing nodes and edges of the model architecture. - """ - nodes = [] - edges = [] - for n in self.list_nodes(): - nodes.append(n.to_dict()) - edges += [{"source": n.name, "target": u.name} for u in n.users] - return {"nodes": nodes, "edges": edges} diff --git a/pnpxai/core/detector/detector.py b/pnpxai/core/detector/detector.py index 52c9028e..0526d226 100644 --- a/pnpxai/core/detector/detector.py +++ b/pnpxai/core/detector/detector.py @@ -1,7 +1,7 @@ -from typing import Set, Tuple, Optional +from typing import Set, Tuple, Optional, Union, Sequence from torch import fx, nn +from pnpxai.utils import format_into_tuple from pnpxai.core._types import Model -from pnpxai.core.detector.utils import get_target_module_of from pnpxai.core.detector.types import ( ModuleType, Linear, @@ -11,6 +11,7 @@ Attention, Embedding, ) +# from pnpxai.core.modality.modality import _Modality DEFAULT_MODULE_TYPES_TO_DETECT = ( Linear, @@ -105,3 +106,24 @@ def detect_model_architecture( continue detected.add(module_type) return detected + + +DATA_MODALITY_MAYBE = { + (float, 2): 'tabular or time-series', + (float, 4): 'image', + (int, 2): 'text', +} + + +def _data_modality_maybe(dtype, ndims): + return DATA_MODALITY_MAYBE.get((dtype, ndims)) + + +def detect_data_modality(modality: Union["Modality", Sequence["Modality"]]): + nms = [] + for mod in format_into_tuple(modality): + mod_nm = _data_modality_maybe(mod.dtype_key, mod.ndims) + if mod_nm is None: + raise ValueError('Cannot match data modality') + nms.append(mod_nm) + return nms diff --git a/pnpxai/core/detector/types.py b/pnpxai/core/detector/types.py index be6a61da..0411ff87 100644 --- a/pnpxai/core/detector/types.py +++ b/pnpxai/core/detector/types.py @@ -86,5 +86,4 @@ class Embedding(metaclass=SubclassMeta): ) - ModuleType = Union[Linear, Convolution, RNN, LSTM, Attention, Pool, Embedding] diff --git a/pnpxai/core/experiment/__init__.py b/pnpxai/core/experiment/__init__.py index 89f4ee7f..257d4edd 100644 --- a/pnpxai/core/experiment/__init__.py +++ b/pnpxai/core/experiment/__init__.py @@ -1,8 +1,2 @@ from pnpxai.core.experiment.experiment import Experiment -from pnpxai.core.experiment.auto_explanation import ( - AutoExplanation, - AutoExplanationForImageClassification, - AutoExplanationForTextClassification, - AutoExplanationForVisualQuestionAnswering, - AutoExplanationForTSClassification, -) +from pnpxai.core.experiment.auto_explanation import AutoExplanation diff --git a/pnpxai/core/experiment/auto_explanation.py b/pnpxai/core/experiment/auto_explanation.py index 1caf801b..27edc335 100644 --- a/pnpxai/core/experiment/auto_explanation.py +++ b/pnpxai/core/experiment/auto_explanation.py @@ -1,39 +1,25 @@ -from typing import Callable, Optional, Tuple, List -import itertools +from typing import Callable, Optional, List, Union, Any -from torch.nn.modules import Module -from torch.utils.data import DataLoader +import torch from pnpxai.core._types import Model, DataSource from pnpxai.core.experiment.experiment import Experiment -from pnpxai.core.modality.modality import Modality, ImageModality, TextModality, TimeSeriesModality +from pnpxai.core.modality.modality import Modality from pnpxai.core.recommender import XaiRecommender -from pnpxai.explainers.types import TargetLayer -from pnpxai.evaluator.metrics import PIXEL_FLIPPING_METRICS +from pnpxai.explainers.types import TargetLayerOrTupleOfTargetLayers from pnpxai.evaluator.metrics import ( - MuFidelity, - Sensitivity, - Complexity, MoRF, LeRF, AbPC, ) -from pnpxai.utils import format_into_tuple +from pnpxai.utils import _camel_to_snake -METRICS_BASELINE_FN_REQUIRED = PIXEL_FLIPPING_METRICS -METRICS_CHANNEL_DIM_REQUIRED = PIXEL_FLIPPING_METRICS -DEFAULT_METRICS_FOR_TEXT = [ +DEFAULT_METRICS = [ MoRF, LeRF, AbPC, ] -DEFAULT_METRICS = [ - MuFidelity, - AbPC, - Sensitivity, - Complexity, -] class AutoExplanation(Experiment): @@ -59,306 +45,39 @@ def __init__( model: Model, data: DataSource, modality: Modality, - input_extractor: Optional[Callable] = None, - label_extractor: Optional[Callable] = None, - target_extractor: Optional[Callable] = None, - target_labels: bool = False + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], torch.Tensor]] = None, + target_class_extractor: Optional[Callable[[Any], Any]] = None, + label_key: Optional[Union[str, int]] = -1, + target_labels: bool = False, + cache_device: Optional[Union[torch.device, str]] = None, ): - self.recommended = XaiRecommender().recommend(modality=modality, model=model) - self.modality = modality - super().__init__( model=model, data=data, modality=modality, - explainers=self._load_default_explainers(model), - postprocessors=self._load_default_postprocessors(), - metrics=self._load_default_metrics(model), - input_extractor=input_extractor, - label_extractor=label_extractor, - target_extractor=target_extractor, + target_layer=target_layer, + target_input_keys=target_input_keys, + additional_input_keys=additional_input_keys, + output_modifier=output_modifier, + target_class_extractor=target_class_extractor, + label_key=label_key, target_labels=target_labels, + cache_device=cache_device, ) - - def _load_default_explainers(self, model): - explainers = [] - for explainer_type in self.recommended.explainers: - explainer = explainer_type(model=model) - default_kwargs = self._generate_default_kwargs_for_explainer() - for k, v in default_kwargs.items(): - if hasattr(explainer, k): - explainer = explainer.set_kwargs(**{k: v}) - explainers.append(explainer) - return explainers - - def _load_default_metrics(self, model): - empty_metrics = [] # empty means that explainer is not assigned yet - for metric_type in DEFAULT_METRICS: - metric = metric_type(model=model) - default_kwargs = self._generate_default_kwargs_for_metric() - for k, v in default_kwargs.items(): - if hasattr(metric, k): - metric = metric.set_kwargs(**{k: v}) - empty_metrics.append(metric) - return empty_metrics - - def _load_default_postprocessors(self): - modalities = format_into_tuple(self.modality) - if len(modalities) == 1: - return self.modality.get_default_postprocessors() - return list(itertools.product(*tuple( - modality.get_default_postprocessors() - for modality in modalities - ))) - - def _generate_default_kwargs_for_explainer(self): - return { - 'feature_mask_fn': self.modality.get_default_feature_mask_fn(), - 'baseline_fn': self.modality.get_default_baseline_fn(), - } - - def _generate_default_kwargs_for_metric(self): - return { - 'baseline_fn': self.modality.get_default_baseline_fn(), - 'channel_dim': self.modality.channel_dim, - } - - -class AutoExplanationForImageClassification(AutoExplanation): - """ - An extension of AutoExplanation class with modality set to the ImageModality. - - Parameters: - model (Model): The machine learning model to be analyzed. - data (DataSource): The data source used for the experiment. - input_extractor (Optional[Callable]): Custom function to extract input features. - label_extractor (Optional[Callable]): Custom function to extract labels features. - target_extractor (Optional[Callable]): Custom function to extract target features. - target_labels (Optional[bool]): Whether to use target labels. - channel_dim (int): Channel dimension. - - Attributes: - modality (ImageModality): An object to specify modality-specific workflow. - recommended (RecommenderOutput): A data object, containing recommended explainers. - """ - def __init__( - self, - model: Module, - data: DataLoader, - input_extractor: Optional[Callable] = None, - label_extractor: Optional[Callable] = None, - target_extractor: Optional[Callable] = None, - target_labels: bool = False, - channel_dim: int = 1, - ): - super().__init__( - model=model, - data=data, - modality=ImageModality(channel_dim), - input_extractor=input_extractor, - label_extractor=label_extractor, - target_extractor=target_extractor, - target_labels=target_labels - ) - - -class AutoExplanationForTextClassification(AutoExplanation): - """ - An extension of AutoExplanation class with modality set to the TextModality. - - Parameters: - model (Model): The machine learning model to be analyzed. - data (DataSource): The data source used for the experiment. - layer (TargetLayer): A Module or its string representation to select a target layer for analysis. - mask_token_id (int): A mask token id. - input_extractor (Optional[Callable], optional): Custom function to extract input features. - forward_arg_extractor (Optional[Callable]): Custom function to extract forward arguments. - additional_forward_arg_extractor (Optional[Callable]): Custom function to extract additional forward arguments. - label_extractor (Optional[Callable], optional): Custom function to extract labels features. - target_extractor (Optional[Callable], optional): Custom function to extract target features. - target_labels (Optional[bool]): Whether to use target labels. - channel_dim (int): Channel dimension. - - Attributes: - modality (ImageModality): An object to specify modality-specific workflow. - recommended (RecommenderOutput): A data object, containing recommended explainers. - """ - def __init__( - self, - model: Module, - data: DataLoader, - layer: TargetLayer, - mask_token_id: int, - input_extractor: Optional[Callable] = None, - forward_arg_extractor: Optional[Callable] = None, - additional_forward_arg_extractor: Optional[Callable] = None, - label_extractor: Optional[Callable] = None, - target_extractor: Optional[Callable] = None, - target_labels: bool = False, - channel_dim: int = -1, - ): - self.layer = layer - self.mask_token_id = mask_token_id - self.forward_arg_extractor = forward_arg_extractor - self.additional_forward_arg_extractor = additional_forward_arg_extractor - - super().__init__( - model=model, - data=data, - modality=TextModality( - channel_dim=channel_dim, mask_token_id=mask_token_id), - input_extractor=input_extractor, - label_extractor=label_extractor, - target_extractor=target_extractor, - target_labels=target_labels - ) - - def _generate_default_kwargs_for_explainer(self): - return { - 'layer': self.layer, - 'forward_arg_extractor': self.forward_arg_extractor, - 'additional_forward_arg_extractor': self.additional_forward_arg_extractor, - 'feature_mask_fn': self.modality.get_default_feature_mask_fn(), - 'baseline_fn': self.modality.get_default_baseline_fn(), - } - - def _generate_default_kwargs_for_metric(self): - return { - 'baseline_fn': self.modality.get_default_baseline_fn(), - 'channel_dim': self.modality.channel_dim, - } - - def _load_default_metrics(self, model): - empty_metrics = [] # empty means that explainer is not assigned yet - for metric_type in DEFAULT_METRICS_FOR_TEXT: - metric = metric_type(model=model) - default_kwargs = self._generate_default_kwargs_for_metric() - for k, v in default_kwargs.items(): - if hasattr(metric, k): - metric = metric.set_kwargs(**{k: v}) - empty_metrics.append(metric) - return empty_metrics - - -class AutoExplanationForVisualQuestionAnswering(AutoExplanation): - """ - An extension of AutoExplanation class with multiple modalities, namely ImageModality and TextModality. - - Parameters: - model (Model): The machine learning model to be analyzed. - data (DataSource): The data source used for the experiment. - layer (TargetLayer): A Module or its string representation to select a target layer for analysis. - mask_token_id (int): A mask token id. - modality (Modality): An object to specify modality-specific workflow. - input_extractor (Optional[Callable], optional): Custom function to extract input features. - forward_arg_extractor (Optional[Callable]): Custom function to extract forward arguments. - additional_forward_arg_extractor (Optional[Callable]): Custom function to extract additional forward arguments. - label_extractor (Optional[Callable], optional): Custom function to extract labels features. - target_extractor (Optional[Callable], optional): Custom function to extract target features. - target_labels (Optional[bool]): Whether to use target labels. - channel_dim (Tuple[int]): Channel dimension. Requires a tuple channel dimensions for image and text modalities. - - Attributes: - modality (Tuple[ImageModality, TextModality]): A tuple of objects to specify modality-specific workflow. - recommended (RecommenderOutput): A data object, containing recommended explainers. - """ - def __init__( - self, - model: Module, - data: DataLoader, - layer: List[TargetLayer], - mask_token_id: int, - input_extractor: Optional[Callable] = None, - forward_arg_extractor: Optional[Callable] = None, - additional_forward_arg_extractor: Optional[Callable] = None, - label_extractor: Optional[Callable] = None, - target_extractor: Optional[Callable] = None, - target_labels: bool = False, - channel_dim: Tuple[int] = (1, -1), - ): - self.layer = layer - self.mask_token_id = mask_token_id - self.forward_arg_extractor = forward_arg_extractor - self.additional_forward_arg_extractor = additional_forward_arg_extractor - super().__init__( + self.recommended = XaiRecommender().recommend( + modality=modality, model=model, - data=data, - modality=( - ImageModality(channel_dim=channel_dim[0]), - TextModality(channel_dim=channel_dim[1], ), - ), - input_extractor=input_extractor, - label_extractor=label_extractor, - target_extractor=target_extractor, - target_labels=target_labels ) + self._load_default_explainers() + self._load_default_metrics() - def _generate_default_kwargs_for_explainer(self): - return { - 'layer': self.layer, - 'forward_arg_extractor': self.forward_arg_extractor, - 'additional_forward_arg_extractor': self.additional_forward_arg_extractor, - 'feature_mask_fn': tuple( - modality.get_default_feature_mask_fn() for modality in self.modality - ), - 'baseline_fn': tuple( - modality.get_default_baseline_fn() for modality in self.modality - ), - } - - def _generate_default_kwargs_for_metric(self): - return { - 'baseline_fn': tuple( - modality.get_default_baseline_fn() for modality in self.modality - ), - 'channel_dim': tuple(modality.channel_dim for modality in self.modality), - } - - -class AutoExplanationForTSClassification(AutoExplanation): - """ - An extension of AutoExplanation class with modality set to the TimeSeriesModality. - - Parameters: - model (Model): The machine learning model to be analyzed. - data (DataSource): The data source used for the experiment. - input_extractor (Optional[Callable], optional): Custom function to extract input features. - label_extractor (Optional[Callable], optional): Custom function to extract labels features. - target_extractor (Optional[Callable], optional): Custom function to extract target features. - target_labels (Optional[bool]): Whether to use target labels. - sequence_dim (Tuple[int]): Sequence dimension. - mask_agg_dim (Tuple[int]): A dimension for aggregating mask values. Usually, a channel dimension. - - Attributes: - modality (TimeSeriesModality): An object to specify modality-specific workflow. - recommended (RecommenderOutput): A data object, containing recommended explainers. - """ - def __init__( - self, - model: Module, - data: DataLoader, - input_extractor: Optional[Callable] = None, - label_extractor: Optional[Callable] = None, - target_extractor: Optional[Callable] = None, - target_labels: bool = False, - sequence_dim: int = -1, - mask_agg_dim: int = -2, - ): - self.mask_agg_dim = mask_agg_dim - super().__init__( - model=model, - data=data, - modality=TimeSeriesModality(sequence_dim), - input_extractor=input_extractor, - label_extractor=label_extractor, - target_extractor=target_extractor, - target_labels=target_labels - ) + def _load_default_explainers(self): + for explainer_type in self.recommended.explainers: + self.explainers.add(key=explainer_type.alias[0], value=explainer_type) - def _generate_default_kwargs_for_metric(self): - return { - 'baseline_fn': self.modality.get_default_baseline_fn(), - 'feature_mask_fn': self.modality.get_default_feature_mask_fn(), - 'channel_dim': self.modality.channel_dim, - 'mask_agg_dim': self.mask_agg_dim, - } + def _load_default_metrics(self): + for tp in DEFAULT_METRICS: + self.metrics.add(key=tp.alias[0], value=tp) diff --git a/pnpxai/core/experiment/experiment.py b/pnpxai/core/experiment/experiment.py index f7e324e2..f60d6073 100644 --- a/pnpxai/core/experiment/experiment.py +++ b/pnpxai/core/experiment/experiment.py @@ -1,4 +1,7 @@ -from typing import Any, Callable, Optional, Sequence, Union, List, Literal +from typing import Any, Callable, Optional, Sequence, Union, List, Literal, Type, Dict +import itertools +import inspect +from collections import defaultdict import torch from torch import Tensor @@ -6,11 +9,21 @@ import optuna from pnpxai.core._types import DataSource, Model -from pnpxai.core.modality.modality import Modality, TextModality -from pnpxai.core.experiment.experiment_metrics_defaults import EVALUATION_METRIC_REVERSE_SORT, EVALUATION_METRIC_SORT_PRIORITY +from pnpxai.core.modality.modality import Modality +from pnpxai.core.experiment.experiment_metrics_defaults import ( + EVALUATION_METRIC_REVERSE_SORT, + EVALUATION_METRIC_SORT_PRIORITY, +) from pnpxai.core.experiment.manager import ExperimentManager +from pnpxai.core.utils import ModelWrapper from pnpxai.explainers import Explainer, Lime, KernelShap -from pnpxai.explainers.utils.postprocess import Identity +from pnpxai.explainers.base import Tunable +from pnpxai.explainers.utils import FunctionSelector +from pnpxai.explainers.utils.postprocess import Identity, PostProcessor +from pnpxai.explainers.types import ( + TensorOrTupleOfTensors, + TargetLayerOrTupleOfTargetLayers, +) from pnpxai.evaluator.optimizer.types import OptimizationOutput from pnpxai.evaluator.optimizer.objectives import Objective from pnpxai.evaluator.optimizer.utils import ( @@ -18,6 +31,7 @@ get_default_n_trials, ) from pnpxai.evaluator.metrics.base import Metric +from pnpxai.core.experiment.types import ExperimentOutput from pnpxai.utils import ( class_to_string, Observable, to_device, @@ -50,7 +64,7 @@ class Experiment(Observable): metrics (Optional[Sequence[Metric]]): Evaluation metrics used to assess model interpretability. input_extractor (Optional[Callable[[Any], Any]]): Function to extract inputs from data. label_extractor (Optional[Callable[[Any], Any]]): Function to extract labels from data. - target_extractor (Optional[Callable[[Any], Any]]): Function to extract targets from data. + target_class_extractor (Optional[Callable[[Any], Any]]): Function to extract targets from data. cache_device (Optional[Union[torch.device, str]]): Device to cache data and results. target_labels (bool): True if the target is a label, False otherwise. @@ -69,38 +83,50 @@ def __init__( model: Model, data: DataSource, modality: Modality, - explainers: Sequence[Explainer], - postprocessors: Sequence[Callable], - metrics: Sequence[Metric], - input_extractor: Optional[Callable[[Any], Any]] = None, - label_extractor: Optional[Callable[[Any], Any]] = None, - target_extractor: Optional[Callable[[Any], Any]] = None, - cache_device: Optional[Union[torch.device, str]] = None, + explainers: Optional[Dict[str, Type[Explainer]]] = None, + metrics: Optional[Dict[str, Type[Metric]]] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], torch.Tensor]] = None, + target_class_extractor: Optional[Callable[[Any], Any]] = None, + label_key: Optional[Union[str, int]] = -1, target_labels: bool = False, + cache_device: Optional[Union[torch.device, str]] = None, ): super(Experiment, self).__init__() + + # set model self.model = model self.model_device = next(self.model.parameters()).device + self.target_input_keys = target_input_keys + self.additional_input_keys = additional_input_keys + self.output_modifier = output_modifier + # set data self.manager = ExperimentManager(data=data, cache_device=cache_device) - for explainer in explainers: - self.manager.add_explainer(explainer) - for postprocessor in postprocessors: - self.manager.add_postprocessor(postprocessor) - for metric in metrics: - self.manager.add_metric(metric) - - self.input_extractor = input_extractor \ - if input_extractor is not None \ - else default_input_extractor - self.label_extractor = label_extractor \ - if label_extractor is not None \ - else default_target_extractor - self.target_extractor = target_extractor \ - if target_extractor is not None \ - else default_target_extractor - self.target_labels = target_labels self.modality = modality + + # set explainer choices + self.explainers = FunctionSelector() + if explainers is not None: + for k, explainer_type in explainers.items(): + self.explainers.add(k, explainer_type) + + # set metrics + self.metrics = FunctionSelector() + if metrics is not None: + for k, metric_type in metrics.items(): + self.metrics.add(k, metric_type) + + self.target_layer = target_layer + self.target_class_extractor = target_class_extractor or default_target_extractor + self.target_labels = target_labels + self.label_key = label_key + + self._explainer_key_to_id = {} + self._postprocessor_key_to_id = {} + self._metric_key_to_id = {} self.reset_errors() def reset_errors(self): @@ -113,19 +139,24 @@ def errors(self): def to_device(self, x): return to_device(x, self.model_device) + def _validate_choice(self, selector: FunctionSelector, choice: str): + if choice not in selector.choices: + raise ValueError(f"'{choice}' not found in {selector}") + def run_batch( self, - explainer_id: int, - postprocessor_id: int, - metric_id: int, + explainer_key: str, + metric_key: str, data_ids: Optional[Sequence[int]] = None, - ) -> dict: + pooling_method: Optional[str] = None, + normalization_method: Optional[str] = None, + ) -> ExperimentOutput: """ - Runs the experiment for selected batch of data, explainer, postprocessor and metric. + Runs the experiment for selected batch of data, explainer_key, postprocessor and metric. Args: data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to process. - explainer_id (int): ID of explainer to use for the run. + explainer_key (str): ID of explainer to use for the run. postprocessor_id (int): ID of postprocessor to use for the run. metrics_id (int): ID of metric to use for the run. @@ -137,31 +168,55 @@ def run_batch( Note: The input parameters allow for flexibility in specifying subset of data, explainer, postprocessor and metric to process. """ + + # validate choices + self._validate_choice(self.explainers, explainer_key) + self._validate_choice(self.metrics, metric_key) + data_ids = data_ids if data_ids is not None else self.manager.get_data_ids( data_ids) self.predict_batch(data_ids) - self.explain_batch(data_ids, explainer_id) - self.evaluate_batch( - data_ids, explainer_id, postprocessor_id, metric_id) - data = self.manager.batch_data_by_ids(data_ids) - return { - 'inputs': self.input_extractor(data), - 'labels': self.label_extractor(data), - 'outputs': self.manager.batch_outputs_by_ids(data_ids), - 'targets': self._get_targets(data_ids), - 'explainer': self.manager.get_explainer_by_id(explainer_id), - 'explanation': self.manager.batch_explanations_by_ids(data_ids, explainer_id), - 'postprocessor': self.manager.get_postprocessor_by_id(postprocessor_id), - 'postprocessed': self.postprocess_batch(data_ids, explainer_id, postprocessor_id), - 'metric': self.manager.get_metric_by_id(metric_id), - 'evaluation': self.manager.batch_evaluations_by_ids(data_ids, explainer_id, postprocessor_id, metric_id), - } + _, explainer = self.explain_batch( + data_ids, explainer_key, return_explainer=True) + attrs_pp, pp_methods = self.postprocess_batch( + data_ids, explainer_key, pooling_method, normalization_method, + return_methods=True, + ) + evals, metric = self.evaluate_batch( + data_ids, explainer_key, metric_key, + *pp_methods, return_metric=True, + ) + return ExperimentOutput( + explainer=explainer, + metric=metric, + explanations=attrs_pp, + evaluations=evals, + ) + + @property + def _wrapped_model(self): + return ModelWrapper( + model=self.model, + target_input_keys=self.target_input_keys, + additional_input_keys=self.additional_input_keys, + output_modifier=self.output_modifier, + ) + + def input_extractor(self, batch) -> TensorOrTupleOfTensors: + return self._wrapped_model.format_inputs(batch) + + def label_extractor(self, batch) -> Tensor: + return batch[self.label_key] + + def _forward_batch(self, batch): + formatted_inputs = self._wrapped_model.format_inputs(batch) + return self._wrapped_model(*formatted_inputs) def predict_batch( self, data_ids: Sequence[int], - ): + ) -> TensorOrTupleOfTensors: """ Predicts results of the experiment for selected batch of data. @@ -179,17 +234,99 @@ def predict_batch( if self.manager.get_output_by_id(idx) is None ] if len(data_ids_pred) > 0: - data = self.manager.batch_data_by_ids(data_ids_pred) - outputs = self.model( - *format_into_tuple(self.input_extractor(data))) + batch = self.manager.batch_data_by_ids(data_ids_pred) + outputs = self._forward_batch(batch) self.manager.cache_outputs(data_ids_pred, outputs) return self.manager.batch_outputs_by_ids(data_ids) + def _create_instance( + self, + instance_type: Union[Type[Explainer], Type[Metric]], + **kwargs, + ) -> Union[Explainer, Metric]: + # collect constructors from experiment + required_params = dict(inspect.signature(instance_type).parameters) + expr_params = dict(inspect.signature(self.__class__).parameters) + params_base = {} + for param_nm, param in expr_params.items(): + if param_nm in required_params: + required_param = required_params.pop(param_nm) + params_base[required_param.name] = getattr(self, required_param.name) + + # update additional constructors + additional_params = defaultdict(list) + modalities = format_into_tuple(self.modality) + for param_nm, param in required_params.items(): + if param_nm in kwargs: + if param_nm in modalities[0].util_functions: + assert isinstance(kwargs[param_nm], Sequence), ( + f"'{param_nm}' must be a tuple." + ) + assert len(kwargs[param_nm]) == len(modalities), ( + f"'{param_nm}' must have same length with modality." + ) + additional_params[param_nm] = kwargs[param_nm] + else: + if param_nm in modalities[0].util_functions: + for modality in modalities: + default_fn_key = modality.util_functions[param_nm].choices[0] + default_fn = modality.util_functions[param_nm].select( + default_fn_key, + ) + additional_params[param_nm].append(default_fn) + elif param.default is inspect.Parameter.empty: + # raise TypeError when returning + continue + else: + additional_params[param_nm] = param.default + + # format additional constructors + additional_params = { + k: tuple(v) if isinstance(v, list) else v + for k, v in additional_params.items() + } + return instance_type(**params_base, **additional_params) + + def create_explainer(self, explainer_key: str, **kwargs) -> Explainer: + explainer_type = self.explainers.get(explainer_key) + explainer = self._create_instance(explainer_type, **kwargs) + + # set util function space for tunable explainer + modalities = format_into_tuple(self.modality) + if isinstance(explainer, Tunable): + for tunable_param in explainer.tunable_params: + if not tunable_param.is_leaf and tunable_param.space is None: + # tunable_param may be one of tunable util functions + fn_key, *_keys = tunable_param.name.split('.') + if _keys: + modality_key, *_ = _keys + else: + modality_key = 0 + modality = modalities[int(modality_key)] + tunable_param.set_selector( + modality.util_functions[fn_key], set_space=True) + return explainer + + def create_metric(self, metric_key, **kwargs) -> Metric: + metric_type = self.metrics.get(metric_key) + metric = self._create_instance(metric_type, **kwargs) + + # modify pooling dim + modalities = format_into_tuple(self.modality) + pooling_dims = tuple() + if hasattr(metric, 'pooling_dim'): + for modality in modalities: + pooling_dims += (modality.pooling_dim,) + metric.pooling_dim = format_out_tuple_if_single(pooling_dims) + return metric + def explain_batch( self, data_ids: Sequence[int], - explainer_id: int, - ): + explainer_key: str, + return_explainer: bool = True, + **kwargs, + ) -> TensorOrTupleOfTensors: """ Explains selected batch of data within experiment. @@ -202,33 +339,48 @@ def explain_batch( This method orchestrates the experiment by configuring the manager, obtaining explainer instance, processing data, and generating explanations. It then caches the results in the manager, and returns back to the user. """ + + # collect data_ids not cached yet + explainer_id = self._explainer_key_to_id.get(explainer_key, '_placeholder') data_ids_expl = [ idx for idx in data_ids if self.manager.get_explanation_by_id(idx, explainer_id) is None ] - if len(data_ids_expl): - data = self.manager.batch_data_by_ids(data_ids_expl) - inputs = self.input_extractor(data) - targets = self._get_targets(data_ids_expl) - explainer = self.manager.get_explainer_by_id(explainer_id) - explanations = explainer.attribute(inputs, targets) + if len(data_ids_expl) > 0: + batch = self.manager.batch_data_by_ids(data_ids_expl) + inputs = self._wrapped_model.extract_inputs(batch) + targets = self.get_targets_by_id(data_ids_expl) + explainer = self.create_explainer(explainer_key, **kwargs) + + if explainer_id == '_placeholder': + explainer_id = self.manager.add_explainer(explainer) + self._explainer_key_to_id[explainer_key] = explainer_id + + attrs = explainer.attribute(inputs, targets) self.manager.cache_explanations( - explainer_id, data_ids_expl, explanations) - return self.manager.batch_explanations_by_ids(data_ids, explainer_id) + explainer_id, data_ids_expl, attrs) + attrs = self.manager.batch_explanations_by_ids(data_ids, explainer_id) + if return_explainer: + explainer = self.manager.get_explainer_by_id(explainer_id) + return attrs, explainer + return attrs def postprocess_batch( self, data_ids: List[int], - explainer_id: int, - postprocessor_id: int, - ): + explainer_key: str, + pooling_method: Optional[str] = None, + normalization_method: Optional[str] = None, + return_methods: bool = False, + ) -> TensorOrTupleOfTensors: """ Postprocesses selected batch of data within experiment. Args: data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to postprocess. - explainer_id (int): An explainer ID to specify the explainer to use. + explainer_key (str): An explainer ID to specify the explainer to use. postprocessor_id (int): A postprocessor ID to specify the postprocessor to use. + return_methods (bool): Returns postprocess methods if True Returns: Batched postprocessed model explanations corresponding to data ids. @@ -236,42 +388,75 @@ def postprocess_batch( This method orchestrates the experiment by configuring the manager, obtaining explainer instance, processing data, and generating explanations. It then caches the results in the manager, and returns back to the user. """ - explanations = self.manager.batch_explanations_by_ids( + explainer_id = self._explainer_key_to_id.get(explainer_key) + pp_key = self._generate_postprocessor_key( + pooling_method, normalization_method) + pp_id = self._postprocessor_key_to_id.get(pp_key, '_placeholder') + mods = format_into_tuple(self.modality) + pools = format_into_tuple(pooling_method) + norms = format_into_tuple(normalization_method) + attrs = self.manager.batch_explanations_by_ids( data_ids, explainer_id) - postprocessor = self.manager.get_postprocessor_by_id(postprocessor_id) - - modalities = format_into_tuple(self.modality) - explanations = format_into_tuple(explanations) - postprocessors = format_into_tuple(postprocessor) - - batch = [] - explainer = self.manager.get_explainer_by_id(explainer_id) - for mod, attr, pp in zip(modalities, explanations, postprocessors): - if ( - isinstance(explainer, (Lime, KernelShap)) - and isinstance(mod, TextModality) - and not isinstance(pp.pooling_fn, Identity) - ): - raise ValueError( - f'postprocessor {postprocessor_id} does not support explainer {explainer_id}.') - batch.append(pp(attr)) - return format_out_tuple_if_single(batch) + attrs = format_into_tuple(attrs) + zipped = itertools.zip_longest(mods, pools, norms, attrs) + attrs_pp = tuple() + applied_pps = tuple() + applied_pools = tuple() + applied_norms = tuple() + explainer_type = self.explainers.get(explainer_key) + for modality, pool, norm, attr in zipped: + # given or default pooling method + pool = pool or modality.util_functions['pooling_fn'].choices[0] + + # TODO: more elegant way + skip_pool = ( + explainer_type in (KernelShap, Lime) + and modality.dtype_key == int + ) + if skip_pool: + pool = 'identity' + modality.util_functions['pooling_fn'].add_fallback_option( + key=pool, value=Identity) + norm = norm or modality.util_functions['normalization_fn'].choices[0] + pp = PostProcessor(modality, pool, norm) + attrs_pp += (pp(attr),) + applied_pps += (pp,) + applied_pools += (pool,) + applied_norms += (norm,) + attrs_pp = format_out_tuple_if_single(attrs_pp) + if pp_id == '_placeholder': + pp_id = self.manager.add_postprocessor(applied_pps) + pp_key = self._generate_postprocessor_key(applied_pools, applied_norms) + self._postprocessor_key_to_id[pp_key] = pp_id + if not return_methods: + return attrs_pp + applied_pools = format_out_tuple_if_single(applied_pools) + applied_norms = format_out_tuple_if_single(applied_norms) + return attrs_pp, (applied_pools, applied_norms) + + def _generate_postprocessor_key(self, pooling_method, normalization_method): + pools = format_into_tuple(pooling_method) + norms = format_into_tuple(normalization_method) + return '-'.join( + ['_'.join([pool, norm]) for pool, norm in zip(pools, norms)]) def evaluate_batch( self, data_ids: List[int], - explainer_id: int, - postprocessor_id: int, - metric_id: int, - ): + explainer_key: str, + metric_key: str, + pooling_method: Optional[str] = None, + normalization_method: Optional[str] = None, + return_metric: bool = False, + ) -> TensorOrTupleOfTensors: """ Evaluates selected batch of data within experiment. Args: data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to postprocess. - explainer_id (int): An explainer ID to specify the explainer to use. + explainer (str): An explainer ID to specify the explainer to use. postprocessor_id (int): A postprocessor ID to specify the postprocessor to use. - metric_id (int): A metric ID to evaluate the model explanations. + metric (int): A metric ID to evaluate the model explanations. Returns: Batched model evaluations corresponding to data ids. @@ -279,53 +464,92 @@ def evaluate_batch( This method orchestrates the experiment by configuring the manager, obtaining explainer instance, processing data, generating explanations, and evaluating results. It then caches the results in the manager, and returns back to the user. """ + explainer_id = self._explainer_key_to_id.get(explainer_key) + pp_id =self._postprocessor_key_to_id.get(self._generate_postprocessor_key( + pooling_method, normalization_method, + )) + metric_id = self._metric_key_to_id.get(metric_key, '_placeholder') data_ids_eval = [ idx for idx in data_ids if self.manager.get_evaluation_by_id( - idx, explainer_id, postprocessor_id, metric_id) is None + idx, explainer_id, pp_id, metric_id) is None ] - if len(data_ids_eval): - data = self.manager.batch_data_by_ids(data_ids_eval) - inputs = self.input_extractor(data) - targets = self._get_targets(data_ids_eval) - postprocessed = self.postprocess_batch( - data_ids_eval, explainer_id, postprocessor_id) + if len(data_ids_eval) > 0: + batch = self.manager.batch_data_by_ids(data_ids_eval) + inputs = self._wrapped_model.extract_inputs(batch) + targets = self.get_targets_by_id(data_ids_eval) + attrs_pp = self.postprocess_batch( + data_ids_eval, explainer_key, pooling_method, normalization_method) explainer = self.manager.get_explainer_by_id(explainer_id) - metric = self.manager.get_metric_by_id(metric_id) - evaluations = metric.set_explainer(explainer).evaluate( - inputs, targets, postprocessed) + metric = self.create_metric(metric_key) + if metric_id == '_placeholder': + metric_id = self.manager.add_metric(metric) + self._metric_key_to_id[metric_key] = metric_id + evals = metric.set_explainer(explainer).evaluate( + inputs, targets, attrs_pp, + ) self.manager.cache_evaluations( - explainer_id, postprocessor_id, metric_id, - data_ids_eval, evaluations + explainer_id, pp_id, metric_id, + data_ids_eval, evals, ) - return self.manager.batch_evaluations_by_ids( - data_ids, explainer_id, postprocessor_id, metric_id + evals = self.manager.batch_evaluations_by_ids( + data_ids, explainer_id, pp_id, metric_id ) + if return_metric: + metric = self.manager.get_metric_by_id(metric_id) + return evals, metric + return evals - def _get_targets(self, data_ids): + def get_targets_by_id(self, data_ids): if self.target_labels: - return self.label_extractor(self.manager.batch_data_by_ids(data_ids)) + batch = self.manager.batch_data_by_ids(data_ids) + labels = batch[self.label_key] + return labels.to(self.model_device) outputs = self.manager.batch_outputs_by_ids(data_ids) - return self.target_extractor(outputs) + return self.target_class_extractor(outputs).to(self.model_device) + + def is_tunable(self, explainer_key): + is_tunable_explainer = issubclass( + self.explainers.get(explainer_key), Tunable) + modalities = format_into_tuple(self.modality) + is_tunable_by_pool = any( + len(modality.util_functions['pooling_fn'].choices) > 1 + for modality in modalities + ) + is_tunable_by_norm = any( + len(modality.util_functions['normalization_fn'].choices) > 1 + for modality in modalities + ) + return any([ + is_tunable_explainer, + is_tunable_by_pool, + is_tunable_by_norm, + ]) def optimize( self, - data_ids: Union[int, Sequence[int]], - explainer_id: int, - metric_id: int, + explainer_key: str, + metric_key: str, + metric_options: Optional[Dict[str, Any]] = None, + disable_tunable_params: Optional[Dict[str, Any]] = None, direction: Literal['minimize', 'maximize'] = 'maximize', sampler: Literal['grid', 'random', 'tpe'] = 'tpe', + data_ids: Optional[Union[int, Sequence[int]]] = None, n_trials: Optional[int] = None, timeout: Optional[float] = None, + num_threads: Optional[int] = None, + show_progress: bool = False, + n_jobs: int = 1, + errors: Literal['raise', 'ignore'] = 'raise', **kwargs, # sampler kwargs - ): + ) -> OptimizationOutput: """ Optimize experiment hyperparameters by processing data, generating explanations, evaluating with metrics, caching and retrieving the data. Args: data_ids (Union[int, Sequence[int]]): A single data ID or sequence of data IDs to specify the subset of data to process. - explainer_id (int): An explainer ID to specify the explainer to use. - metric_id (int): A metric ID to evaluate optimizer decisions. + explainer (str): An explainer ID to specify the explainer to use. + metric (int): A metric ID to evaluate optimizer decisions. direction (Literal['minimize', 'maximize']): A string to specify the direction of optimization. sampler (Literal['grid', 'random', 'tpe']): A string to specify the sampler to use for optimization. n_trials (Optional[int]): An integer to specify the number of trials for optimization. If none passed, the number of trials is inferred from `timeout`. @@ -340,21 +564,23 @@ def optimize( Note: The input parameters allow for flexibility in specifying subsets of data, explainers, and metrics to process. If not provided, the method processes all available data, explainers, postprocessors, and metrics. """ - data_ids = [data_ids] if isinstance(data_ids, int) else data_ids - data = self.manager.batch_data_by_ids(data_ids) - explainer = self.manager.get_explainer_by_id(explainer_id) - postprocessor = self.manager.get_postprocessor_by_id( - 0) # sample postprocessor to ensure channel_dim - metric = self.manager.get_metric_by_id(metric_id) - - objective = Objective( - explainer=explainer, - postprocessor=postprocessor, - metric=metric, + org_num_threads = torch.get_num_threads() + if num_threads is not None: + torch.set_num_threads(num_threads) + metric_options = metric_options or {} + obj = Objective( modality=self.modality, - inputs=self.input_extractor(data), - targets=self._get_targets(data_ids), + explainer=self.create_explainer(explainer_key), + metric=self.create_metric(metric_key, **metric_options), + data=self.manager.get_data(data_ids)[0], + disable_tunable_params=disable_tunable_params, + target_class_extractor=self.target_class_extractor, + label_key=self.label_key, + target_labels=self.target_labels, + show_progress=show_progress, + errors=errors, ) + # TODO: grid search if timeout is None: n_trials = n_trials or get_default_n_trials(sampler) @@ -365,20 +591,38 @@ def optimize( direction=direction, ) study.optimize( - objective, + obj, n_trials=n_trials, timeout=timeout, - n_jobs=1, + n_jobs=n_jobs, ) - opt_explainer = study.best_trial.user_attrs['explainer'] - opt_postprocessor = study.best_trial.user_attrs['postprocessor'] + + opt_explainer = self.create_explainer(explainer_key) + opt_pps = tuple( + PostProcessor(modality=mod) for mod in format_into_tuple(self.modality)) + for key in study.best_params: + div, *additional_keys = key.split('.') + if div == Objective.EXPLAINER_KEY: + opt_explainer.update_current_value( + '.'.join(additional_keys), study.best_params[key]) + elif div == Objective.POSTPROCESSOR_KEY: + mod_loc, _key = additional_keys + opt_pps[int(mod_loc)].update_current_value( + _key, study.best_params[key]) + opt_pps = format_out_tuple_if_single(opt_pps) return OptimizationOutput( explainer=opt_explainer, - postprocessor=opt_postprocessor, + postprocessor=opt_pps, study=study, + value=study.best_trial.value, + params=study.best_params, ) - def get_inputs_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]: + + def get_inputs_flattened( + self, + data_ids: Optional[Sequence[int]] = None, + ) -> Sequence[Tensor]: """ Retrieve and flatten last run input data. @@ -406,10 +650,14 @@ def get_all_inputs_flattened(self) -> Sequence[Tensor]: This method retrieves input data from all available data points using the input extractor and flattens it. """ data = self.manager.get_all_data() - data = [self.input_extractor(datum) for datum in data] + data = [format_out_tuple_if_single( + self.input_extractor(datum)) for datum in data] return self.manager.flatten_if_batched(data, data) - def get_labels_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]: + def get_labels_flattened( + self, + data_ids: Optional[Sequence[int]] = None, + ) -> Sequence[Tensor]: """ Retrieve and flatten labels data. @@ -425,7 +673,10 @@ def get_labels_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequ labels = [self.label_extractor(datum) for datum in data] return self.manager.flatten_if_batched(labels, data) - def get_targets_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]: + def get_targets_flattened( + self, + data_ids: Optional[Sequence[int]] = None, + ) -> Sequence[Tensor]: """ Retrieve and flatten target data. @@ -442,13 +693,13 @@ def get_targets_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Seq if self.target_labels: return self.get_labels_flattened(data_ids) data, _ = self.manager.get_data(data_ids) - targets = [self._get_targets(data_ids)] + targets = [self.get_targets_by_id(data_ids)] return self.manager.flatten_if_batched(targets, data) - # targets = [self.label_extractor(datum) for datum in data] \ - # if self.target_labels else [self._get_targets(data_ids)] - # return self.manager.flatten_if_batched(targets, data) - def get_outputs_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]: + def get_outputs_flattened( + self, + data_ids: Optional[Sequence[int]] = None, + ) -> Sequence[Tensor]: """ Retrieve and flatten model outputs. @@ -464,7 +715,10 @@ def get_outputs_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Seq """ return self.manager.get_flat_outputs(data_ids) - def get_explanations_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Sequence[Tensor]]: + def get_explanations_flattened( + self, + data_ids: Optional[Sequence[int]] = None, + ) -> Sequence[Sequence[Tensor]]: """ Retrieve and flatten explanations from all explainers. @@ -480,11 +734,14 @@ def get_explanations_flattened(self, data_ids: Optional[Sequence[int]] = None) - """ _, explainer_ids = self.manager.get_explainers() return [ - self.manager.get_flat_explanations(explainer_id, data_ids) - for explainer_id in explainer_ids + self.manager.get_flat_explanations(explainer, explainer_ids) + for explainer in explainer_ids ] - def get_evaluations_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Sequence[Sequence[Tensor]]]: + def get_evaluations_flattened( + self, + data_ids: Optional[Sequence[int]] = None, + ) -> Sequence[Sequence[Sequence[Tensor]]]: """ Retrieve and flatten evaluations for all explainers and metrics. @@ -505,10 +762,10 @@ def get_evaluations_flattened(self, data_ids: Optional[Sequence[int]] = None) -> formatted = [[[ self.manager.get_flat_evaluations( - explainer_id, postprocessor_id, metric_id, data_ids) - for metric_id in metric_ids + explainer, postprocessor_id, metric, data_ids) + for metric in metric_ids ] for postprocessor_id in postprocessor_ids] - for explainer_id in explainer_ids] + for explainer in explainer_ids] return formatted @@ -560,3 +817,5 @@ def get_explainers_ranks(self) -> Optional[Sequence[Sequence[int]]]: @property def has_explanations(self): return self.manager.has_explanations + + diff --git a/pnpxai/core/experiment/manager.py b/pnpxai/core/experiment/manager.py index 7622d9ff..c87e7f34 100644 --- a/pnpxai/core/experiment/manager.py +++ b/pnpxai/core/experiment/manager.py @@ -22,7 +22,7 @@ def __init__( self.set_data_ids() self._explainers: List[Explainer] = [] self._explainer_ids: List[int] = [] - self._postprocessors: List[PostProcessor] =[] + self._postprocessors: List[PostProcessor] = [] self._postprocessor_ids: List[int] = [] self._metrics: List[Metric] = [] self._metric_ids: List[int] = [] @@ -358,7 +358,6 @@ def save_outputs(self, outputs: DataSource, data: DataSource, data_ids: Sequence for idx, output in zip(data_ids, outputs): self._cache.set_output(idx, output) - def _get_batch_size(self, data: DataSource) -> Optional[int]: if torch.is_tensor(data): return len(data) diff --git a/pnpxai/core/experiment/types.py b/pnpxai/core/experiment/types.py new file mode 100644 index 00000000..bfc7993c --- /dev/null +++ b/pnpxai/core/experiment/types.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from pnpxai.core._types import TensorOrTupleOfTensors +from pnpxai.explainers.base import Explainer +from pnpxai.evaluator.metrics.base import Metric + + +@dataclass +class ExperimentOutput: + explainer: Explainer + metric: Metric + explanations: TensorOrTupleOfTensors + evaluations: TensorOrTupleOfTensors diff --git a/pnpxai/core/modality/__init__.py b/pnpxai/core/modality/__init__.py index 8a519e43..f1f1013a 100644 --- a/pnpxai/core/modality/__init__.py +++ b/pnpxai/core/modality/__init__.py @@ -1,6 +1 @@ -from pnpxai.core.modality.modality import ( - Modality, - ImageModality, - TextModality, - TimeSeriesModality -) +from pnpxai.core.modality.modality import Modality \ No newline at end of file diff --git a/pnpxai/core/modality/modality.py b/pnpxai/core/modality/modality.py index 13923114..6ac67f9b 100644 --- a/pnpxai/core/modality/modality.py +++ b/pnpxai/core/modality/modality.py @@ -1,353 +1,74 @@ -from typing import Optional, List, Callable, Type, Any, Dict -from abc import ABC, abstractmethod +from typing import Optional, Any +from collections import defaultdict -from pnpxai.explainers.utils import UtilFunction from pnpxai.explainers.utils.postprocess import ( - PostProcessor, POOLING_FUNCTIONS, - POOLING_FUNCTIONS_FOR_IMAGE, - POOLING_FUNCTIONS_FOR_TEXT, - POOLING_FUNCTIONS_FOR_TIME_SERIES, - NORMALIZATION_FUNCTIONS_FOR_IMAGE, - NORMALIZATION_FUNCTIONS_FOR_TEXT, - NORMALIZATION_FUNCTIONS_FOR_TIME_SERIES, - PoolingFunction, - NormalizationFunction + NORMALIZATION_FUNCTIONS, ) from pnpxai.explainers.utils.function_selectors import FunctionSelector from pnpxai.explainers.utils.baselines import ( - BaselineFunction, BASELINE_FUNCTIONS, - BASELINE_FUNCTIONS_FOR_IMAGE, - BASELINE_FUNCTIONS_FOR_TEXT, - BASELINE_FUNCTIONS_FOR_TIME_SERIES, + TokenBaselineFunction, + MeanBaselineFunction, ) -from pnpxai.explainers.utils.feature_masks import ( - FeatureMaskFunction, - FEATURE_MASK_FUNCTIONS, - FEATURE_MASK_FUNCTIONS_FOR_IMAGE, - FEATURE_MASK_FUNCTIONS_FOR_TEXT, - FEATURE_MASK_FUNCTIONS_FOR_TIME_SERIES, -) -from pnpxai.explainers import ( - Gradient, - GradientXInput, - SmoothGrad, - VarGrad, - IntegratedGradients, - LRPUniformEpsilon, - LRPEpsilonPlus, - LRPEpsilonGammaBox, - LRPEpsilonAlpha2Beta1, - KernelShap, - Lime, - AVAILABLE_EXPLAINERS -) - - -class Modality(ABC): - """ - An abstract class describing modality-specific workflow. The class is used to define both default and available - explainers, baselines, feature masks, pooling methods, and normalization methods for the modality. +from pnpxai.explainers.utils.feature_masks import FEATURE_MASK_FUNCTIONS - Parameters: - channel_dim (int): Target sequence dimension. - baseline_fn_selector (Optional[FunctionSelector]): Selector of baselines for the modality's explainers. If None selected, all BASELINE_FUNCTIONS will be used. - feature_mask_fn_selector (Optional[FunctionSelector]): Selector of feature masks for the modality's explainers. If None selected, all FEATURE_MASK_FUNCTIONS will be used. - pooling_fn_selector (Optional[FunctionSelector]): Selector of pooling methods for the modality's explainers. If None selected, all POOLING_FUNCTIONS will be used. - normalization_fn_selector (Optional[FunctionSelector]): Selector of normalization methods for the modality's explainers. If None selected, all NORMALIZATION_FUNCTIONS_FOR_IMAGE will be used. - Attributes: - EXPLAINERS (Tuple[Explainer]): Tuple of all available explainers. - """ - - # Copies the tuple without preserving the reference - EXPLAINERS = tuple(iter(AVAILABLE_EXPLAINERS)) +class Modality: + UTIL_FUNCTIONS = { + 'baseline_fn': BASELINE_FUNCTIONS, + 'feature_mask_fn': FEATURE_MASK_FUNCTIONS, + 'pooling_fn': POOLING_FUNCTIONS, + 'normalization_fn': NORMALIZATION_FUNCTIONS, + } def __init__( self, - channel_dim: int, - baseline_fn_selector: Optional[FunctionSelector] = None, - feature_mask_fn_selector: Optional[FunctionSelector] = None, - pooling_fn_selector: Optional[FunctionSelector] = None, - normalization_fn_selector: Optional[FunctionSelector] = None, - **kwargs + dtype: Any, + ndims: Any, + pooling_dim: Optional[int] = None, + mask_token_id: Optional[int] = None, ): - self.channel_dim = channel_dim - self.baseline_fn_selector = baseline_fn_selector or FunctionSelector(BASELINE_FUNCTIONS) - self.feature_mask_fn_selector = feature_mask_fn_selector or FunctionSelector(FEATURE_MASK_FUNCTIONS) - self.pooling_fn_selector = pooling_fn_selector or FunctionSelector(POOLING_FUNCTIONS) - self.normalization_fn_selector = normalization_fn_selector or FunctionSelector(NORMALIZATION_FUNCTIONS_FOR_IMAGE) - - @abstractmethod - def get_default_feature_mask_fn(self) -> Callable: - """ - Defines default baseline function for the modality's explainers. - - Returns: - BaselineFunction: Zeros baseline function. - """ - raise NotImplementedError - - @abstractmethod - def get_default_baseline_fn(self) -> Callable: - """ - Defines default feature mask function for the modality's explainers. - - Returns: - FeatureMaskFunction: No Mask baseline function. - """ - raise NotImplementedError - - @abstractmethod - def get_default_postprocessors(self) -> List[Callable]: - """ - Defines default post-processors list for the modality's explainers. - - Returns: - List[PostProcessor]: Identity PostProcessors. - """ - raise NotImplementedError - - def map_fn_selector(self, method_type: Type[Any]) -> Dict[Type[UtilFunction], callable]: - """ - Selects custom optimizable hyperparameter functions. - - Returns: - Dict[Type[UtilFunction], callable]: Identity PostProcessors. - """ - return { - BaselineFunction: self.baseline_fn_selector, - FeatureMaskFunction: self.feature_mask_fn_selector, - PoolingFunction: self.pooling_fn_selector, - NormalizationFunction: self.normalization_fn_selector, - }.get(method_type, None) - - -class ImageModality(Modality): - """ - An extension of Modality class for Image domain with automatic explainers and evaluation metrics recommendation. - - Parameters: - channel_dim (int): Target sequence dimension. - baseline_fn_selector (Optional[FunctionSelector]): Selector of baselines for the modality's explainers. If None selected, BASELINE_FUNCTIONS_FOR_TIME_SERIES will be used. - feature_mask_fn_selector (Optional[FunctionSelector]): Selector of feature masks for the modality's explainers. If None selected, FEATURE_MASK_FUNCTIONS_FOR_TIME_SERIES will be used. - pooling_fn_selector (Optional[FunctionSelector]): Selector of pooling methods for the modality's explainers. If None selected, POOLING_FUNCTIONS_FOR_TIME_SERIES will be used. - normalization_fn_selector (Optional[FunctionSelector]): Selector of normalization methods for the modality's explainers. If None selected, NORMALIZATION_FUNCTIONS_FOR_TIME_SERIES will be used. - """ - def __init__( - self, - channel_dim: int = 1, - baseline_fn_selector: Optional[FunctionSelector] = None, - feature_mask_fn_selector: Optional[FunctionSelector] = None, - pooling_fn_selector: Optional[FunctionSelector] = None, - normalization_fn_selector: Optional[FunctionSelector] = None, - ): - super(ImageModality, self).__init__( - channel_dim, - baseline_fn_selector=baseline_fn_selector or FunctionSelector( - data=BASELINE_FUNCTIONS_FOR_IMAGE, - default_kwargs={'dim': channel_dim}, - ), - feature_mask_fn_selector=feature_mask_fn_selector or FunctionSelector( - data=FEATURE_MASK_FUNCTIONS_FOR_IMAGE - ), - pooling_fn_selector=pooling_fn_selector or FunctionSelector( - data=POOLING_FUNCTIONS_FOR_IMAGE, - default_kwargs={'channel_dim': channel_dim}, - ), - normalization_fn_selector=normalization_fn_selector or FunctionSelector( - data=NORMALIZATION_FUNCTIONS_FOR_IMAGE - ), - ) - - def get_default_baseline_fn(self) -> BaselineFunction: - """ - Defines default baseline function for the modality's explainers. - - Returns: - BaselineFunction: Zeros baseline function. - """ - return self.baseline_fn_selector.select('zeros') - - def get_default_feature_mask_fn(self) -> FeatureMaskFunction: - """ - Defines default feature mask function for the modality's explainers. - - Returns: - FeatureMaskFunction: Felzenszwalb baseline function. - """ - return self.feature_mask_fn_selector.select('felzenszwalb', scale=250) - - def get_default_postprocessors(self) -> List[PostProcessor]: - """ - Defines default post-processors list for the modality's explainers. - - Returns: - List[PostProcessor]: All available PostProcessors. - """ - return [ - PostProcessor( - pooling_fn=self.pooling_fn_selector.select(pm), - normalization_fn=self.normalization_fn_selector.select(nm), - ) for pm in self.pooling_fn_selector.choices - for nm in self.normalization_fn_selector.choices - ] - - -class TextModality(Modality): - """ - An extension of Modality class for Text domain with automatic explainers and evaluation metrics recommendation. - - Parameters: - channel_dim (int): Target sequence dimension. - baseline_fn_selector (Optional[FunctionSelector]): Selector of baselines for the modality's explainers. If None selected, BASELINE_FUNCTIONS_FOR_TEXT will be used. - feature_mask_fn_selector (Optional[FunctionSelector]): Selector of feature masks for the modality's explainers. If None selected, FEATURE_MASK_FUNCTIONS_FOR_TEXT will be used. - pooling_fn_selector (Optional[FunctionSelector]): Selector of pooling methods for the modality's explainers. If None selected, POOLING_FUNCTIONS_FOR_TEXT will be used. - normalization_fn_selector (Optional[FunctionSelector]): Selector of normalization methods for the modality's explainers. If None selected, NORMALIZATION_FUNCTIONS_FOR_TEXT will be used. - """ - EXPLAINERS = ( - Gradient, - GradientXInput, - SmoothGrad, - VarGrad, - IntegratedGradients, - LRPUniformEpsilon, - LRPEpsilonPlus, - LRPEpsilonGammaBox, - LRPEpsilonAlpha2Beta1, - KernelShap, - Lime, - ) - - def __init__( - self, - channel_dim: int = -1, - mask_token_id: int = 0, - baseline_fn_selector: Optional[FunctionSelector] = None, - feature_mask_fn_selector: Optional[FunctionSelector] = None, - pooling_fn_selector: Optional[FunctionSelector] = None, - normalization_fn_selector: Optional[FunctionSelector] = None, - ): - super(TextModality, self).__init__( - channel_dim, - baseline_fn_selector=baseline_fn_selector or FunctionSelector( - data=BASELINE_FUNCTIONS_FOR_TEXT, - default_kwargs={'token_id': mask_token_id}, - ), - feature_mask_fn_selector=feature_mask_fn_selector or FunctionSelector( - data=FEATURE_MASK_FUNCTIONS_FOR_TEXT - ), - pooling_fn_selector=pooling_fn_selector or FunctionSelector( - data=POOLING_FUNCTIONS_FOR_TEXT, - default_kwargs={'channel_dim': channel_dim}, - ), - normalization_fn_selector=normalization_fn_selector or FunctionSelector( - data=NORMALIZATION_FUNCTIONS_FOR_TEXT, - ), - ) + self.dtype = dtype + self.ndims = ndims + self.pooling_dim = pooling_dim self.mask_token_id = mask_token_id - def get_default_baseline_fn(self) -> BaselineFunction: - """ - Defines default baseline function for the modality's explainers. - - Returns: - BaselineFunction: Token baseline function. - """ - return self.baseline_fn_selector.select('token') - - def get_default_feature_mask_fn(self) -> FeatureMaskFunction: - """ - Defines default feature mask function for the modality's explainers. - - Returns: - FeatureMaskFunction: No Mask baseline function. - """ - return self.feature_mask_fn_selector.select('no_mask_1d') - - def get_default_postprocessors(self) -> List[PostProcessor]: - """ - Defines default post-processors list for the modality's explainers. - - Returns: - List[PostProcessor]: All PostProcessors. - """ - return [ - PostProcessor( - pooling_fn=self.pooling_fn_selector.select(pm), - normalization_fn=self.normalization_fn_selector.select(nm), - ) for pm in self.pooling_fn_selector.choices - for nm in self.normalization_fn_selector.choices - ] - - -class TimeSeriesModality(Modality): - """ - An extension of Modality class for Time Series domain with automatic explainers and evaluation metrics recommendation. - - Parameters: - channel_dim (int): Target sequence dimension. - baseline_fn_selector (Optional[FunctionSelector]): Selector of baselines for the modality's explainers. If None selected, BASELINE_FUNCTIONS_FOR_TIME_SERIES will be used. - feature_mask_fn_selector (Optional[FunctionSelector]): Selector of feature masks for the modality's explainers. If None selected, FEATURE_MASK_FUNCTIONS_FOR_TIME_SERIES will be used. - pooling_fn_selector (Optional[FunctionSelector]): Selector of pooling methods for the modality's explainers. If None selected, POOLING_FUNCTIONS_FOR_TIME_SERIES will be used. - normalization_fn_selector (Optional[FunctionSelector]): Selector of normalization methods for the modality's explainers. If None selected, NORMALIZATION_FUNCTIONS_FOR_TIME_SERIES will be used. - """ - def __init__( - self, - channel_dim: int = -1, - baseline_fn_selector: Optional[FunctionSelector] = None, - feature_mask_fn_selector: Optional[FunctionSelector] = None, - pooling_fn_selector: Optional[FunctionSelector] = None, - normalization_fn_selector: Optional[FunctionSelector] = None, - ): - super(TimeSeriesModality, self).__init__( - channel_dim, - baseline_fn_selector=baseline_fn_selector or FunctionSelector( - data=BASELINE_FUNCTIONS_FOR_TIME_SERIES, # [zeros, mean] - default_kwargs={'dim': channel_dim}, - ), - feature_mask_fn_selector=feature_mask_fn_selector or FunctionSelector( - data=FEATURE_MASK_FUNCTIONS_FOR_TIME_SERIES, - ), - pooling_fn_selector=pooling_fn_selector or FunctionSelector( - data=POOLING_FUNCTIONS_FOR_TIME_SERIES, # [identity] - default_kwargs={'channel_dim': channel_dim}, - ), - normalization_fn_selector=normalization_fn_selector or FunctionSelector( - data=NORMALIZATION_FUNCTIONS_FOR_TIME_SERIES, # [identity] - ), - ) - - def get_default_baseline_fn(self) -> BaselineFunction: - """ - Defines default baseline function for the modality's explainers. - - Returns: - BaselineFunction: Zeros baseline function. - """ - return self.baseline_fn_selector.select('zeros') - - def get_default_feature_mask_fn(self) -> FeatureMaskFunction: - """ - Defines default feature mask function for the modality's explainers. - - Returns: - FeatureMaskFunction: No Mask baseline function. - """ - return self.feature_mask_fn_selector.select('no_mask_2d') - - def get_default_postprocessors(self) -> List[PostProcessor]: - """ - Defines default post-processors list for the modality's explainers. + self._util_functions = defaultdict(FunctionSelector) + self._set_util_functions() + + @property + def dtype_key(self): + # TODO: more elegant way + if 'float' in str(self.dtype): + return float + if 'int' in str(self.dtype) or 'long' in str(self.dtype): + return int + raise ValueError(f'No matched key for {self.dtype}') + + @property + def util_functions(self): + return self._util_functions + + def _set_util_functions(self): + for space_key, spaces in self.UTIL_FUNCTIONS.items(): + space = spaces[(self.dtype_key, self.ndims)] + for fn_key, fn_type in space.items(): + # set non-varying kwargs such as pooling dim or token id + if ( + space_key == 'pooling_fn' + and 'pooling_dim' not in self._util_functions[space_key].default_kwargs + ): + self._util_functions[space_key].add_default_kwargs( + 'pooling_dim', self.pooling_dim) + if fn_type is TokenBaselineFunction: + self._util_functions[space_key].add_default_kwargs( + 'token_id', self.mask_token_id, choice='token') + if fn_type is MeanBaselineFunction: + dim = self.pooling_dim or 0 + self._util_functions[space_key].add_default_kwargs( + 'dim', dim, choice='mean') + + # add fn_type to space + self._util_functions[space_key].add(fn_key, fn_type) - Returns: - List[PostProcessor]: Identity PostProcessors. - """ - return [ - PostProcessor( - pooling_fn=self.pooling_fn_selector.select(pm), - normalization_fn=self.normalization_fn_selector.select(nm), - ) for pm in self.pooling_fn_selector.choices - for nm in self.normalization_fn_selector.choices - ] diff --git a/pnpxai/core/recommender/recommender.py b/pnpxai/core/recommender/recommender.py index b4b1d4db..367529f9 100644 --- a/pnpxai/core/recommender/recommender.py +++ b/pnpxai/core/recommender/recommender.py @@ -1,35 +1,39 @@ from typing import List, Type, Dict, Set, Any, Sequence, Tuple, Union +from collections import defaultdict from dataclasses import dataclass, asdict from tabulate import tabulate +import itertools from pnpxai.core._types import Model from pnpxai.core.modality.modality import Modality -from pnpxai.core.detector import detect_model_architecture -from pnpxai.core.detector.types import ( - ModuleType, - Linear, - Convolution +from pnpxai.core.detector import ( + detect_model_architecture, + detect_data_modality, ) +from pnpxai.core.detector.detector import _data_modality_maybe +from pnpxai.core.detector.types import ModuleType, Attention from pnpxai.explainers.base import Explainer from pnpxai.explainers import ( GradCam, GuidedGradCam, + RAP, AVAILABLE_EXPLAINERS ) from pnpxai.utils import format_into_tuple -CAM_BASED_EXPLAINERS = {GradCam, GuidedGradCam} +ATTENTION_NOT_SUPPORTED_EXPLAINERS = {GradCam, GuidedGradCam, RAP} @dataclass class RecommenderOutput: + detected_modality: str detected_architectures: Set[ModuleType] - explainers: list + explainers: List[Explainer] def print_tabular(self): print(tabulate([ - [k, [v.__name__ for v in vs]] + [k, [v.__name__ if not isinstance(v, str) else v for v in vs]] for k, vs in asdict(self).items() ])) @@ -59,51 +63,82 @@ class XaiRecommender: """ def __init__(self): - self.architecture_to_explainers_map = self._build_architecture_to_explainers_map() - - def _build_architecture_to_explainers_map(self): - map_data = {} + self._map_by_architecture = self._build_map_by_architecture() + self._map_by_modality = self._build_map_by_modality() + + @property + def map_by_architecture(self): + return self._map_by_architecture + + @property + def map_by_modality(self): + return self._map_by_modality + + def _build_map_by_architecture(self): + map_data = defaultdict(set) for explainer_type in AVAILABLE_EXPLAINERS: for arch in explainer_type.SUPPORTED_MODULES: - if arch not in map_data: - map_data[arch] = set() map_data[arch].add(explainer_type) return RecommendationMap(map_data, ["architecture", "explainers"]) + def _build_map_by_modality(self): + map_data = defaultdict(set) + for explainer_type in AVAILABLE_EXPLAINERS: + if not hasattr(explainer_type, 'SUPPORTED_DTYPES'): + continue + combs = itertools.product( + explainer_type.SUPPORTED_DTYPES, + explainer_type.SUPPORTED_NDIMS, + ) + for comb in combs: + mod_nm = _data_modality_maybe(*comb) + if not mod_nm: + continue + map_data[mod_nm].add(explainer_type) + return RecommendationMap(map_data, ["modality", "explainers"]) + def _filter_explainers( self, modality: Union[Modality, Tuple[Modality]], - arch: Set[ModuleType], + architecture: Set[ModuleType], ) -> List[Type[Explainer]]: """ Filters explainers based on the user's question, task, and model architecture. Args: - modaltiy (Union[Modality, Tuple[Modality]]): Modality of the input data (e.g., ImageModality, TextModality, TabularModality). - - arch (Set[ModuleType]): Set of neural network module types (e.g., nn.Linear, nn.Conv2d). + - architecture (Set[ModuleType]): Set of neural network module types (e.g., nn.Linear, nn.Conv2d). Returns: - List[Set[Type[Explainer]]]: List of compatible explainers based on the given inputs. """ # question_to_method = QUESTION_TO_EXPLAINERS.get(question, set()) - explainers = [] + explainers = defaultdict(set) + explainers['modality'].update(AVAILABLE_EXPLAINERS) for mod in format_into_tuple(modality): - modality_to_explainers = set(mod.EXPLAINERS) - arch_to_explainers = set.union(*( - self.architecture_to_explainers_map.data.get( - module_type, set()) - for module_type in arch - )) - explainers_mod = set.intersection( - modality_to_explainers, arch_to_explainers) - if arch.difference({Convolution, Linear}): - explainers_mod = explainers_mod.difference( - CAM_BASED_EXPLAINERS) - explainers.append(explainers_mod) - explainers = set.intersection(*explainers) + mod_nm = _data_modality_maybe(mod.dtype_key, mod.ndims) + if mod_nm is None: + raise ValueError('Cannot match data modality') + explainers['modality'].difference_update( + explainers['modality'].difference( + self._map_by_modality.data[mod_nm] + ) + ) + for arch in architecture: + if arch in self._map_by_architecture.data: + explainers['architecture'].update( + self._map_by_architecture.data[arch] + ) + if Attention in architecture: + explainers['architecture'].difference_update(ATTENTION_NOT_SUPPORTED_EXPLAINERS) + explainers = set.intersection(*explainers.values()) return list(explainers) - def recommend(self, modality: Union[Modality, Tuple[Modality]], model: Model) -> RecommenderOutput: + def recommend( + self, + modality: Union[Modality, Tuple[Modality]], + model: Model, + ) -> RecommenderOutput: """ Recommends explainers and evaluation metrics based on the user's input. @@ -114,11 +149,11 @@ def recommend(self, modality: Union[Modality, Tuple[Modality]], model: Model) -> Returns: - RecommenderOutput: An object containing recommended explainers. """ - arch = detect_model_architecture(model) - explainers = self._filter_explainers(modality, arch) - # metrics = self._suggest_metrics(explainers) + mod_nms = detect_data_modality(modality) + architecture = detect_model_architecture(model) + explainers = self._filter_explainers(modality, architecture) return RecommenderOutput( - detected_architectures=arch, + detected_modality=mod_nms, + detected_architectures=architecture, explainers=_sort_by_name(explainers), - # metrics=_sort_by_name(metrics), ) diff --git a/pnpxai/core/utils.py b/pnpxai/core/utils.py new file mode 100644 index 00000000..a8a83d0d --- /dev/null +++ b/pnpxai/core/utils.py @@ -0,0 +1,110 @@ +from typing import Sequence, Optional, Callable, Any, List, Dict, Union +import torch +import inspect + +from pnpxai.utils import format_into_tuple + + +def default_output_modifier(outputs): + return outputs + + +class ModelWrapper(torch.nn.Module): + def __init__( + self, + model: torch.nn.Module, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], torch.Tensor]] = None, + ): + super().__init__() + forward_params = inspect.signature(model.forward).parameters + + self.model = model + self.target_input_keys = target_input_keys or [next(iter(forward_params))] + self.additional_input_keys = additional_input_keys or [] + self.output_modifier = output_modifier or default_output_modifier + + self._device = next(iter(model.parameters())).device + self._required_order = self.target_input_keys + self.additional_input_keys + self._validate_input_keys(forward_params) + _key = next(iter(self._required_order)) + assert isinstance(_key, (str, int)), ( + f'Unsupported key type: {type(_key)}. Must be one of [str, int].', + ) + self._key_type = type(_key) + + def _validate_input_keys(self, forward_params): + for key in self._required_order: + # validations for dict-like batch + if isinstance(key, str): + if key not in forward_params: + raise ValueError( + f"'input_key {key} not found in model forward params.") + + # validations for tuple-like batch + elif isinstance(key, int): + if key > len(forward_params)-1: + raise ValueError( + f"'input_key {key} must be lesser than {len(forward_params)-1}.") + + @property + def device(self): + return self._device + + @property + def required_order(self): + return self._required_order + + def extract_inputs(self, batch: Union[Sequence, Dict]): + if isinstance(batch, Sequence): + batch = list(batch) + return tuple(batch[key].to(self.device) for key in self._required_order) + else: + return {key: batch[key].to(self.device) for key in self._required_order} + + def extract_target_inputs(self, batch: Union[Sequence, Dict]): + if isinstance(batch, Sequence): + batch = list(batch) + return tuple(batch[key].to(self.device) for key in self.target_input_keys) + else: + return {key: batch[key].to(self.device) for key in self.target_input_keys} + + def extract_additional_inputs(self, batch: Union[Sequence, Dict]): + if isinstance(batch, Sequence): + batch = list(batch) + return tuple(batch[key].to(self.device) for key in self.additional_input_keys) + else: + return {key: batch[key].to(self.device) for key in self.additional_input_keys} + + def format_inputs(self, inputs: Union[torch.Tensor, Sequence, Dict]): + if isinstance(inputs, torch.Tensor): + inputs = format_into_tuple(inputs) + if isinstance(inputs, Sequence): + inputs = list(inputs) + return tuple(dict(zip(self._required_order, inputs)).values()) + return tuple(inputs[key].to(self.device) for key in self._required_order) + + def format_target_inputs(self, inputs: Union[torch.Tensor, Sequence, Dict]): + if isinstance(inputs, torch.Tensor): + inputs = format_into_tuple(inputs) + if isinstance(inputs, Sequence): + inputs = list(inputs) + return tuple(dict(zip(self.target_input_keys, inputs)).values()) + return tuple(inputs[key].to(self.device) for key in self.target_input_keys) + + def format_additional_inputs(self, inputs: Union[Sequence, Dict]): + if isinstance(inputs, torch.Tensor): + inputs = format_into_tuple(inputs) + if isinstance(inputs, Sequence): + inputs = list(inputs) + return tuple( + inputs[key].to(self.device) for key in self.additional_input_keys) + + def forward(self, *formatted_inputs): + if self._key_type is str: + outputs = self.model.forward( + **dict(zip(self._required_order, formatted_inputs))) + elif self._key_type is int: + outputs = self.model.forward(*formatted_inputs) + return self.output_modifier(outputs) diff --git a/pnpxai/evaluator/metrics/base.py b/pnpxai/evaluator/metrics/base.py index ba8040d0..eee0b540 100644 --- a/pnpxai/evaluator/metrics/base.py +++ b/pnpxai/evaluator/metrics/base.py @@ -1,16 +1,14 @@ import abc import sys -import warnings -from typing import Optional, Union, Callable, Tuple +from typing import Optional, Union, Callable, Tuple, List, Any, Dict import copy import torch from torch import nn +from pnpxai.core.utils import ModelWrapper from pnpxai.core._types import ExplanationType -from pnpxai.explainers import GradCam from pnpxai.explainers.base import Explainer -from pnpxai.explainers.utils.postprocess import PostProcessor # Ensure compatibility with Python 2/3 ABC = abc.ABC if sys.version_info >= (3, 4) else abc.ABCMeta(str("ABC"), (), {}) @@ -41,13 +39,30 @@ class Metric(ABC): SUPPORTED_EXPLANATION_TYPE: ExplanationType = "attribution" def __init__( - self, model: nn.Module, explainer: Optional[Explainer] = None, **kwargs + self, + model: nn.Module, + explainer: Optional[Explainer] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], torch.Tensor]] = None, + **kwargs, ): self.model = model.eval() # Set the model to evaluation mode self.explainer = explainer - self.device = next( - model.parameters() - ).device # Determine the device used by the model + self._wrapped_model = ModelWrapper( + model=model, + target_input_keys=target_input_keys, + additional_input_keys=additional_input_keys, + output_modifier=output_modifier, + ) + + @property + def wrapped_model(self): + return self._wrapped_model + + @property + def device(self): + return next(self.model.parameters()).device def __repr__(self): """ @@ -65,6 +80,11 @@ def __repr__(self): ) return f"{self.__class__.__name__}({displayed_attrs})" + def format_inputs(self, inputs: Union[Tuple, Dict]): + forward_args = self._wrapped_model.format_target_inputs(inputs) + additional_forward_args = self._wrapped_model.format_additional_inputs(inputs) + return forward_args, additional_forward_args + def copy(self): """ Creates a shallow copy of the Metric object. @@ -78,10 +98,19 @@ def set_explainer(self, explainer: Explainer): """ Sets the explainer for the metric, ensuring it is associated with the same model. """ - assert self.model is explainer.model, "Must have same model of metric." - clone = self.copy() - clone.explainer = explainer - return clone + for alias in ['model', '_model', 'forward_func']: + if hasattr(explainer, alias): + explainer_model = getattr(explainer, alias) + break + if isinstance(explainer_model, ModelWrapper): + explainer_model = explainer_model.model + # explainer_model = explainer._model if hasattr(explainer, '_model') else explainer.model + assert self.model is explainer_model or self._wrapped_model is explainer_model, "Must have same model of metric." + self.explainer = explainer + return self + # clone = self.copy() + # clone.explainer = explainer + # return clone def set_kwargs(self, **kwargs): """ diff --git a/pnpxai/evaluator/metrics/complexity.py b/pnpxai/evaluator/metrics/complexity.py index 229586e1..6adab96a 100644 --- a/pnpxai/evaluator/metrics/complexity.py +++ b/pnpxai/evaluator/metrics/complexity.py @@ -29,6 +29,8 @@ class Complexity(Metric): Reference: U. Bhatt, A. Weller, and J. M. F. Moura. Evaluating and aggregating feature-based model attributions. In Proceedings of the IJCAI (2020). """ + alias = ['complexity'] + def __init__( self, model: Model, explainer: Optional[Explainer] = None, n_bins: int = 10 @@ -63,4 +65,4 @@ def evaluate( hist, _ = np.histogram(attr.detach().cpu(), bins=self.n_bins) prob_mass = hist / hist.sum() evaluations.append(entropy(prob_mass)) - return torch.tensor(evaluations) + return torch.tensor(evaluations).to(attributions.dtype).to(self.device) diff --git a/pnpxai/evaluator/metrics/mu_fidelity.py b/pnpxai/evaluator/metrics/mu_fidelity.py index f28c406f..c17e196e 100644 --- a/pnpxai/evaluator/metrics/mu_fidelity.py +++ b/pnpxai/evaluator/metrics/mu_fidelity.py @@ -34,6 +34,7 @@ class MuFidelity(Metric): Reference: U. Bhatt, A. Weller, and J. M. F. Moura. Evaluating and aggregating feature-based model attributions. In Proceedings of the IJCAI (2020). """ + alias = ['mu_fidelity', 'mufid'] def __init__( self, diff --git a/pnpxai/evaluator/metrics/pixel_flipping.py b/pnpxai/evaluator/metrics/pixel_flipping.py index 58b293f3..7ff74ee4 100644 --- a/pnpxai/evaluator/metrics/pixel_flipping.py +++ b/pnpxai/evaluator/metrics/pixel_flipping.py @@ -1,19 +1,23 @@ -from typing import Callable, Optional, Union, Tuple, Dict, Any, Type +from typing import Callable, Optional, Union, Tuple, Dict, Any, Type, List import torch import torch.nn.functional as F from torch.nn.modules import Module -from pnpxai.explainers.types import Tensor, TensorOrTupleOfTensors, ForwardArgumentExtractor +from pnpxai.explainers.types import Tensor, TensorOrTupleOfTensors from pnpxai.explainers.base import Explainer -from pnpxai.explainers import GradCam -from pnpxai.explainers.utils.baselines import ZeroBaselineFunction +from pnpxai.explainers.utils.baselines import BaselineFunction, ZeroBaselineFunction from pnpxai.utils import format_into_tuple, format_into_tuple_all -from pnpxai.explainers.utils.postprocess import PostProcessor from pnpxai.evaluator.metrics.base import Metric -BaselineFunction = Union[Callable[[Tensor], Tensor], Tuple[Callable[[Tensor], Tensor]]] +def default_prob_fn(outputs: torch.Tensor) -> torch.Tensor: + return outputs.softmax(-1) + + +def default_pred_fn(outputs: torch.Tensor) -> torch.Tensor: + return outputs.argmax(-1) + class PixelFlipping(Metric): """ @@ -27,7 +31,7 @@ class PixelFlipping(Metric): Attributes: model (Module): The model. explainer (Optional[Explainer]=None): The explainer whose explanations are being evaluated. - channel_dim (int): Target channel dimension. + pooling_dim (int): Target channel dimension. n_steps (int): The number of perturbation steps. baseline_fn (Optional[BaselineFunction]): Function to generate baseline inputs for perturbation. prob_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute probabilities from model outputs. @@ -43,23 +47,25 @@ class PixelFlipping(Metric): def __init__( self, model: Module, - explainer: Optional[Explainer]=None, - channel_dim: int=1, - n_steps: int=10, - baseline_fn: Optional[BaselineFunction]=None, - prob_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.softmax(-1), - pred_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.argmax(-1), - forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, + explainer: Optional[Explainer] = None, + pooling_dim: int = 1, + n_steps: int = 10, + baseline_fn: Optional[BaselineFunction] = None, + prob_fn: Optional[Callable[[Tensor], Tensor]] = None, + pred_fn: Optional[Callable[[Tensor], Tensor]] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], torch.Tensor]] = None, ): - super().__init__(model, explainer) - self.channel_dim = channel_dim + super().__init__( + model, explainer, target_input_keys, + additional_input_keys, output_modifier + ) + self.pooling_dim = pooling_dim self.n_steps = n_steps self.baseline_fn = baseline_fn or ZeroBaselineFunction() - self.prob_fn = prob_fn - self.pred_fn = pred_fn - self.forward_arg_extractor = forward_arg_extractor - self.additional_forward_arg_extractor = additional_forward_arg_extractor + self.prob_fn = prob_fn or default_prob_fn + self.pred_fn = pred_fn or default_pred_fn @torch.no_grad() def evaluate( @@ -67,8 +73,8 @@ def evaluate( inputs: TensorOrTupleOfTensors, targets: Tensor, attributions: TensorOrTupleOfTensors, - attention_mask: Optional[TensorOrTupleOfTensors]=None, - descending: bool=True, + attention_mask: Optional[TensorOrTupleOfTensors] = None, + descending: bool = True, ) -> Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]]: """ Evaluate the explainer's correctness based on the attributions by observing changes in model predictions. @@ -84,12 +90,12 @@ def evaluate( Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]]: A dictionary or tuple of dictionaries containing the probabilities and predictions at each perturbation step. """ - forward_args, additional_forward_args = self._extract_forward_args(inputs) + forward_args, additional_forward_args = self.format_inputs(inputs) formatted: Dict[str, Tuple[Any]] = format_into_tuple_all( forward_args=forward_args, additional_forward_args=additional_forward_args, attributions=attributions, - channel_dim=self.channel_dim, + pooling_dim=self.pooling_dim or (None,)*len(format_into_tuple(forward_args)), baseline_fn=self.baseline_fn, attention_mask=attention_mask or (None,)*len(format_into_tuple(forward_args)), ) @@ -101,25 +107,27 @@ def evaluate( bsz = formatted['forward_args'][0].size(0) results = [] - outputs = self.model( + outputs = self._wrapped_model( *formatted['forward_args'], *formatted['additional_forward_args'], ) init_probs = self.prob_fn(outputs) init_preds = self.pred_fn(outputs) - for pos, forward_arg in enumerate(formatted['forward_args']): - baseline_fn = formatted['baseline_fn'][pos] - attrs = formatted['attributions'][pos] + for loc, forward_arg in enumerate(formatted['forward_args']): + baseline_fn = formatted['baseline_fn'][loc] + attrs = formatted['attributions'][loc] + attrs, original_size = _flatten_if_not_1d(attrs) - if formatted['attention_mask'][pos] is not None: - attn_mask, _ = _flatten_if_not_1d(formatted['attention_mask'][pos]) + if formatted['attention_mask'][loc] is not None: + attn_mask, _ = _flatten_if_not_1d(formatted['attention_mask'][loc]) mask_value = -torch.inf if descending else torch.inf attrs = torch.where(attn_mask == 1, attrs, mask_value) valid_n_features = (~attrs.isinf()).sum(-1) n_flipped_per_step = valid_n_features // self.n_steps - n_flipped_per_step = n_flipped_per_step.clamp(min=1) # ensure at least a pixel flipped + # ensure at least a pixel flipped + n_flipped_per_step = n_flipped_per_step.clamp(min=1) sorted_indices = torch.argsort( attrs, descending=descending, @@ -131,6 +139,10 @@ def evaluate( n_flipped = n_flipped_per_step * step if step + 1 == self.n_steps: n_flipped = valid_n_features + if any((n_flipped - 1) >= attrs.size(-1)): + # All features flipped already + # This break condition works when n_features < self.n_steps + break is_index_of_flipped = ( F.one_hot(n_flipped-1, num_classes=attrs.size(-1)).to(self.device) .flip(-1).cumsum(-1).flip(-1) @@ -140,7 +152,7 @@ def evaluate( is_flipped = _recover_shape_if_flattened(is_flipped, original_size) is_flipped = _match_channel_dim_if_pooled( is_flipped, - formatted['channel_dim'][pos], + formatted['pooling_dim'][loc], forward_arg.size() ) @@ -148,10 +160,10 @@ def evaluate( flipped_forward_arg = baseline * is_flipped + forward_arg * (1 - is_flipped) flipped_forward_args = tuple( - flipped_forward_arg if i == pos else formatted['forward_args'][i] + flipped_forward_arg if i == loc else formatted['forward_args'][i] for i in range(len(formatted['forward_args'])) ) - flipped_outputs = self.model( + flipped_outputs = self._wrapped_model( *flipped_forward_args, *formatted['additional_forward_args'], ) @@ -178,7 +190,6 @@ def _extract_forward_args( return forward_args, additional_forward_args - class MoRF(PixelFlipping): """ A metric class for evaluating the correctness of explanations or attributions using the @@ -191,7 +202,7 @@ class MoRF(PixelFlipping): Attributes: model (Module): The model. explainer (Optional[Explainer]=None): The explainer whose explanations are being evaluated. - channel_dim (int): Target channel dimension. + pooling_dim (int): Target channel dimension. n_steps (int): The number of perturbation steps. baseline_fn (Optional[BaselineFunction]): Function to generate baseline inputs for perturbation. prob_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute probabilities from model outputs. @@ -203,24 +214,27 @@ class MoRF(PixelFlipping): evaluate(inputs, targets, attributions, attention_mask=None): Evaluate the explainer's correctness using the MoRF technique by observing changes in model predictions. """ + alias = ['morf'] def __init__( self, model: Module, - explainer: Optional[Explainer]=None, - channel_dim: int=1, - n_steps: int=10, - baseline_fn: Optional[BaselineFunction]=None, - prob_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.softmax(-1), - pred_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.argmax(-1), - forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, + explainer: Optional[Explainer] = None, + pooling_dim: int = 1, + n_steps: int = 10, + baseline_fn: Optional[BaselineFunction] = None, + prob_fn: Optional[Callable[[Tensor], Tensor]] = None, + pred_fn: Optional[Callable[[Tensor], Tensor]] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], torch.Tensor]] = None, ): super().__init__( - model, explainer, channel_dim, n_steps, + model, explainer, pooling_dim, n_steps, baseline_fn, prob_fn, pred_fn, - forward_arg_extractor, - additional_forward_arg_extractor, + target_input_keys, + additional_input_keys, + output_modifier, ) def evaluate( @@ -228,7 +242,7 @@ def evaluate( inputs: TensorOrTupleOfTensors, targets: Tensor, attributions: TensorOrTupleOfTensors, - attention_mask: Optional[TensorOrTupleOfTensors]=None + attention_mask: Optional[TensorOrTupleOfTensors] = None ) -> TensorOrTupleOfTensors: """ Evaluate the explainer's correctness using the MoRF technique by observing changes in model predictions. @@ -251,7 +265,6 @@ def evaluate( return morf - class LeRF(PixelFlipping): """ A metric class for evaluating the correctness of explanations or attributions using the @@ -264,7 +277,7 @@ class LeRF(PixelFlipping): Attributes: model (Module): The model. explainer (Optional[Explainer]=None): The explainer whose explanations are being evaluated. - channel_dim (int): Target channel dimension. + pooling_dim (int): Target channel dimension. n_steps (int): The number of perturbation steps. baseline_fn (Optional[BaselineFunction]): Function to generate baseline inputs for perturbation. prob_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute probabilities from model outputs. @@ -276,24 +289,27 @@ class LeRF(PixelFlipping): evaluate(inputs, targets, attributions, attention_mask=None): Evaluate the explainer's correctness using the LeRF technique by observing changes in model predictions. """ + alias = ['lerf'] def __init__( self, model: Module, - explainer: Optional[Explainer]=None, - channel_dim: int=1, - n_steps: int=10, - baseline_fn: Optional[BaselineFunction]=None, - prob_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.softmax(-1), - pred_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.argmax(-1), - forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, + explainer: Optional[Explainer] = None, + pooling_dim: int = 1, + n_steps: int = 10, + baseline_fn: Optional[BaselineFunction] = None, + prob_fn: Optional[Callable[[Tensor], Tensor]] = None, + pred_fn: Optional[Callable[[Tensor], Tensor]] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], torch.Tensor]] = None, ): super().__init__( - model, explainer, channel_dim, n_steps, + model, explainer, pooling_dim, n_steps, baseline_fn, prob_fn, pred_fn, - forward_arg_extractor, - additional_forward_arg_extractor, + target_input_keys, + additional_input_keys, + output_modifier, ) def evaluate( @@ -301,7 +317,7 @@ def evaluate( inputs: TensorOrTupleOfTensors, targets: Tensor, attributions: TensorOrTupleOfTensors, - attention_mask: TensorOrTupleOfTensors=None, + attention_mask: TensorOrTupleOfTensors = None, ) -> TensorOrTupleOfTensors: """ Evaluate the explainer's correctness using the LeRF technique by observing changes in model predictions. @@ -316,7 +332,8 @@ def evaluate( TensorOrTupleOfTensors: The mean probabilities at each perturbation step, indicating the impact of perturbing the least relevant features first. """ - pf_results = super().evaluate(inputs, targets, attributions, attention_mask, False) + pf_results = super().evaluate( + inputs, targets, attributions, attention_mask, False) pf_results = format_into_tuple(pf_results) lerf = tuple(result['probs'].mean(-1) for result in pf_results) if len(lerf) == 1: @@ -324,7 +341,6 @@ def evaluate( return lerf - class AbPC(PixelFlipping): """ A metric class for evaluating the correctness of explanations or attributions using the @@ -338,7 +354,7 @@ class AbPC(PixelFlipping): Attributes: model (Module): The model. explainer (Optional[Explainer]=None): The explainer whose explanations are being evaluated. - channel_dim (int): Target channel dimension. + pooling_dim (int): Target channel dimension. n_steps (int): The number of perturbation steps. baseline_fn (Optional[BaselineFunction]): Function to generate baseline inputs for perturbation. prob_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute probabilities from model outputs. @@ -351,25 +367,28 @@ class AbPC(PixelFlipping): evaluate(inputs, targets, attributions, attention_mask=None): Evaluate the explainer's correctness using the AbPC technique by observing changes in model predictions. """ + alias = ['abpc'] def __init__( self, model: Module, explainer: Optional[Explainer]=None, - channel_dim: int=1, + pooling_dim: int=1, n_steps: int=10, baseline_fn: Optional[BaselineFunction]=None, prob_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.softmax(-1), pred_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.argmax(-1), - forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, lb: float=-1., + target_input_keys: Optional[List[Union[str, int]]]=None, + additional_input_keys: Optional[List[Union[str, int]]]=None, + output_modifier: Optional[Callable[[Any], torch.Tensor]]=None, ): super().__init__( - model, explainer, channel_dim, n_steps, + model, explainer, pooling_dim, n_steps, baseline_fn, prob_fn, pred_fn, - forward_arg_extractor, - additional_forward_arg_extractor, + target_input_keys, + additional_input_keys, + output_modifier, ) self.lb = lb @@ -378,7 +397,7 @@ def evaluate( inputs: TensorOrTupleOfTensors, targets: Tensor, attributions: TensorOrTupleOfTensors, - attention_mask: Optional[TensorOrTupleOfTensors]=None, + attention_mask: Optional[TensorOrTupleOfTensors] = None, return_pf=False, ) -> TensorOrTupleOfTensors: """ @@ -396,11 +415,13 @@ def evaluate( indicating the impact of perturbing the most and least relevant features. """ # pf by ascending order: lerf - pf_ascs = super().evaluate(inputs, targets, attributions, attention_mask, False) + pf_ascs = super().evaluate( + inputs, targets, attributions, attention_mask, False) pf_ascs = format_into_tuple(pf_ascs) # pf by descending order: morf - pf_descs = super().evaluate(inputs, targets, attributions, attention_mask, True) + pf_descs = super().evaluate( + inputs, targets, attributions, attention_mask, True) pf_descs = format_into_tuple(pf_descs) # abpc @@ -419,6 +440,7 @@ def _extract_target_probs(probs, targets): # please ensure probs.size() == (batch_size, n_classes) return probs[torch.arange(probs.size(0)), targets] + def _sort_by_order(x, permutation): d1, d2 = x.size() ret = x[ @@ -427,22 +449,26 @@ def _sort_by_order(x, permutation): ].view(d1, d2) return ret + def _flatten_if_not_1d(batch): if batch.dim() > 2: original_size = batch.size() return batch.flatten(1), original_size return batch, batch.size() + def _recover_shape_if_flattened(batch, original_size): if batch.size() == original_size: return batch return batch.view(*original_size) -def _match_channel_dim_if_pooled(batch, channel_dim, x_batch_size): + +def _match_channel_dim_if_pooled(batch, pooling_dim, x_batch_size): if batch.size() == x_batch_size: return batch - n_channels = x_batch_size[channel_dim] - n_repeats = tuple(n_channels if d == channel_dim else 1 for d in range(len(x_batch_size))) - return batch.unsqueeze(channel_dim).repeat(*n_repeats) - - + n_channels = x_batch_size[pooling_dim] + n_repeats = tuple( + n_channels if d == pooling_dim else 1 + for d in range(len(x_batch_size)) + ) + return batch.unsqueeze(pooling_dim).repeat(*n_repeats) diff --git a/pnpxai/evaluator/optimizer/objectives.py b/pnpxai/evaluator/optimizer/objectives.py index 6c87e542..48974679 100644 --- a/pnpxai/evaluator/optimizer/objectives.py +++ b/pnpxai/evaluator/optimizer/objectives.py @@ -1,14 +1,15 @@ -from typing import Optional, Any -from torch import Tensor + +from typing import Optional, Any, Callable, Union, Dict, Literal + +from tqdm import tqdm from optuna.trial import Trial, TrialState -from pnpxai.core._types import TensorOrTupleOfTensors -from pnpxai.core.modality.modality import Modality, TextModality +from pnpxai.core._types import DataSource +from pnpxai.core.modality.modality import Modality from pnpxai.explainers import Explainer, KernelShap, Lime from pnpxai.explainers.utils.postprocess import PostProcessor, Identity from pnpxai.evaluator.metrics.base import Metric -from pnpxai.evaluator.optimizer.suggestor import suggest -from pnpxai.utils import format_into_tuple, format_out_tuple_if_single, generate_param_key +from pnpxai.utils import format_into_tuple, format_out_tuple_if_single class Objective: @@ -38,64 +39,36 @@ class is designed to be callable and can be used within an optimization framewor def __init__( self, + modality: Modality, explainer: Explainer, - postprocessor: PostProcessor, metric: Metric, - modality: Modality, - inputs: Optional[TensorOrTupleOfTensors] = None, - targets: Optional[Tensor] = None, + data: DataSource, + disable_tunable_params: Optional[Dict[str, Any]] = None, + target_class_extractor: Optional[Callable[[Any], Any]] = None, + label_key: Optional[Union[str, int]] = -1, + target_labels: bool = False, + show_progress: bool = False, + errors: Literal['raise', 'ignore'] = 'raise', ): + self.modality = modality self.explainer = explainer - self.postprocessor = postprocessor self.metric = metric - self.modality = modality - self.inputs = inputs - self.targets = targets - - def set_inputs(self, inputs): - """ - Sets the input data for the objective. + self.data = data + self.target_class_extractor = target_class_extractor + self.label_key = label_key + self.target_labels = target_labels + self.show_progress = show_progress + self.disable_tunable_params = disable_tunable_params or {} + self.disable_tunable_params_() + self.errors = errors + + def disable_tunable_params_(self): + if self.explainer.is_tunable(): + for key, value in self.disable_tunable_params.items(): + param = getattr(self.explainer, key) + param.update_value(value) # fix value + param.disable() # disable tuning - Parameters: - inputs (TensorOrTupleOfTensors): - The input data to be used by the explainer. - - Returns: - Objective: The updated Objective instance. - """ - self.inputs = inputs - return self - - def set_targets(self, targets): - """ - Sets the target labels for the objective. - - Parameters: - targets (Tensor): - The target labels corresponding to the input data. - - Returns: - Objective: The updated Objective instance. - """ - self.targets = targets - return self - - def set_data(self, inputs, targets): - """ - Sets both the input data and target labels for the objective. - - Parameters: - inputs (TensorOrTupleOfTensors): - The input data to be used by the explainer. - targets (Tensor): - The target labels corresponding to the input data. - - Returns: - Objective: The updated Objective instance. - """ - self.set_inputs(inputs) - self.set_targets(targets) - return self def __call__(self, trial: Trial) -> float: """ @@ -114,61 +87,75 @@ def __call__(self, trial: Trial) -> float: results contain non-countable values like `nan` or `inf`. """ # suggest explainer - explainer = suggest(trial, self.explainer, self.modality, key=self.EXPLAINER_KEY) + if self.explainer.__class__.is_tunable(): + suggested_explainer = self.explainer.suggest(trial, key=self.EXPLAINER_KEY) + else: + suggested_explainer = self.explainer # suggest postprocessor modalities = format_into_tuple(self.modality) - is_multi_modal = len(modalities) > 1 - postprocessor = [] - for pp, modality in zip( - format_into_tuple(self.postprocessor), - modalities, - ): - force_params = {} + suggested_pps = tuple() + for pp_loc, modality in enumerate(modalities): if ( - isinstance(explainer, (Lime, KernelShap)) - and isinstance(modality, TextModality) + isinstance(suggested_explainer, (Lime, KernelShap)) + and modality.dtype_key == int ): - force_params['pooling_fn'] = Identity() - postprocessor.append(suggest( - trial, pp, modality, - key=generate_param_key( - self.POSTPROCESSOR_KEY, - modality.__class__.__name__ if is_multi_modal else None - ), - force_params=force_params, - )) - postprocessor = format_out_tuple_if_single(tuple(postprocessor)) - - # Ignore duplicated samples + modality.util_functions['pooling_fn'].add_fallback_option( + key='identity', value=Identity) + pp_base = PostProcessor(modality, 'identity') + pp_base.disable_tunable_param('pooling_method') + else: + pp_base = PostProcessor(modality) + suggested_pp = pp_base.suggest( + trial, + key=f'{self.POSTPROCESSOR_KEY}.{pp_loc}', + ) + suggested_pps += (suggested_pp,) + + ''' + Although the number of trials is larger than the number of search grids, + the number of actual trials will be limited to the number of search grids + by the following exception. + ''' + # ignore duplicated samples states_to_consider = (TrialState.COMPLETE,) - trials_to_consider = trial.study.get_trials(deepcopy=False, states=states_to_consider) + trials_to_consider = trial.study.get_trials( + deepcopy=False, + states=states_to_consider, + ) for t in reversed(trials_to_consider): if trial.params == t.params: - trial.set_user_attr('explainer', explainer) - trial.set_user_attr('postprocessor', postprocessor) return t.value - # Explain and postprocess - attrs = explainer.attribute(self.inputs, self.targets) - postprocessed = tuple( - pp(attr) for pp, attr in zip( - format_into_tuple(postprocessor), - format_into_tuple(attrs), + # actual trial + evals_sum = 0. + pbar = self.data + if self.show_progress: + pbar = tqdm(pbar, total=len(pbar)) + for batch in pbar: + inputs = suggested_explainer._wrapped_model.extract_inputs(batch) + if self.target_labels: + targets = batch[self.label_key] + else: + formatted = suggested_explainer._wrapped_model.format_inputs(inputs) + outputs = suggested_explainer._wrapped_model(*formatted) + targets = self.target_class_extractor(outputs) + try: + attrs = format_into_tuple(suggested_explainer.attribute(inputs, targets)) + except Exception as e: + if self.errors == 'raise': + raise e + return float('nan') + attrs_pp = tuple() + for pp, attr in zip(suggested_pps, attrs): + attr_pp = pp(attr) + if any(a.isnan().sum() > 0 or a.isinf().sum() > 0 for a in attr_pp): + return float('nan') + attrs_pp += (attr_pp,) + attrs_pp = format_out_tuple_if_single(attrs_pp) + evals = self.metric.set_explainer(suggested_explainer).evaluate( + inputs, targets, attrs_pp, ) - ) - - if any(pp.isnan().sum() > 0 or pp.isinf().sum() > 0 for pp in postprocessed): - # Treat a failure as nan - return float('nan') - - postprocessed = format_out_tuple_if_single(postprocessed) - metric = self.metric.set_explainer(explainer) - evals = format_into_tuple( - metric.evaluate(self.inputs, self.targets, postprocessed) - ) - - # Keep current explainer and postprocessor on trial - trial.set_user_attr('explainer', explainer) - trial.set_user_attr('postprocessor', postprocessor) - return (sum(*evals) / len(evals)).item() + for ev in format_into_tuple(evals): + evals_sum += ev.sum().item() + return evals_sum / len(self.data.dataset) diff --git a/pnpxai/evaluator/optimizer/optimize.py b/pnpxai/evaluator/optimizer/optimize.py index 2dfa7b1b..a32abcc7 100644 --- a/pnpxai/evaluator/optimizer/optimize.py +++ b/pnpxai/evaluator/optimizer/optimize.py @@ -6,6 +6,7 @@ get_default_n_trials, ) + def optimize( objective: Objective, direction: Literal['maximize', 'minimize'] = 'maximize', diff --git a/pnpxai/evaluator/optimizer/suggestor.py b/pnpxai/evaluator/optimizer/suggestor.py deleted file mode 100644 index 29ef377c..00000000 --- a/pnpxai/evaluator/optimizer/suggestor.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import Optional, Any, Dict, Type, Union, Tuple -from optuna import Trial - -from pnpxai.core.modality.modality import Modality -from pnpxai.explainers.utils import UtilFunction -from pnpxai.utils import format_into_tuple, format_out_tuple_if_single, generate_param_key - - -def map_suggest_method( - trial: Trial, - method_type: Type[Any], -): - return { - list: trial.suggest_categorical, - int: trial.suggest_int, - float: trial.suggest_float, - }.get(method_type, None) - - -def suggest( - trial: Trial, - obj: Any, - modality: Union[Modality, Tuple[Modality]], - key: Optional[str] = None, - force_params: Optional[Dict[str, Any]] = None, -): - """ - A utility function that suggests parameters for a given object based on an optimization trial. - The function recursively tunes the parameters of the object according to the modality (or - modalities) provided. - - Parameters: - trial (Trial): - The trial object from an optimization framework like Optuna, used to suggest - parameters for tuning. - obj (Any): - The object whose parameters are being tuned. This object must implement - `get_tunables()` and `set_kwargs()` methods. - modality (Union[Modality, Tuple[Modality]]): - The modality (e.g., image, text) or tuple of modalities the object is operating on. - If multiple modalities are provided, the function handles multi-modal tuning. - key (Optional[str], optional): - An optional key to uniquely identify the set of parameters being tuned, - useful for differentiating parameters in multi-modal scenarios. Defaults to None. - - Returns: - Any: - The object with its parameters set according to the trial suggestions. - - Notes: - - The function uses `map_suggest_method` to map the tuning method based on the method - type provided in the tunables. - - It supports multi-modal tuning, where different modalities may require different - parameters to be tuned. - - For utility functions (`UtilFunction`), the function further tunes parameters - based on the selected function from the modality. - - Example: - Assuming `trial` is an instance of `optuna.trial.Trial`, and `explainer: Explainer` is an object - with tunable parameters, you can tune it as follows: - - ```python - tuned_explainer = suggest(trial, explainer, modality) - ``` - """ - is_multi_modal = len(format_into_tuple(modality)) > 1 - force_params = force_params or {} - for param_nm, (method_type, method_kwargs) in obj.get_tunables().items(): - if param_nm in force_params: - param = force_params[param_nm] - else: - method = map_suggest_method(trial, method_type) - if method is not None: - param = method( - name=generate_param_key(key, param_nm), - **method_kwargs - ) - elif issubclass(method_type, UtilFunction): - param = [] - for mod in format_into_tuple(modality): - fn_selector = mod.map_fn_selector(method_type) - _param_nm, (_method_type, _method_kwargs) = next( - iter(fn_selector.get_tunables().items()) - ) - _param_nm = generate_param_key( - param_nm, - mod.__class__.__name__ if is_multi_modal else None, - _param_nm, - ) # update param_nm - _method = map_suggest_method(trial, _method_type) - fn_nm = _method( - name=generate_param_key(key, _param_nm), - **_method_kwargs - ) - fn = fn_selector.select(fn_nm) - param.append(suggest( - trial, fn, mod, - key=generate_param_key( - key, param_nm, - mod.__class__.__name__ if is_multi_modal else None, - ), - )) - param = format_out_tuple_if_single(tuple(param)) - obj = obj.set_kwargs(**{param_nm: param}) - return obj diff --git a/pnpxai/evaluator/optimizer/types.py b/pnpxai/evaluator/optimizer/types.py index e32d4343..445b6b3d 100644 --- a/pnpxai/evaluator/optimizer/types.py +++ b/pnpxai/evaluator/optimizer/types.py @@ -1,3 +1,4 @@ +from typing import Dict, Any, Optional from dataclasses import dataclass import optuna @@ -24,6 +25,8 @@ class OptimizationOutput: contains information about the optimization trials and results. """ - explainer: Explainer - postprocessor: PostProcessor - study: optuna.study.Study + explainer: Optional[Explainer] + postprocessor: Optional[PostProcessor] + study: Optional[optuna.study.Study] + value: Optional[float] + params: Optional[Dict[str, Any]] diff --git a/pnpxai/evaluator/optimizer/utils.py b/pnpxai/evaluator/optimizer/utils.py index f0dacfdc..fa7ea6a2 100644 --- a/pnpxai/evaluator/optimizer/utils.py +++ b/pnpxai/evaluator/optimizer/utils.py @@ -15,21 +15,10 @@ 'tpe': 50, } + def load_sampler(sampler: Literal['grid', 'random', 'tpe']='tpe', **kwargs): return AVAILABLE_SAMPLERS[sampler](**kwargs) + def get_default_n_trials(sampler): return DEFAULT_N_TRIALS[sampler] - -def nest_params(flattened_params): - nested = {} - for k, v in flattened_params.items(): - ref = nested - splits = k.split('.') - while splits: - s = splits.pop(0) - if s not in ref: - _v = {} if len(splits) > 0 else v - ref[s] = _v - ref = ref[s] - return nested diff --git a/pnpxai/explainers/__init__.py b/pnpxai/explainers/__init__.py index b19d4ea9..86a59d36 100644 --- a/pnpxai/explainers/__init__.py +++ b/pnpxai/explainers/__init__.py @@ -62,8 +62,8 @@ KernelShap, Lime, Gfgp, - AttentionRollout, - TransformerAttribution, + # AttentionRollout, + # TransformerAttribution, ) EXPLAINERS_FOR_TABULAR = [] diff --git a/pnpxai/explainers/attention_rollout.py b/pnpxai/explainers/attention_rollout.py index 266384d4..793d58cf 100644 --- a/pnpxai/explainers/attention_rollout.py +++ b/pnpxai/explainers/attention_rollout.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Callable, Tuple, List, Sequence, Optional, Union, Literal +from typing import Callable, Tuple, List, Sequence, Optional, Union, Literal, Any import torch from torch import Tensor @@ -15,20 +15,25 @@ from pnpxai.explainers.attentions.attributions import SavingAttentionAttributor from pnpxai.explainers.attentions.rules import CGWAttentionPropagation from pnpxai.explainers.attentions.module_converters import default_attention_converters +from pnpxai.explainers.base import Tunable +from pnpxai.explainers.types import TunableParameter from pnpxai.explainers.zennit.attribution import Gradient, LayerGradient from pnpxai.explainers.zennit.base import ZennitExplainer -from pnpxai.explainers.utils import captum_wrap_model_input -from pnpxai.explainers.types import ForwardArgumentExtractor +from pnpxai.explainers.utils import ModelWrapperForLayerAttribution + def rollout_min_head_fusion_function(attn_weights): return attn_weights.min(axis=1).values + def rollout_max_head_fusion_function(attn_weights): return attn_weights.max(axis=1).values + def rollout_mean_head_fusion_function(attn_weights): return attn_weights.mean(axis=1) + def _get_rollout_head_fusion_function(method: Literal['min', 'max', 'mean']): if method == 'min': return rollout_min_head_fusion_function @@ -38,7 +43,7 @@ def _get_rollout_head_fusion_function(method: Literal['min', 'max', 'mean']): return rollout_mean_head_fusion_function -class AttentionRolloutBase(ZennitExplainer): +class AttentionRolloutBase(ZennitExplainer, Tunable): """ Base class for `AttentionRollout` and `TransformerAttribution` explainers. @@ -61,30 +66,56 @@ class AttentionRolloutBase(ZennitExplainer): """ SUPPORTED_MODULES = [Attention] + SUPPORTED_DTYPES = [float] + SUPPORTED_NDIMS = [4] def __init__( self, model: Module, - interpolate_mode: Literal['bilinear']='bilinear', - head_fusion_method: Literal['min', 'max', 'mean']='min', - discard_ratio: float=0.9, - forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - n_classes: Optional[int]=None, + interpolate_mode: Literal['bilinear', 'bicubic'] = 'bilinear', + head_fusion_method: Literal['min', 'max', 'mean'] = 'min', + discard_ratio: float = 0.9, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + n_classes: Optional[int] = None, ) -> None: - super().__init__( + self.interpolate_mode = TunableParameter( + name='interpolate_mode', + current_value=interpolate_mode, + dtype=str, + is_leaf=True, + space={'choices': ['bilinear', 'bicubic']}, + ) + self.head_fusion_method = TunableParameter( + name='head_fusion_method', + current_value=head_fusion_method, + dtype=str, + is_leaf=True, + space={'choices': ['min', 'max', 'mean']}, + ) + self.discard_ratio = TunableParameter( + name='discard_ratio', + current_value=discard_ratio, + dtype=float, + is_leaf=True, + space={'low': 0., 'high': .95, 'step': .05}, + ) + ZennitExplainer.__init__( + self, model, - forward_arg_extractor, - additional_forward_arg_extractor, + target_input_keys, + additional_input_keys, + output_modifier, n_classes, ) - self.interpolate_mode = interpolate_mode - self.head_fusion_method = head_fusion_method - self.discard_ratio = discard_ratio + Tunable.__init__(self) + self.register_tunable_params([ + self.interpolate_mode, self.head_fusion_method, self.discard_ratio]) @property def head_fusion_function(self): - return _get_rollout_head_fusion_function(self.head_fusion_method) + return _get_rollout_head_fusion_function(self.head_fusion_method.current_value) @abstractmethod def collect_attention_map(self, inputs, targets): @@ -95,17 +126,20 @@ def rollout(self, *args): raise NotImplementedError def _discard(self, fused_attn_map): - org_size = fused_attn_map.size() # keep size to recover it after discard + org_size = fused_attn_map.size() # keep size to recover it after discard + flattened = fused_attn_map.flatten(1) bsz, n_tokens = flattened.size() - attn_cls = flattened[:, 0] # keep attn scores of cls token to recover them after discard + + # keep attn scores of cls token to recover them after discard + attn_cls = flattened[:, 0] _, indices = flattened.topk( - k=int(n_tokens*self.discard_ratio), + k=int(n_tokens*self.discard_ratio.current_value), dim=-1, largest=False, ) - flattened[torch.arange(bsz)[:, None], indices] = 0. # discard - flattened[:, 0] = attn_cls # recover attn scores of cls token + flattened[torch.arange(bsz)[:, None], indices] = 0. # discard + flattened[:, 0] = attn_cls # recover attn scores of cls token discarded = flattened.view(*org_size) return discarded @@ -124,6 +158,12 @@ def attribute( Returns: torch.Tensor: The result of the explanation. """ + forward_args, _ = self.format_inputs(inputs) + assert ( + len(forward_args) == 1 + ), "AttentionRollout for multiple inputs is not supported." + + # inputs = inputs[0] attn_maps = self.collect_attention_map(inputs, targets) with torch.no_grad(): @@ -131,8 +171,8 @@ def attribute( # attn btw cls and patches attrs = rollout[:, 0, 1:] - n_patches = attrs.size(-1) - bsz, _, h, w = inputs.size() + n_patches = attrs.size(-1) + bsz, _, h, w = forward_args[0].size() p_h = int(h / w * n_patches ** .5) p_w = n_patches // p_h attrs = attrs.view(bsz, 1, p_h, p_w) @@ -141,27 +181,10 @@ def attribute( attrs = LayerAttribution.interpolate( layer_attribution=attrs, interpolate_dims=(h, w), - interpolate_mode=self.interpolate_mode, + interpolate_mode=self.interpolate_mode.current_value, ) return attrs - def get_tunables(self): - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `interpolate_mode` (str): Value can be selected of `"bilinear"` and `"bicubic"` - - `head_fusion_method` (str): Value can be selected of `"min"`, `"max"`, and `"mean"` - - `discard_ratio` (float): Value can be selected in the range of `range(0, 0.95, 0.05)` - """ - return { - 'interpolate_mode': (list, {'choices': ['bilinear', 'bicubic']}), - 'head_fusion_method': (list, {'choices': ['min', 'max', 'mean']}), - 'discard_ratio': (float, {'low': 0., 'high': .95, 'step': .05}), - } - class AttentionRollout(AttentionRolloutBase): """ @@ -182,30 +205,36 @@ class AttentionRollout(AttentionRolloutBase): Reference: Samira Abnar, Willem Zuidema. Quantifying Attention Flow in Transformers. """ + alias = ['attention_rollout', 'ar'] + def __init__( self, model: Module, - interpolate_mode: Literal['bilinear']='bilinear', - head_fusion_method: Literal['min', 'max', 'mean']='min', - discard_ratio: float=0.9, - forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - n_classes: Optional[int]=None, + interpolate_mode: Literal['bilinear'] = 'bilinear', + head_fusion_method: Literal['min', 'max', 'mean'] = 'min', + discard_ratio: float = 0.9, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], torch.Tensor]] = None, + n_classes: Optional[int] = None, ) -> None: super().__init__( model, interpolate_mode, head_fusion_method, discard_ratio, - forward_arg_extractor, - additional_forward_arg_extractor, + target_input_keys, + additional_input_keys, + output_modifier, n_classes ) def collect_attention_map(self, inputs, targets): + forward_args, additional_forward_args = self.format_inputs(inputs) # get all attn maps - with SavingAttentionAttributor(model=self.model) as attributor: - weights_all = attributor(inputs, None) + with SavingAttentionAttributor(model=self._wrapped_model) as attributor: + forward_args += additional_forward_args + weights_all = attributor(forward_args, None) return (weights_all,) def rollout(self, weights_all): @@ -247,36 +276,39 @@ class TransformerAttribution(AttentionRolloutBase): """ SUPPORTED_MODULES = [Attention] + alias = ['transformer_attribution', 'ta'] def __init__( self, model: Module, - interpolate_mode: Literal['bilinear']='bilinear', - head_fusion_method: Literal['min', 'max', 'mean']='mean', - discard_ratio: float=0.9, - alpha: float=2., - beta: float=1., - stabilizer: float=1e-6, - zennit_canonizers: Optional[List[Canonizer]]=None, - layer: Optional[Union[Module, Sequence[Module]]]=None, - forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - n_classes: Optional[int]=None + interpolate_mode: Literal['bilinear'] = 'bilinear', + head_fusion_method: Literal['min', 'max', 'mean'] = 'mean', + discard_ratio: float = 0.9, + alpha: float = 2., + beta: float = 1., + stabilizer: float = 1e-6, + zennit_canonizers: Optional[List[Canonizer]] = None, + target_layer: Optional[Union[Module, Sequence[Module]]] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], torch.Tensor]] = None, + n_classes: Optional[int] = None ) -> None: super().__init__( model, interpolate_mode, head_fusion_method, discard_ratio, - forward_arg_extractor, - additional_forward_arg_extractor, + target_input_keys, + additional_input_keys, + output_modifier, n_classes ) self.alpha = alpha self.beta = beta self.stabilizer = stabilizer self.zennit_canonizers = zennit_canonizers or [] - self.layer = layer + self.target_layer = target_layer @staticmethod def default_head_fusion_fn(attns): @@ -302,32 +334,34 @@ def zennit_composite(self): @property def _layer_gradient(self) -> LayerGradient: - wrapped_model = captum_wrap_model_input(self.model) + wrapped_model = ModelWrapperForLayerAttribution(self._wrapped_model) + format_into_tuple layers = [ - wrapped_model.input_maps[layer] if isinstance(layer, str) - else layer for layer in self.layer - ] if isinstance(self.layer, Sequence) else self.layer + wrapped_model.input_maps[target_layer] if isinstance(target_layer, str) + else target_layer for target_layer in format_into_tuple(self.target_layer) + ] + layers = format_out_tuple_if_single(layers) return LayerGradient( model=wrapped_model, - layer=layers, + target_layer=layers, composite=self.zennit_composite, ) @property def _gradient(self) -> Gradient: return Gradient( - model=self.model, + model=self._wrapped_model, composite=self.zennit_composite, ) @property def attributor(self): - if self.layer is None: + if self.target_layer is None: return self._gradient return self._layer_gradient def collect_attention_map(self, inputs, targets): - forward_args, additional_forward_args = self._extract_forward_args(inputs) + forward_args, additional_forward_args = self.format_inputs(inputs) with self.attributor as attributor: attributor.forward( forward_args=forward_args, @@ -360,10 +394,10 @@ class GenericAttention(AttentionRolloutBase): def __init__( self, model: Module, - alpha: float=2., - beta: float=1., - stabilizer: float=1e-6, - head_fusion_function: Optional[Callable[[Tensor], Tensor]]=None, - n_classes: Optional[int]=None - ) -> None: - raise NotImplementedError \ No newline at end of file + alpha: float = 2., + beta: float = 1., + stabilizer: float = 1e-6, + head_fusion_function: Optional[Callable[[Tensor], Tensor]] = None, + n_classes: Optional[int] = None + ) -> None: + raise NotImplementedError diff --git a/pnpxai/explainers/attentions/attributions.py b/pnpxai/explainers/attentions/attributions.py index eabea4ab..d4544620 100644 --- a/pnpxai/explainers/attentions/attributions.py +++ b/pnpxai/explainers/attentions/attributions.py @@ -5,18 +5,20 @@ from pnpxai.explainers.attentions.rules import SavingAttention from pnpxai.explainers.attentions.module_converters import default_attention_converters -from pnpxai.explainers.utils import _format_to_tuple +from pnpxai.utils import format_into_tuple class SavingAttentionAttributor(Attributor): def __init__(self, model: Module): layer_map = [(MultiheadAttention, SavingAttention())] - composite = LayerMapComposite(layer_map=layer_map, canonizers=default_attention_converters) + composite = LayerMapComposite( + layer_map=layer_map, + canonizers=default_attention_converters, + ) super().__init__(model, composite, None) - def forward(self, input, attr_output_fn): - input = _format_to_tuple(input) + input = format_into_tuple(input) _ = self.model(*input) attn_output_weights_all = [ hook_ref.stored_tensors[hook_ref.saved_name] diff --git a/pnpxai/explainers/base.py b/pnpxai/explainers/base.py index d27a0923..09c9d16a 100644 --- a/pnpxai/explainers/base.py +++ b/pnpxai/explainers/base.py @@ -1,17 +1,18 @@ +from typing import Tuple, Optional, Union, Type, Dict, List, Callable, Any import abc from abc import abstractmethod + import sys -from typing import Tuple, Optional, Union, Type, Dict -import math +import inspect +from collections import defaultdict -import copy from torch import Tensor from torch.nn.modules import Module +import optuna -from pnpxai.core._types import ExplanationType -from pnpxai.explainers.types import ForwardArgumentExtractor -from pnpxai.explainers.utils import UtilFunction, BaselineFunction, FeatureMaskFunction -from pnpxai.utils import format_into_tuple, format_out_tuple_if_single +from pnpxai.core.utils import ModelWrapper +from pnpxai.explainers.types import TunableParameter +from pnpxai.utils import generate_param_key # Ensure compatibility with Python 2/3 @@ -25,6 +26,8 @@ "device", "n_classes", "zennit_composite", + "_wrapped_model", + "_tunable_params", ] @@ -44,21 +47,35 @@ class Explainer(ABC): - Subclasses must implement the `attribute` method to define how attributions are computed. - The `forward_arg_extractor` and `additional_forward_arg_extractor` functions allow for customization in extracting forward arguments from the inputs. """ - - EXPLANATION_TYPE: ExplanationType = "attribution" - SUPPORTED_MODULES = [] - TUNABLES = {} + SUPPORTED_MODULES: List[Type[Module]] = [] + SUPPORTED_DTYPES: List[Type] = [] + SUPPORTED_NDIMS: List[int] = [] def __init__( self, model: Module, - forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, **kwargs, ) -> None: self.model = model.eval() - self.forward_arg_extractor = forward_arg_extractor - self.additional_forward_arg_extractor = additional_forward_arg_extractor + self.target_input_keys = target_input_keys + self.additional_input_keys = additional_input_keys + self.output_modifier = output_modifier + + @property + def _wrapped_model(self): + return ModelWrapper( + model=self.model, + target_input_keys=self.target_input_keys, + additional_input_keys=self.additional_input_keys, + output_modifier=self.output_modifier, + ) + + @property + def wrapped_model(self): + return self._wrapped_model @property def device(self): @@ -72,83 +89,11 @@ def __repr__(self): ) return "{}({})".format(self.__class__.__name__, kwargs_repr) - def _extract_forward_args( - self, inputs: Union[Tensor, Tuple[Tensor]] - ) -> Tuple[Union[Tensor, Tuple[Tensor], Type[None]]]: - forward_args = ( - self.forward_arg_extractor(inputs) if self.forward_arg_extractor else inputs - ) - additional_forward_args = ( - self.additional_forward_arg_extractor(inputs) - if self.additional_forward_arg_extractor - else None - ) + def format_inputs(self, inputs: Union[Tuple, Dict]): + forward_args = self._wrapped_model.format_target_inputs(inputs) + additional_forward_args = self._wrapped_model.format_additional_inputs(inputs) return forward_args, additional_forward_args - def copy(self): - return copy.copy(self) - - def set_kwargs(self, **kwargs): - clone = self.copy() - for k, v in kwargs.items(): - setattr(clone, k, v) - return clone - - def _load_util_fn( - self, util_attr: str, util_fn_class: Type[UtilFunction] - ) -> Optional[Union[UtilFunction, Tuple[UtilFunction]]]: - attr = getattr(self, util_attr) - if attr is None: - return None - - attr_values = [] - for attr_value in format_into_tuple(attr): - if isinstance(attr_value, str): - attr_value = util_fn_class.from_method(method=attr_value) - attr_values.append(attr_value) - attr_values = tuple(attr_values) - return format_out_tuple_if_single(attr_values) - - def _get_baselines(self, forward_args) -> Union[Tensor, Tuple[Tensor]]: - baseline_fns = self._load_util_fn("baseline_fn", BaselineFunction) - if baseline_fns is None: - return None - - forward_args = format_into_tuple(forward_args) - baseline_fns = format_into_tuple(baseline_fns) - - assert len(forward_args) == len(baseline_fns) - baselines = tuple( - baseline_fn(forward_arg) - for baseline_fn, forward_arg in zip(baseline_fns, forward_args) - ) - return format_out_tuple_if_single(baselines) - - def _get_feature_masks(self, forward_args) -> Union[Tensor, Tuple[Tensor]]: - feature_mask_fns = self._load_util_fn("feature_mask_fn", FeatureMaskFunction) - if feature_mask_fns is None: - return None - - feature_mask_fns = format_into_tuple(feature_mask_fns) - forward_args = format_into_tuple(forward_args) - - assert len(forward_args) == len(feature_mask_fns) - feature_masks = [] - max_vals = None - for feature_mask_fn, forward_arg in zip(feature_mask_fns, forward_args): - feature_mask = feature_mask_fn(forward_arg) - if max_vals is not None: - feature_mask += ( - max_vals[(...,) + (None,) * (feature_mask.dim() - 1)] + 1 - ) - feature_masks.append(feature_mask) - - # update max_vals - bsz, *size = feature_mask.size() - max_vals = feature_mask.view(-1, math.prod(size)).max(axis=1).values - feature_masks = tuple(feature_masks) - return format_out_tuple_if_single(feature_masks) - @abstractmethod def attribute( self, @@ -167,11 +112,132 @@ def attribute( """ raise NotImplementedError - def get_tunables(self) -> Dict[str, Tuple[type, dict]]: - """ - Returns a dictionary of tunable parameters for the explainer. + @classmethod + def is_tunable(cls): + return issubclass(cls, Tunable) - Returns: - Dict[str, Tuple[type, dict]]: Dictionary of tunable parameters. - """ - return {} + +class Tunable: + def __init__(self, params: Optional[List[TunableParameter]] = None): + self._tunable_params = params or [] + + @property + def tunable_params(self): + return [tp for tp in self._tunable_params] + + @property + def non_tunable_params(self): + tunable_params = [tp.name.split('.')[0] for tp in self.tunable_params] + required_params = [ + param for param in inspect.signature(self.__class__).parameters] + return [param for param in required_params if param not in tunable_params] + + def get_current_tunable_param_values(self): + tps = {} + for tp in self._tunable_params: + if tp.is_leaf: + tps[tp.name] = tp.current_value + continue + if not hasattr(tp, 'selector'): + tps[tp.name] = tp.current_value + continue + tps[tp.name] = tp.selector.get_key(tp.current_value.__class__) + if isinstance(tp.current_value, Tunable): + leaf_values = tp.current_value.get_current_tunable_param_values() + for k, v in leaf_values.items(): + tps[f'{tp.name}.{k}'] = v + return tps + + def register_tunable_params( + self, + params: List[Union[TunableParameter, Tuple[TunableParameter]]], + ): + for param in params: + self._register_tunable_param(param) + + def _register_tunable_param( + self, + param: Union[TunableParameter, Tuple[TunableParameter]], + key=None, + ): + if isinstance(param, tuple): + for i, _param in enumerate(param): + self._register_tunable_param(_param, i) + else: + if key is not None: + param.rename(f'{param.name}.{key}') + self._tunable_params.append(param) + + def disable_tunable_param(self, param_name): + for tp in self._tunable_params: + if tp.name.startswith(param_name): + tp.disable() + + def enable_tunable_param(self, param_name): + for tp in self._tunable_params: + if tp.name.startswith(param_name): + tp.enable() + + def suggest(self, trial: optuna.Trial, key=None): + suggested = defaultdict(tuple) + disabled_nms = [] + for tp in self.tunable_params: + if not tp.disabled: + suggest = { + str: trial.suggest_categorical, + int: trial.suggest_int, + float: trial.suggest_float + }.get(tp.dtype) + suggested_value = suggest( + name=generate_param_key(key, tp.name), + **tp.space + ) + if not tp.is_leaf: # maybe util functions + # init default util function + suggested_value = tp.selector.select(suggested_value) + # recursively suggest + if isinstance(suggested_value, Tunable): + suggested_value = suggested_value.suggest( + trial, key=generate_param_key(key, tp.name), + ) + nm, *keys = tp.name.split('.') + if keys: + suggested[nm] += (suggested_value,) + else: + suggested[nm] = suggested_value + else: + suggested[tp.name] = tp.current_value + disabled_nms.append(tp.name) + + non_tunable_params = {k: getattr(self, k) for k in self.non_tunable_params} + suggested_obj = self.__class__( + **non_tunable_params, + **suggested, + ) + for nm in disabled_nms: + suggested_obj.disable_tunable_param(nm) + return suggested_obj + + + # def get_tunable_param(self, key, is_leaf=True): + # for tp in self.tunable_params: + # if tp.name == key: + # return tp, is_leaf + # return self.get_tunable_param('.'.join(key.split('.')[:-1]), False) + + def get_tunable_param(self, key, parent=False): + for tp in self.tunable_params: + if tp.name == key: + return tp + if not parent: + return None + return self.get_tunable_param('.'.join(key.split('.')[:-1])) + + def update_current_value(self, key, value): + tp = self.get_tunable_param(key) + if tp is None: + parent_tp = self.get_tunable_param(key, parent=True) + parent_tp.current_value.update_current_value(key.split('.')[-1], value) + else: + value = value if tp.is_leaf else tp.selector.select(value) + tp.update_value(value) diff --git a/pnpxai/explainers/deep_lift_shap.py b/pnpxai/explainers/deep_lift_shap.py index 0a4e0011..4281206c 100644 --- a/pnpxai/explainers/deep_lift_shap.py +++ b/pnpxai/explainers/deep_lift_shap.py @@ -6,7 +6,6 @@ from pnpxai.core.detector.types import Convolution from pnpxai.explainers.base import Explainer from torch.nn.modules import Module -from pnpxai.explainers.base import Explainer class DeepLiftShap(Explainer): @@ -23,6 +22,9 @@ class DeepLiftShap(Explainer): Scott M. Lundberg, Su-In Lee, A Unified Approach to Interpreting Model Predictions """ SUPPORTED_MODULES = [Convolution] + SUPPORTED_DTYPES = [float] + SUPPORTED_NDIMS = [2] + alias = ['deep_lift_shap', 'dls'] def __init__( self, diff --git a/pnpxai/explainers/full_grad.py b/pnpxai/explainers/full_grad.py index 9168b51c..37af29ba 100644 --- a/pnpxai/explainers/full_grad.py +++ b/pnpxai/explainers/full_grad.py @@ -1,10 +1,12 @@ -from typing import Optional, Literal +from typing import Optional, Literal, List, Union, Callable, Any import torchvision.transforms.functional as TF from torch.nn.modules import Module from pnpxai.core.detector.types import Convolution from pnpxai.core._types import Tensor +from pnpxai.explainers.base import Tunable +from pnpxai.explainers.types import TunableParameter from pnpxai.explainers.zennit.base import ZennitExplainer from pnpxai.explainers.zennit.attribution import FullGradient as FullGradAttributor @@ -21,14 +23,14 @@ def _format_pooling_method(method): def _format_interpolate_mode(mode): if mode == 'nearest': return TF.InterpolationMode.NEAREST - elif mode == 'nearest_exact': + elif mode == 'nearest-exact': return TF.InterpolationMode.NEAREST_EXACT elif mode == 'bicubic': return TF.InterpolationMode.BICUBIC return TF.InterpolationMode.BILINEAR -class FullGrad(ZennitExplainer): +class FullGrad(ZennitExplainer, Tunable): """ FullGrad explainer. @@ -37,26 +39,52 @@ class FullGrad(ZennitExplainer): Parameters: model (Module): The PyTorch model for which attribution is to be computed. pooling_method (Optional[str]): The pooling mode used by the explainer. Available methods are: `"abssum"` (absolute sum) and `"possum"` (positive sum) - interpolate_mode (Optional[str]): The interpolation mode used by the explainer. Available methods are: `"bilinear"`, `"nearest"`, `"nearest_exact"`, and `"bicubic"` + interpolate_mode (Optional[str]): The interpolation mode used by the explainer. Available methods are: `"bilinear"`, `"nearest"`, `"nearest-exact"`, and `"bicubic"` n_classes (Optional[int]): The number of classes **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer Reference: Suraj Srinivas, Francois Fleuret. Full-Gradient Representation for Neural Network Visualization. """ - SUPPORTED_MODULES = [Convolution] - + SUPPORTED_DTYPES = [float] + SUPPORTED_NDIMS = [4] + alias = ['full_grad', 'fg'] + def __init__( self, model: Module, - pooling_method: Literal['abssum', 'possum']='abssum', - interpolate_mode: Literal['bilinear', 'nearest', 'nearest_exact', 'bicubic']='blinear', - n_classes: Optional[int]=None, + pooling_method: Literal['abssum', 'possum'] = 'abssum', + interpolate_mode: Literal['bilinear', 'bicubic', 'nearest', 'nearest-exact'] = 'blinear', + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + n_classes: Optional[int] = None, ): - super().__init__(model=model, n_classes=n_classes) - self.pooling_method = pooling_method - self.interpolate_mode = interpolate_mode + self.pooling_method = TunableParameter( + name='pooling_method', + current_value=pooling_method, + dtype=str, + is_leaf=True, + space={'choices': ['abssum', 'possum']}, + ) + self.interpolate_mode = TunableParameter( + name='interpolate_mode', + current_value=interpolate_mode, + dtype=str, + is_leaf=True, + space={'choices': ['bilinear', 'bicubic', 'nearest', 'nearest-exact']}, + ) + ZennitExplainer.__init__( + self, + model, + target_input_keys, + additional_input_keys, + output_modifier, + n_classes, + ) + Tunable.__init__(self) + self.register_tunable_params([self.pooling_method, self.interpolate_mode]) def attribute(self, inputs: Tensor, targets: Tensor): """ @@ -71,22 +99,10 @@ def attribute(self, inputs: Tensor, targets: Tensor): """ with FullGradAttributor( model=self.model, - pooling_method=_format_pooling_method(self.pooling_method), - interpolate_mode=_format_interpolate_mode(self.interpolate_mode), + pooling_method=_format_pooling_method( + self.pooling_method.current_value), + interpolate_mode=_format_interpolate_mode( + self.interpolate_mode.current_value), ) as attributor: attrs = attributor(inputs, self._format_targets(targets)) return attrs - - def get_tunables(self): - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `pooling_method` (str): Value can be selected of `"abssum"` and `"possum"` - - `interpolate_mode` (str): Value can be selected of `"bilinear"`, `"nearest"`, `"nearest_exact"`, and `"bicubic"` - """ - return { - 'pooling_method': (list, {'choices': ['abssum', 'possum']}), - 'interpolate_mode': (list, {'choices': ['bilinear', 'nearest', 'nearest_exact', 'bicubic']}), - } diff --git a/pnpxai/explainers/grad_cam.py b/pnpxai/explainers/grad_cam.py index 22ede9aa..d81be453 100644 --- a/pnpxai/explainers/grad_cam.py +++ b/pnpxai/explainers/grad_cam.py @@ -1,15 +1,15 @@ -from typing import Dict, Tuple +from typing import Optional, List, Union, Any, Callable from torch import Tensor, nn from captum.attr import LayerGradCam, LayerAttribution -from pnpxai.utils import format_into_tuple from pnpxai.core.detector.types import Convolution -from pnpxai.explainers.base import Explainer +from pnpxai.explainers.base import Explainer, Tunable from pnpxai.explainers.utils import find_cam_target_layer +from pnpxai.explainers.types import TunableParameter from pnpxai.explainers.errors import NoCamTargetLayerAndNotTraceableError -class GradCam(Explainer): +class GradCam(Explainer, Tunable): """ GradCAM explainer. @@ -25,18 +25,40 @@ class GradCam(Explainer): """ SUPPORTED_MODULES = [Convolution] + SUPPORTED_DTYPES = [float] + SUPPORTED_NDIMS = [2, 4] + alias = ['grad_cam', 'gcam'] def __init__( - self, model: nn.Module, interpolate_mode: str = "bilinear", **kwargs + self, + model: nn.Module, + interpolate_mode: str = "bilinear", + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, ) -> None: - super().__init__(model, **kwargs) - self.interpolate_mode = interpolate_mode + self.interpolate_mode = TunableParameter( + name='interpolate_mode', + current_value=interpolate_mode, + dtype=str, + is_leaf=True, + space={'choices': ['bilinear', 'bicubic', 'nearest', 'nearest-exact']}, + ) + Explainer.__init__( + self, + model, + target_input_keys, + additional_input_keys, + output_modifier, + ) + Tunable.__init__(self) + self.register_tunable_params([self.interpolate_mode]) @property def layer(self): try: return self._layer or find_cam_target_layer(self.model) - except: + except Exception: raise NoCamTargetLayerAndNotTraceableError( 'You did not set cam target layer and', 'it does not automatically determined.', @@ -59,10 +81,7 @@ def attribute(self, inputs: Tensor, targets: Tensor) -> Tensor: Returns: torch.Tensor: The result of the explanation. """ - forward_args, additional_forward_args = self._extract_forward_args(inputs) - forward_args = format_into_tuple(forward_args) - additional_forward_args = format_into_tuple(additional_forward_args) - + forward_args, additional_forward_args = self.format_inputs(inputs) assert ( len(forward_args) == 1 ), "GradCam for multiple inputs is not supported yet." @@ -78,17 +97,6 @@ def attribute(self, inputs: Tensor, targets: Tensor) -> Tensor: upsampled = LayerAttribution.interpolate( layer_attribution=attrs, interpolate_dims=forward_args[0].shape[2:], - interpolate_mode=self.interpolate_mode, + interpolate_mode=self.interpolate_mode.current_value, ) - return upsampled - - def get_tunables(self) -> Dict[str, Tuple[type, dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `interpolate_mode` (str): Value can be selected of `"bilinear"` and `"bicubic"` - """ - return { - "interpolate_mode": (list, {"choices": ["bilinear", "bicubic"]}), - } + return upsampled \ No newline at end of file diff --git a/pnpxai/explainers/grad_x_input.py b/pnpxai/explainers/grad_x_input.py index e7942a16..97b90dc2 100644 --- a/pnpxai/explainers/grad_x_input.py +++ b/pnpxai/explainers/grad_x_input.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, List, Tuple, Union, Sequence +from typing import Callable, Optional, List, Tuple, Union, Sequence, Any from torch import Tensor from torch.nn.modules import Module @@ -7,7 +7,9 @@ from pnpxai.core.detector.types import Linear, Convolution, LSTM, RNN, Attention from pnpxai.explainers.base import Explainer -from pnpxai.explainers.utils import captum_wrap_model_input +from pnpxai.explainers.utils import ModelWrapperForLayerAttribution +from pnpxai.explainers.types import TargetLayerOrTupleOfTargetLayers +from pnpxai.utils import format_into_tuple, format_out_tuple_if_single class GradientXInput(Explainer): @@ -18,7 +20,7 @@ class GradientXInput(Explainer): Parameters: model (Module): The PyTorch model for which attribution is to be computed. - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained forward_arg_extractor: A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned. additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s). **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer @@ -28,25 +30,34 @@ class GradientXInput(Explainer): """ SUPPORTED_MODULES = [Linear, Convolution, LSTM, RNN, Attention] + SUPPORTED_DTYPES = [float, int] + SUPPORTED_NDIMS = [2, 4] + alias = ['grad_x_input', 'gi'] def __init__( self, model: Module, - layer: Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]] = None, - forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]] = None, - additional_forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]] = None, + target_layer: TargetLayerOrTupleOfTargetLayers = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, ) -> None: - super().__init__(model, forward_arg_extractor, additional_forward_arg_extractor) - self.layer = layer - + super().__init__( + model, + target_input_keys, + additional_input_keys, + output_modifier, + ) + self.target_layer = target_layer @property def _layer_explainer(self) -> CaptumLayerGradientXInput: - wrapped_model = captum_wrap_model_input(self.model) + wrapped_model = ModelWrapperForLayerAttribution(self._wrapped_model) layers = [ - wrapped_model.input_maps[layer] if isinstance(layer, str) - else layer for layer in self.layer - ] if isinstance(self.layer, Sequence) else self.layer + wrapped_model.input_maps[target_layer] if isinstance(target_layer, str) + else target_layer for target_layer in format_into_tuple(self.target_layer) + ] + layers = format_out_tuple_if_single(layers) return CaptumLayerGradientXInput( forward_func=wrapped_model, layer=layers, @@ -54,11 +65,11 @@ def _layer_explainer(self) -> CaptumLayerGradientXInput: @property def _explainer(self) -> CaptumGradientXInput: - return CaptumGradientXInput(forward_func=self.model) + return CaptumGradientXInput(forward_func=self._wrapped_model) @property def explainer(self) -> Union[CaptumGradientXInput, CaptumLayerGradientXInput]: - if self.layer is None: + if self.target_layer is None: return self._explainer return self._layer_explainer @@ -78,7 +89,7 @@ def attribute( Union[torch.Tensor, Tuple[torch.Tensor]]: The result of the explanation. """ - forward_args, additional_forward_args = self._extract_forward_args(inputs) + forward_args, additional_forward_args = self.format_inputs(inputs) attrs = self.explainer.attribute( inputs=forward_args, target=targets, @@ -89,80 +100,3 @@ def attribute( if isinstance(attrs, tuple) and len(attrs) == 1: attrs = attrs[0] return attrs - - -""" -Class for computing Integrated Gradients attributions for a given model and layer. - -Args: - model (torch.nn.Module): The PyTorch model to explain. - layer (torch.nn.Module or List[str or torch.nn.Module]): - The layer(s) for which to compute attributions. To target an input layer, a string of argument name - is available. - forward_arg_extractor (Callable[[Tuple[Tensor]], Tensor or Tuple[Tensor]] or None, optional): - A function to extract arguments for each target layer from the tuple of inputs. - Defaults to None. - additional_forward_arg_extractor (Callable[[Tuple[Tensor]], Tensor or Tuple[Tensor]] or None, optional): - A function to extract additional arguments not to be forwarded through the target layer. - Defaults to None. - baseline_fn (Callable or List[Callable] or None, optional): - The function(s) to generate baseline of forward arguments. Must have same length as a tuple of - forward arguments extracted by `forward_arg_extractor`. Defaults to None. - n_step (int, optional): The number of steps for numerical approximation. Defaults to 20. - -Raises: - AssertionError: If the type of the `layer` or `baseline_fn` argument is not correct. - -Example: - For a given VQA model and dataset, - ``` - # ./models.py - - class MyVQAModel(Module): - ... - - def forward(img, qst, qst_len): - x = self.vision_model(img) - embedded = self.embedding(qst) - y = self.question_model(embedded, qst_len) - z = self.answer_model(x, y) - return z - - # ./dataset.py - class MyVQADataset(Dataset): - ... - - def __getitem__(self, idx): - ... - return img, qst, qst_len - - ``` - - Computes Integrated Gradients attributions for input image and embedded question. - - ``` - # model and data - model = MyVQAModel().eval() - dataloader = DataLoader(MyVQADataset(), batch_size=8, shuffle=False) - imgs, qsts, qst_lens, answers = next(iter(dataloader)) - inputs = (imgs, qsts, qst_lens) - outputs = model(*inputs) - targets = outputs.argmax(1) - - # explainer - layer_ig = LayerIntegratedGradients( - model=model, - layer=["img", model.embedding], - forward_arg_extractor=lambda inputs: tuple(inputs[:2]), # (imgs, qsts) - additional_forward_arg_extractor=lambda inputs: inputs[-1], # qst_lens - baseline_fn=[ - lambda imgs: torch.zeros_like(imgs), # baseline function for images - lambda qsts: torch.zeros_like(qsts).long(), # baseline function for questions - ], - n_step=20, - ) - - # attribute - img_attrs, qst_attrs = layer_ig.attribute(inputs, targets) - ``` -""" \ No newline at end of file diff --git a/pnpxai/explainers/gradient.py b/pnpxai/explainers/gradient.py index e78cda87..a8ce7cfb 100644 --- a/pnpxai/explainers/gradient.py +++ b/pnpxai/explainers/gradient.py @@ -1,13 +1,15 @@ -from typing import Tuple, Callable, Sequence, Union, Optional +from typing import Tuple, Callable, Sequence, Union, Optional, Any, List from torch import Tensor from torch.nn.modules import Module from pnpxai.core.detector.types import Linear, Convolution, LSTM, RNN, Attention +from pnpxai.utils import format_into_tuple, format_out_tuple_if_single from pnpxai.explainers.zennit.attribution import Gradient as GradientAttributor from pnpxai.explainers.zennit.attribution import LayerGradient as LayerGradientAttributor from pnpxai.explainers.zennit.base import ZennitExplainer -from pnpxai.explainers.utils import captum_wrap_model_input +from pnpxai.explainers.utils import ModelWrapperForLayerAttribution +from pnpxai.explainers.types import TargetLayerOrTupleOfTargetLayers class Gradient(ZennitExplainer): @@ -18,8 +20,8 @@ class Gradient(ZennitExplainer): Parameters: model (Module): The PyTorch model for which attribution is to be computed. - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained. - n_classes (int): The number of classes. + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained. + n_classes (int): The number of classes. forward_arg_extractor: A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned. additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s). **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer @@ -28,32 +30,36 @@ class Gradient(ZennitExplainer): Gabriel Erion, Joseph D. Janizek, Pascal Sturmfels, Scott Lundberg, Su-In Lee. Improving performance of deep learning models with axiomatic attribution priors and expected gradients. """ SUPPORTED_MODULES = [Linear, Convolution, LSTM, RNN, Attention] + SUPPORTED_DTYPES = [float, int] + SUPPORTED_NDIMS = [2, 4] + alias = ['gradient', 'vanilla_gradient'] def __init__( self, model: Module, - forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]=None, - additional_forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]=None, - layer: Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]=None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, n_classes: Optional[int] = None, ) -> None: super().__init__( model, - forward_arg_extractor, - additional_forward_arg_extractor, + target_input_keys, + additional_input_keys, + output_modifier, n_classes ) - self.layer = layer + self.target_layer = target_layer @property def _layer_attributor(self) -> LayerGradientAttributor: - wrapped_model = captum_wrap_model_input(self.model) + wrapped_model = ModelWrapperForLayerAttribution(self._wrapped_model) layers = [ - wrapped_model.input_maps[layer] if isinstance(layer, str) - else layer for layer in self.layer - ] if isinstance(self.layer, Sequence) else [self.layer] - if len(layers) == 1: - layers = layers[0] + wrapped_model.input_maps[target_layer] if isinstance(target_layer, str) + else target_layer for target_layer in format_into_tuple(self.target_layer) + ] + layers = format_out_tuple_if_single(layers) return LayerGradientAttributor( model=wrapped_model, layer=layers, @@ -61,11 +67,11 @@ def _layer_attributor(self) -> LayerGradientAttributor: @property def _attributor(self) -> GradientAttributor: - return GradientAttributor(self.model) + return GradientAttributor(self._wrapped_model) @property def attributor(self) -> Union[GradientAttributor, LayerGradientAttributor]: - if self.layer is None: + if self.target_layer is None: return self._attributor return self._layer_attributor @@ -84,10 +90,10 @@ def attribute( Returns: Union[torch.Tensor, Tuple[torch.Tensor]]: The result of the explanation. """ - forward_args, additional_forward_args = self._extract_forward_args(inputs) + forward_args, additional_forward_args = self.format_inputs(inputs) attrs = self.attributor.forward( forward_args, targets, additional_forward_args, ) - return attrs \ No newline at end of file + return attrs diff --git a/pnpxai/explainers/guided_backprop.py b/pnpxai/explainers/guided_backprop.py index e90c468a..b651745c 100644 --- a/pnpxai/explainers/guided_backprop.py +++ b/pnpxai/explainers/guided_backprop.py @@ -7,10 +7,7 @@ from zennit.types import Linear from pnpxai.explainers.lrp import LRPBase, canonizers_base -from pnpxai.explainers.types import ( - ForwardArgumentExtractor, - TargetLayer, -) +from pnpxai.explainers.types import TargetLayer class GuidedBackpropRule(BasicHook): @@ -30,18 +27,20 @@ def __init__( self, model: Module, stabilizer: float=1e-6, - forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None, - layer: Optional[TargetLayer]=None, + target_input_keys: Optional[List[Union[str, int]]]=None, + additional_input_keys: Optional[List[Union[str, int]]]=None, + output_modifier: Optional[Callable[[Any], torch.Tensor]]=None, + target_layer: Optional[TargetLayer]=None, n_classes: Optional[int]=None, ) -> None: self.stabilizer = stabilizer super().__init__( model, self._composite, - forward_arg_extractor, - additional_forward_arg_extractor, - layer, + target_input_keys, + additional_input_keys, + output_modifier, + target_layer, n_classes ) diff --git a/pnpxai/explainers/guided_grad_cam.py b/pnpxai/explainers/guided_grad_cam.py index 6fdd360b..4454ef21 100644 --- a/pnpxai/explainers/guided_grad_cam.py +++ b/pnpxai/explainers/guided_grad_cam.py @@ -1,16 +1,16 @@ -from typing import Dict, Optional, Tuple +from typing import Optional, List, Union, Any, Callable from torch import Tensor from torch.nn.modules import Module from captum.attr import GuidedGradCam as CaptumGuidedGradCam from pnpxai.core.detector.types import Convolution -from pnpxai.explainers.base import Explainer +from pnpxai.explainers.base import Explainer, Tunable +from pnpxai.explainers.types import TunableParameter from pnpxai.explainers.utils import find_cam_target_layer -from pnpxai.utils import format_into_tuple from pnpxai.explainers.errors import NoCamTargetLayerAndNotTraceableError -class GuidedGradCam(Explainer): +class GuidedGradCam(Explainer, Tunable): """ GuidedGradCam explainer. @@ -26,22 +26,42 @@ class GuidedGradCam(Explainer): """ SUPPORTED_MODULES = [Convolution] + SUPPORTED_DTYPES = [float] + SUPPORTED_NDIMS = [4] + alias = ['guided_grad_cam', 'ggcam'] def __init__( self, model: Module, layer: Optional[Module] = None, interpolate_mode: str = "nearest", + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, ) -> None: - super().__init__(model) + Explainer.__init__( + self, + model, + target_input_keys, + additional_input_keys, + output_modifier, + ) self._layer = layer - self.interpolate_mode = interpolate_mode + self.interpolate_mode = TunableParameter( + name='interpolate_mode', + current_value=interpolate_mode, + dtype=str, + is_leaf=True, + space={'choices': ['nearest', 'area']} + ) + Tunable.__init__(self) + self.register_tunable_params([self.interpolate_mode]) @property def layer(self): try: return self._layer or find_cam_target_layer(self.model) - except: + except Exception: raise NoCamTargetLayerAndNotTraceableError( 'You did not set cam target layer and', 'it does not automatically determined.', @@ -64,27 +84,14 @@ def attribute(self, inputs: Tensor, targets: Tensor) -> Tensor: Returns: torch.Tensor: The result of the explanation. """ - forward_args, additional_forward_args = self._extract_forward_args( - inputs) - forward_args = format_into_tuple(forward_args) - additional_forward_args = format_into_tuple(additional_forward_args) - assert len( - forward_args) == 1, 'GuidedGradCam for multiple inputs is not supported yet.' + forward_args, additional_forward_args = self.format_inputs(inputs) + assert len(forward_args) == 1, ( + 'GuidedGradCam for multiple inputs is not supported yet.', + ) explainer = CaptumGuidedGradCam(model=self.model, layer=self.layer) attrs = explainer.attribute( inputs=forward_args[0], target=targets, - interpolate_mode=self.interpolate_mode, + interpolate_mode=self.interpolate_mode.current_value, ) return attrs - - def get_tunables(self) -> Dict[str, Tuple[type, dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `interpolate_mode` (str): Value can be selected of `"bilinear"` and `"bicubic"` - """ - return { - 'interpolate_mode': (list, {'choices': ['nearest', 'area']}), - } diff --git a/pnpxai/explainers/integrated_gradients.py b/pnpxai/explainers/integrated_gradients.py index f2c7f117..064e2b07 100644 --- a/pnpxai/explainers/integrated_gradients.py +++ b/pnpxai/explainers/integrated_gradients.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Tuple, Union, Sequence, Dict +from typing import Callable, Optional, Tuple, Union, Sequence, List, Any from torch import Tensor from torch.nn.modules import Module @@ -6,13 +6,23 @@ from captum.attr import LayerIntegratedGradients as CaptumLayerIntegratedGradients from pnpxai.core.detector.types import Linear, Convolution, Attention -from pnpxai.utils import format_into_tuple, format_out_tuple_if_single -from pnpxai.explainers.utils.baselines import BaselineMethodOrFunction, BaselineFunction -from pnpxai.explainers.base import Explainer -from pnpxai.explainers.utils import captum_wrap_model_input - - -class IntegratedGradients(Explainer): +from pnpxai.utils import ( + format_multimodal_supporting_input, + run_multimodal_supporting_util_fn, + format_into_tuple, + format_out_tuple_if_single, +) +from pnpxai.explainers.base import Explainer, Tunable +from pnpxai.explainers.types import ( + TargetLayerOrTupleOfTargetLayers, + TunableParameter, +) +from pnpxai.explainers.utils import ModelWrapperForLayerAttribution +from pnpxai.explainers.utils.types import BaselineFunctionOrTupleOfBaselineFunctions +from pnpxai.explainers.utils.baselines import ZeroBaselineFunction + + +class IntegratedGradients(Explainer, Tunable): """ IntegratedGradients explainer. @@ -22,7 +32,7 @@ class IntegratedGradients(Explainer): model (Module): The PyTorch model for which attribution is to be computed. baseline_fn (Union[BaselineMethodOrFunction, Tuple[BaselineMethodOrFunction]]): The baseline function, accepting the attribution input, and returning the baseline accordingly. n_steps (int): The Number of steps the algorithm makes - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained n_classes (Optional[int]): Number of classes forward_arg_extractor: A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned. additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s). @@ -33,32 +43,55 @@ class IntegratedGradients(Explainer): """ SUPPORTED_MODULES = [Linear, Convolution, Attention] + SUPPORTED_DTYPES = [float, int] + SUPPORTED_NDIMS = [2, 4] + alias = ['integrated_gradients', 'ig'] def __init__( self, model: Module, n_steps: int = 20, - baseline_fn: Union[BaselineMethodOrFunction, - Tuple[BaselineMethodOrFunction]] = 'zeros', - layer: Optional[Callable[[Tuple[Tensor]], - Union[Tensor, Tuple[Tensor]]]] = None, - forward_arg_extractor: Optional[Callable[[ - Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]] = None, - additional_forward_arg_extractor: Optional[Callable[[ - Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]] = None, + baseline_fn: Optional[BaselineFunctionOrTupleOfBaselineFunctions] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, ) -> None: - super().__init__(model, forward_arg_extractor, additional_forward_arg_extractor) - self.layer = layer - self.n_steps = n_steps - self.baseline_fn = baseline_fn + self.target_layer = target_layer + self.n_steps = TunableParameter( + name='n_steps', + current_value=n_steps, + dtype=int, + is_leaf=True, + space={'low': 10, 'high': 100, 'step': 10}, + ) + baseline_fn = baseline_fn or ZeroBaselineFunction() + self.baseline_fn = format_multimodal_supporting_input( + baseline_fn, + format=TunableParameter, + input_key='current_value', + name='baseline_fn', + dtype=str, + is_leaf=False, + ) + Explainer.__init__( + self, + model, + target_input_keys, + additional_input_keys, + output_modifier, + ) + Tunable.__init__(self) + self.register_tunable_params([self.n_steps, self.baseline_fn]) @property def _layer_explainer(self) -> CaptumLayerIntegratedGradients: - wrapped_model = captum_wrap_model_input(self.model) + wrapped_model = ModelWrapperForLayerAttribution(self._wrapped_model) layers = [ - wrapped_model.input_maps[layer] if isinstance(layer, str) - else layer for layer in self.layer - ] if isinstance(self.layer, Sequence) else self.layer + wrapped_model.input_maps[target_layer] if isinstance(target_layer, str) + else target_layer for target_layer in format_into_tuple(self.target_layer) + ] + layers = format_out_tuple_if_single(layers) return CaptumLayerIntegratedGradients( forward_func=wrapped_model, layer=layers, @@ -66,11 +99,11 @@ def _layer_explainer(self) -> CaptumLayerIntegratedGradients: @property def _explainer(self) -> CaptumIntegratedGradients: - return CaptumIntegratedGradients(forward_func=self.model) + return CaptumIntegratedGradients(forward_func=self._wrapped_model) @property def explainer(self) -> Union[CaptumIntegratedGradients, CaptumLayerIntegratedGradients]: - if self.layer is None: + if self.target_layer is None: return self._explainer return self._layer_explainer @@ -89,31 +122,15 @@ def attribute( Returns: Union[torch.Tensor, Tuple[torch.Tensor]]: The result of the explanation. """ - forward_args, additional_forward_args = self._extract_forward_args( - inputs) - forward_args = format_into_tuple(forward_args) - baselines = format_into_tuple(self._get_baselines(forward_args)) + forward_args, additional_forward_args = self.format_inputs(inputs) + baselines = run_multimodal_supporting_util_fn(forward_args, self.baseline_fn) attrs = self.explainer.attribute( inputs=forward_args, baselines=baselines, target=targets, additional_forward_args=additional_forward_args, - n_steps=self.n_steps, + n_steps=self.n_steps.current_value, ) if isinstance(attrs, tuple): attrs = format_out_tuple_if_single(attrs) return attrs - - def get_tunables(self) -> Dict[str, Tuple[type, dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `noise_level` (float): Value can be selected in the range of `range(10, 100, 10)` - - `baseline_fn` (callable): BaselineFunction selects suitable values in accordance with the modality - """ - return { - 'n_steps': (int, {'low': 10, 'high': 100, 'step': 10}), - 'baseline_fn': (BaselineFunction, {}), - } diff --git a/pnpxai/explainers/kernel_shap.py b/pnpxai/explainers/kernel_shap.py index 20566364..395a7ba7 100644 --- a/pnpxai/explainers/kernel_shap.py +++ b/pnpxai/explainers/kernel_shap.py @@ -1,18 +1,25 @@ -from typing import Tuple, Union, Optional, Dict +from typing import Tuple, Union, Optional, Dict, List, Any, Callable from torch import Tensor from torch.nn.modules import Module from captum.attr import KernelShap as CaptumKernelShap from pnpxai.core.detector.types import Linear, Convolution, LSTM, RNN, Attention -from pnpxai.explainers.base import Explainer -from pnpxai.explainers.types import ForwardArgumentExtractor -from pnpxai.explainers.utils.baselines import BaselineMethodOrFunction, BaselineFunction -from pnpxai.explainers.utils.feature_masks import FeatureMaskMethodOrFunction, FeatureMaskFunction -from pnpxai.utils import format_into_tuple, format_out_tuple_if_single - - -class KernelShap(Explainer): +from pnpxai.utils import ( + format_multimodal_supporting_input, + run_multimodal_supporting_util_fn, +) +from pnpxai.explainers.base import Explainer, Tunable +from pnpxai.explainers.types import TunableParameter +from pnpxai.explainers.utils.types import ( + BaselineFunctionOrTupleOfBaselineFunctions, + FeatureMaskFunctionOrTupleOfFeatureMaskFunctions, +) +from pnpxai.explainers.utils.baselines import ZeroBaselineFunction +from pnpxai.explainers.utils.feature_masks import Felzenszwalb + + +class KernelShap(Explainer, Tunable): """ KernelSHAP explainer. @@ -23,7 +30,6 @@ class KernelShap(Explainer): n_samples (int): Number of samples baseline_fn (Union[BaselineMethodOrFunction, Tuple[BaselineMethodOrFunction]]): The baseline function, accepting the attribution input, and returning the baseline accordingly. feature_mask_fn (Union[FeatureMaskMethodOrFunction, Tuple[FeatureMaskMethodOrFunction]): The feature mask function, accepting the attribution input, and returning the feature mask accordingly. - mask_token_id (Optional[int]): The token id of the mask, used for modalities, utilizing tokenization forward_arg_extractor: A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned. additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s). **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer @@ -33,28 +39,53 @@ class KernelShap(Explainer): """ SUPPORTED_MODULES = [Linear, Convolution, LSTM, RNN, Attention] + SUPPORTED_DTYPES = [float, int] + SUPPORTED_NDIMS = [2, 4] + alias = ['kernel_shap', 'ks'] def __init__( self, model: Module, n_samples: int = 25, - baseline_fn: Union[BaselineMethodOrFunction, - Tuple[BaselineMethodOrFunction]] = 'zeros', - feature_mask_fn: Union[FeatureMaskMethodOrFunction, - Tuple[FeatureMaskMethodOrFunction]] = 'felzenszwalb', - forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - mask_token_id: Optional[int] = None, + baseline_fn: Optional[BaselineFunctionOrTupleOfBaselineFunctions] = None, + feature_mask_fn: Optional[FeatureMaskFunctionOrTupleOfFeatureMaskFunctions] = 'felzenszwalb', + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, ) -> None: - super().__init__( + self.n_samples = TunableParameter( + name='n_samples', + current_value=n_samples, + dtype=int, + is_leaf=True, + space={'low': 10, 'high': 50, 'step': 10}, + ) + self.baseline_fn = format_multimodal_supporting_input( + baseline_fn or ZeroBaselineFunction(), + format=TunableParameter, + input_key='current_value', + name='baseline_fn', + dtype=str, + is_leaf=False, + ) + self.feature_mask_fn = format_multimodal_supporting_input( + feature_mask_fn or Felzenszwalb, + format=TunableParameter, + input_key='current_value', + name='feature_mask_fn', + dtype=str, + is_leaf=False, + ) + Explainer.__init__( + self, model, - forward_arg_extractor, - additional_forward_arg_extractor + target_input_keys, + additional_input_keys, + output_modifier, ) - self.n_samples = n_samples - self.baseline_fn = baseline_fn - self.feature_mask_fn = feature_mask_fn - self.mask_token_id = mask_token_id + Tunable.__init__(self) + self.register_tunable_params([ + self.n_samples, self.baseline_fn, self.feature_mask_fn]) def attribute( self, @@ -71,34 +102,17 @@ def attribute( Returns: Union[torch.Tensor, Tuple[torch.Tensor]]: The result of the explanation. """ - forward_args, additional_forward_args = self._extract_forward_args( + forward_args, additional_forward_args = self.format_inputs( inputs) - forward_args = format_into_tuple(forward_args) - explainer = CaptumKernelShap(self.model) - attrs = explainer.attribute( + baselines = run_multimodal_supporting_util_fn(forward_args, self.baseline_fn) + feature_masks = run_multimodal_supporting_util_fn(forward_args, self.feature_mask_fn) + _explainer = CaptumKernelShap(self._wrapped_model) + attrs = _explainer.attribute( inputs=forward_args, target=targets, - baselines=self._get_baselines(forward_args), - feature_mask=self._get_feature_masks(forward_args), - n_samples=self.n_samples, + baselines=baselines, + feature_mask=feature_masks, + n_samples=self.n_samples.current_value, additional_forward_args=additional_forward_args, ) - attrs = format_out_tuple_if_single(attrs) return attrs - - def get_tunables(self) -> Dict[str, Tuple[type, Dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `n_samples` (int): Value can be selected in the range of `range(10, 50, 10)` - - `baseline_fn` (callable): BaselineFunction selects suitable values in accordance with the modality - - `feature_mask_fn` (callable): FeatureMaskFunction selects suitable values in accordance with the modality - """ - return { - 'n_samples': (int, {'low': 10, 'high': 50, 'step': 10}), - 'baseline_fn': (BaselineFunction, {}), - 'feature_mask_fn': (FeatureMaskFunction, {}) - } diff --git a/pnpxai/explainers/lime.py b/pnpxai/explainers/lime.py index 01c93934..41715d7f 100644 --- a/pnpxai/explainers/lime.py +++ b/pnpxai/explainers/lime.py @@ -1,18 +1,25 @@ -from typing import Callable, Tuple, Union, Optional, Dict, Any +from typing import Callable, Tuple, Union, Optional, Any, List -import torch from torch import Tensor from torch.nn.modules import Module from captum.attr import Lime as CaptumLime from pnpxai.core.detector.types import Linear, Convolution, LSTM, RNN, Attention -from pnpxai.explainers.base import Explainer -from pnpxai.explainers.utils.baselines import BaselineMethodOrFunction, BaselineFunction -from pnpxai.explainers.utils.feature_masks import FeatureMaskMethodOrFunction, FeatureMaskFunction -from pnpxai.utils import format_into_tuple - - -class Lime(Explainer): +from pnpxai.utils import ( + format_multimodal_supporting_input, + run_multimodal_supporting_util_fn, +) +from pnpxai.explainers.base import Explainer, Tunable +from pnpxai.explainers.types import TunableParameter +from pnpxai.explainers.utils.types import ( + BaselineFunctionOrTupleOfBaselineFunctions, + FeatureMaskFunctionOrTupleOfFeatureMaskFunctions, +) +from pnpxai.explainers.utils.baselines import ZeroBaselineFunction +from pnpxai.explainers.utils.feature_masks import Felzenszwalb + + +class Lime(Explainer, Tunable): """ Lime explainer. @@ -33,31 +40,61 @@ class Lime(Explainer): """ SUPPORTED_MODULES = [Linear, Convolution, LSTM, RNN, Attention] + SUPPORTED_DTYPES = [float, int] + SUPPORTED_NDIMS = [2, 4] + alias = ['lime'] def __init__( self, model: Module, n_samples: int = 25, - baseline_fn: Union[BaselineMethodOrFunction, - Tuple[BaselineMethodOrFunction]] = 'zeros', - feature_mask_fn: Union[FeatureMaskMethodOrFunction, - Tuple[FeatureMaskMethodOrFunction]] = 'felzenszwalb', + baseline_fn: Optional[BaselineFunctionOrTupleOfBaselineFunctions] = None, + feature_mask_fn: Optional[FeatureMaskFunctionOrTupleOfFeatureMaskFunctions] = None, perturb_fn: Optional[Callable[[Tensor], Tensor]] = None, - forward_arg_extractor: Optional[Callable[[ - Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]] = None, - additional_forward_arg_extractor: Optional[Callable[[ - Tuple[Tensor]], Tuple[Tensor]]] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, ) -> None: - super().__init__(model, forward_arg_extractor, additional_forward_arg_extractor) - self.baseline_fn = baseline_fn or torch.zeros_like - self.feature_mask_fn = feature_mask_fn + self.n_samples = TunableParameter( + name='n_samples', + current_value=n_samples, + dtype=int, + is_leaf=True, + space={'low': 10, 'high': 50, 'step': 10}, + ) + baseline_fn = baseline_fn or ZeroBaselineFunction() + self.baseline_fn = format_multimodal_supporting_input( + baseline_fn or ZeroBaselineFunction(), + format=TunableParameter, + input_key='current_value', + name='baseline_fn', + dtype=str, + is_leaf=False, + ) + self.feature_mask_fn = format_multimodal_supporting_input( + feature_mask_fn or Felzenszwalb(), + format=TunableParameter, + input_key='current_value', + name='feature_mask_fn', + dtype=str, + is_leaf=False, + ) self.perturb_fn = perturb_fn - self.n_samples = n_samples + Explainer.__init__( + self, + model, + target_input_keys, + additional_input_keys, + output_modifier, + ) + Tunable.__init__(self) + self.register_tunable_params([ + self.n_samples, self.baseline_fn, self.feature_mask_fn]) def attribute( - self, - inputs: Tensor, - targets: Optional[Tensor] = None, + self, + inputs: Tensor, + targets: Optional[Tensor] = None, ) -> Union[Tensor, Tuple[Tensor]]: """ Computes attributions for the given inputs and targets. @@ -69,37 +106,18 @@ def attribute( Returns: Union[torch.Tensor, Tuple[torch.Tensor]]: The result of the explanation. """ - forward_args, additional_forward_args = self._extract_forward_args( - inputs) - forward_args = format_into_tuple(forward_args) + forward_args, additional_forward_args = self.format_inputs(inputs) + baselines = run_multimodal_supporting_util_fn(forward_args, self.baseline_fn) + feature_masks = run_multimodal_supporting_util_fn(forward_args, self.feature_mask_fn) - explainer = CaptumLime(self.model, perturb_func=self.perturb_fn) - attrs = explainer.attribute( + _explainer = CaptumLime(self._wrapped_model, perturb_func=self.perturb_fn) + attrs = _explainer.attribute( inputs=forward_args, target=targets, - baselines=self._get_baselines(forward_args), - feature_mask=self._get_feature_masks(forward_args), - n_samples=self.n_samples, + baselines=baselines, + feature_mask=feature_masks, + n_samples=self.n_samples.current_value, additional_forward_args=additional_forward_args, ) - if isinstance(attrs, tuple) and len(attrs) == 1: - attrs = attrs[0] return attrs - - - def get_tunables(self) -> Dict[str, Tuple[type, Dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `n_samples` (int): Value can be selected in the range of `range(10, 100, 10)` - - `baseline_fn` (callable): BaselineFunction selects suitable values in accordance with the modality - - `feature_mask_fn` (callable): FeatureMaskFunction selects suitable values in accordance with the modality - """ - return { - 'n_samples': (int, {'low': 10, 'high': 100, 'step': 10}), - 'baseline_fn': (BaselineFunction, {}), - 'feature_mask_fn': (FeatureMaskFunction, {}) - } + \ No newline at end of file diff --git a/pnpxai/explainers/lrp.py b/pnpxai/explainers/lrp.py index 60c0c0d6..f7a85f81 100644 --- a/pnpxai/explainers/lrp.py +++ b/pnpxai/explainers/lrp.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Callable, Sequence, Union, Optional +from typing import Dict, List, Tuple, Callable, Sequence, Union, Optional, Any import _operator import warnings @@ -6,7 +6,6 @@ from torch import nn, fx, Tensor from torch.nn.modules import Module -from zennit.attribution import Gradient from zennit.core import Composite from zennit.composites import ( layer_map_base, @@ -15,19 +14,24 @@ EpsilonPlus, EpsilonAlpha2Beta1, ) -from zennit.types import Linear from zennit.rules import Epsilon from zennit.canonizers import SequentialMergeBatchNorm, Canonizer from pnpxai.core.detector.types import Linear, Convolution, LSTM, RNN, Attention +from pnpxai.core.utils import ModelWrapper from pnpxai.explainers.attentions.module_converters import default_attention_converters from pnpxai.explainers.attentions.rules import ConservativeAttentionPropagation +from pnpxai.explainers.base import Tunable from pnpxai.explainers.zennit.attribution import Gradient, LayerGradient from pnpxai.explainers.zennit.rules import LayerNormRule from pnpxai.explainers.zennit.base import ZennitExplainer from pnpxai.explainers.zennit.layer import StackAndSum -from pnpxai.explainers.utils import captum_wrap_model_input -from pnpxai.explainers.types import ForwardArgumentExtractor, TargetLayerOrListOfTargetLayers +from pnpxai.explainers.utils import ModelWrapperForLayerAttribution +from pnpxai.explainers.types import ( + TargetLayerOrTupleOfTargetLayers, + TunableParameter, +) +from pnpxai.utils import format_into_tuple, format_out_tuple_if_single class LRPBase(ZennitExplainer): @@ -37,7 +41,7 @@ class LRPBase(ZennitExplainer): Parameters: model (Module): The PyTorch model for which attribution is to be computed. zennit_composite (Composite): The Composite object applies canonizers and register hooks to modules. One Composite instance may only be applied to a single module at a time. - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained n_classes (Optional[int]): Number of classes forward_arg_extractor: A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned. additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s). @@ -50,41 +54,46 @@ def __init__( self, model: Module, zennit_composite: Composite, - forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - layer: Optional[TargetLayerOrListOfTargetLayers] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, n_classes: Optional[int] = None, ) -> None: super().__init__( model, - forward_arg_extractor, - additional_forward_arg_extractor, + target_input_keys, + additional_input_keys, + output_modifier, n_classes ) self.zennit_composite = zennit_composite - self.layer = layer - - def _layer_explainer(self, model: Union[Module, fx.GraphModule]) -> LayerGradient: - wrapped_model = captum_wrap_model_input(model) - stack = self.layer.copy() if isinstance( - self.layer, Sequence) else [self.layer] - layers = [] + self.target_layer = target_layer + + @property + def _layer_attributor(self) -> LayerGradient: + preprocessed_model = self._preprocess_model() + wrapped_model = ModelWrapperForLayerAttribution(preprocessed_model) + stack = list(format_into_tuple(self.target_layer)) + # stack = self.target_layer.copy() if isinstance( + # self.target_layer, Sequence) else [self.target_layer] + layers = () while stack: - layer = stack.pop(0) - if isinstance(layer, str): - layers.append(wrapped_model.input_maps[layer]) + target_layer = stack.pop(0) + if isinstance(target_layer, str): + layers += (wrapped_model.input_maps[target_layer],) continue - if isinstance(model, fx.GraphModule): + if isinstance(preprocessed_model, fx.GraphModule): child_nodes = [] found = False - for node in model.graph.nodes: + for node in preprocessed_model.graph.nodes: if node.op == "call_module": try: module = self.model.get_submodule(node.target) except AttributeError: continue - if module is layer: - layers.append(layer) + if module is target_layer: + layers += (target_layer,) found = True break path_to_node = node.target.split(".")[:-1] @@ -95,32 +104,46 @@ def _layer_explainer(self, model: Union[Module, fx.GraphModule]) -> LayerGradien ".".join(path_to_node[:i+1])) for i in range(len(path_to_node)) ] - if any(anc is layer for anc in ancestors): + if any(anc is target_layer for anc in ancestors): child_nodes.append(node) if not found: last_child = self.model.get_submodule( child_nodes[-1].target) - layers.append(last_child) - elif isinstance(model, Module): - layers.append(layer) - if len(layers) == 1: - layers = layers[0] + layers += (last_child,) + elif isinstance(preprocessed_model, Module): + layers += (target_layer,) + layers = format_out_tuple_if_single(layers) return LayerGradient( model=wrapped_model, layer=layers, composite=self.zennit_composite, ) - def _explainer(self, model) -> Gradient: + @property + def _attributor(self) -> Gradient: return Gradient( - model=model, + model=self._preprocess_model(), composite=self.zennit_composite ) - def explainer(self, model) -> Union[Gradient, LayerGradient]: - if self.layer is None: - return self._explainer(model) - return self._layer_explainer(model) + @property + def attributor(self) -> Union[Gradient, LayerGradient]: + if self.target_layer is None: + return self._attributor + return self._layer_attributor + + def _preprocess_model(self): + model, treated = _replace_add_function_with_sum_module(self.model) + if treated: + # rewrap treated model + return ModelWrapper( + model=model, + target_input_keys=self.target_input_keys, + additional_input_keys=self.additional_input_keys, + output_modifier=self.output_modifier, + ) + return self._wrapped_model + def attribute( self, @@ -137,10 +160,10 @@ def attribute( Returns: torch.Tensor: The result of the explanation. """ - model = _replace_add_function_with_sum_module(self.model) - forward_args, additional_forward_args = self._extract_forward_args( - inputs) - with self.explainer(model=model) as attributor: + forward_args, additional_forward_args = self.format_inputs(inputs) + + # the composite is registered by __enter__ method + with self.attributor as attributor: attrs = attributor.forward( forward_args, targets, @@ -149,7 +172,7 @@ def attribute( return attrs -class LRPUniformEpsilon(LRPBase): +class LRPUniformEpsilon(LRPBase, Tunable): """ LRPUniformEpsilon explainer. @@ -160,12 +183,15 @@ class LRPUniformEpsilon(LRPBase): epsilon (Union[float, Callable[[Tensor], Tensor]]): The epsilon value. stabilizer (Union[float, Callable[[Tensor], Tensor]]): The stabilizer value zennit_canonizers (Optional[List[Canonizer]]): An optional list of canonizers. Canonizers modify modules temporarily such that certain attribution rules can properly be applied. - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained n_classes (Optional[int]): Number of classes **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer """ SUPPORTED_MODULES = [Linear, Convolution, LSTM, RNN, Attention] + SUPPORTED_DTYPES = [float, int] + SUPPORTED_NDIMS = [2, 4] + alias = ['lrp_uniform_epsilon', 'lrp_e'] def __init__( self, @@ -173,39 +199,39 @@ def __init__( epsilon: Union[float, Callable[[Tensor], Tensor]] = .25, stabilizer: Union[float, Callable[[Tensor], Tensor]] = 1e-6, zennit_canonizers: Optional[List[Canonizer]] = None, - forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - layer: Optional[TargetLayerOrListOfTargetLayers] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, n_classes: Optional[int] = None ) -> None: - self.epsilon = epsilon + self.epsilon = TunableParameter( + name='epsilon', + current_value=epsilon, + dtype=float, + is_leaf=True, + space={"low": 1e-6, "high": 1, "log": True}, + ) self.stabilizer = stabilizer self.zennit_canonizers = zennit_canonizers zennit_composite = _get_uniform_epsilon_composite( - epsilon, stabilizer, zennit_canonizers) - super().__init__( + self.epsilon.current_value, stabilizer, zennit_canonizers) + LRPBase.__init__( + self, model, zennit_composite, - forward_arg_extractor, - additional_forward_arg_extractor, - layer, + target_input_keys, + additional_input_keys, + output_modifier, + target_layer, n_classes ) + Tunable.__init__(self) + self.register_tunable_params([self.epsilon]) - def get_tunables(self) -> Dict[str, Tuple[type, dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `epsilon` (float): Value can be selected in the range of `range(1e-6, 1)` - """ - return { - 'epsilon': (float, {"low": 1e-6, "high": 1, "log": True}), - } - -class LRPEpsilonGammaBox(LRPBase): +class LRPEpsilonGammaBox(LRPBase, Tunable): """ LRPEpsilonGammaBox explainer. @@ -219,12 +245,16 @@ class LRPEpsilonGammaBox(LRPBase): epsilon (Union[float, Callable[[Tensor], Tensor]]): The epsilon value. stabilizer (Union[float, Callable[[Tensor], Tensor]]): The stabilizer value zennit_canonizers (Optional[List[Canonizer]]): An optional list of canonizers. Canonizers modify modules temporarily such that certain attribution rules can properly be applied. - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained n_classes (Optional[int]): Number of classes **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer """ SUPPORTED_MODULES = [Convolution] + SUPPORTED_DTYPES = [float] + SUPPORTED_NDIMS = [4] + alias = ['lrp_epsilon_gamma_box', 'lrp_egb'] + def __init__( self, @@ -235,45 +265,50 @@ def __init__( gamma: float = .25, stabilizer: Union[float, Callable[[Tensor], Tensor]] = 1e-6, zennit_canonizers: Optional[List[Canonizer]] = None, - forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - layer: Optional[TargetLayerOrListOfTargetLayers] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, n_classes: Optional[int] = None, ) -> None: self.low = low self.high = high - self.epsilon = epsilon - self.gamma = gamma + self.epsilon = TunableParameter( + name='epsilon', + current_value=epsilon, + dtype=float, + is_leaf=True, + space={"low": 1e-6, "high": 1, "log": True}, + ) + self.gamma = TunableParameter( + name='gamma', + current_value=gamma, + dtype=float, + is_leaf=True, + space={"low": 1e-6, "high": 1, "log": True}, + ) self.stabilizer = stabilizer self.zennit_canonizers = zennit_canonizers zennit_composite = _get_epsilon_gamma_box_composite( - low, high, epsilon, gamma, stabilizer, zennit_canonizers) - super().__init__( + low, high, self.epsilon.current_value, self.gamma.current_value, + stabilizer, zennit_canonizers, + ) + LRPBase.__init__( + self, model, zennit_composite, - forward_arg_extractor, - additional_forward_arg_extractor, - layer, + target_input_keys, + additional_input_keys, + output_modifier, + target_layer, n_classes ) + Tunable.__init__(self) + self.register_tunable_params([self.epsilon, self.gamma]) - def get_tunables(self) -> Dict[str, Tuple[type, dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `epsilon` (float): Value can be selected in the range of `range(1e-6, 1)` - `gamma` (float): Value can be selected in the range of `range(1e-6, 1)` - """ - return { - 'epsilon': (float, {"low": 1e-6, "high": 1, "log": True}), - 'gamma': (float, {"low": 1e-6, "high": 1, "log": True}), - } - - -class LRPEpsilonPlus(LRPBase): +class LRPEpsilonPlus(LRPBase, Tunable): """ LRPEpsilonPlus explainer. @@ -284,12 +319,15 @@ class LRPEpsilonPlus(LRPBase): epsilon (Union[float, Callable[[Tensor], Tensor]]): The epsilon value. stabilizer (Union[float, Callable[[Tensor], Tensor]]): The stabilizer value zennit_canonizers (Optional[List[Canonizer]]): An optional list of canonizers. Canonizers modify modules temporarily such that certain attribution rules can properly be applied. - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained n_classes (Optional[int]): Number of classes **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer """ SUPPORTED_MODULES = [Convolution] + SUPPORTED_DTYPES = [float] + SUPPORTED_NDIMS = [4] + alias = ['lrp_epsilon_plus', 'lrp_ep'] def __init__( self, @@ -297,39 +335,39 @@ def __init__( epsilon: Union[float, Callable[[Tensor], Tensor]] = 1e-6, stabilizer: Union[float, Callable[[Tensor], Tensor]] = 1e-6, zennit_canonizers: Optional[List[Canonizer]] = None, - forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - layer: Optional[TargetLayerOrListOfTargetLayers] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, n_classes: Optional[int] = None ) -> None: - self.epsilon = epsilon + self.epsilon = TunableParameter( + name='epsilon', + current_value=epsilon, + dtype=float, + is_leaf=True, + space={"low": 1e-6, "high": 1, "log": True}, + ) self.stabilizer = stabilizer self.zennit_canonizers = zennit_canonizers zennit_composite = _get_epsilon_plus_composite( - epsilon, stabilizer, zennit_canonizers) - super().__init__( + self.epsilon.current_value, stabilizer, zennit_canonizers) + LRPBase.__init__( + self, model, zennit_composite, - forward_arg_extractor, - additional_forward_arg_extractor, - layer, + target_input_keys, + additional_input_keys, + output_modifier, + target_layer, n_classes ) - - def get_tunables(self) -> Dict[str, Tuple[type, dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `epsilon` (float): Value can be selected in the range of `range(1e-6, 1)` - """ - return { - 'epsilon': (float, {"low": 1e-6, "high": 1, "log": True}), - } + Tunable.__init__(self) + self.register_tunable_params([self.epsilon]) -class LRPEpsilonAlpha2Beta1(LRPBase): +class LRPEpsilonAlpha2Beta1(LRPBase, Tunable): """ LRPEpsilonAlpha2Beta1 explainer. @@ -340,12 +378,15 @@ class LRPEpsilonAlpha2Beta1(LRPBase): epsilon (Union[float, Callable[[Tensor], Tensor]]): The epsilon value. stabilizer (Union[float, Callable[[Tensor], Tensor]]): The stabilizer value zennit_canonizers (Optional[List[Canonizer]]): An optional list of canonizers. Canonizers modify modules temporarily such that certain attribution rules can properly be applied. - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained n_classes (Optional[int]): Number of classes **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer """ SUPPORTED_MODULES = [Convolution] + SUPPORTED_DTYPES = [float] + SUPPORTED_NDIMS = [4] + alias = ['lrp_epsilon_alpha2_beta1', 'lrp_ea2b1'] def __init__( self, @@ -353,43 +394,43 @@ def __init__( epsilon: Union[float, Callable[[Tensor], Tensor]] = 1e-6, stabilizer: Union[float, Callable[[Tensor], Tensor]] = 1e-6, zennit_canonizers: Optional[List[Canonizer]] = None, - forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - additional_forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, - layer: Optional[TargetLayerOrListOfTargetLayers] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, n_classes: Optional[int] = None ) -> None: - self.epsilon = epsilon + self.epsilon = TunableParameter( + name='epsilon', + current_value=epsilon, + dtype=float, + is_leaf=True, + space={"low": 1e-6, "high": 1, "log": True}, + ) self.stabilizer = stabilizer self.zennit_canonizers = zennit_canonizers zennit_composite = _get_epsilon_alpha2_beta1_composite( - epsilon, stabilizer, zennit_canonizers) - super().__init__( + self.epsilon.current_value, stabilizer, zennit_canonizers) + LRPBase.__init__( + self, model, zennit_composite, - forward_arg_extractor, - additional_forward_arg_extractor, - layer, + target_input_keys, + additional_input_keys, + output_modifier, + target_layer, n_classes ) - - def get_tunables(self) -> Dict[str, Tuple[type, dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `epsilon` (float): Value can be selected in the range of `range(1e-6, 1)` - """ - return { - 'epsilon': (float, {"low": 1e-6, "high": 1, "log": True}), - } + Tunable.__init__(self) + self.register_tunable_params([self.epsilon]) def _get_uniform_epsilon_composite(epsilon, stabilizer, zennit_canonizers): zennit_canonizers = zennit_canonizers or [] canonizers = canonizers_base() + default_attention_converters + zennit_canonizers layer_map = ( - [(Linear, Epsilon(epsilon=epsilon))] + [(Linear, Epsilon(epsilon=epsilon)), (Convolution, Epsilon(epsilon=epsilon))] + transformer_layer_map(stabilizer=stabilizer) + layer_map_base(stabilizer=stabilizer) ) @@ -397,7 +438,14 @@ def _get_uniform_epsilon_composite(epsilon, stabilizer, zennit_canonizers): return composite -def _get_epsilon_gamma_box_composite(low, high, epsilon, gamma, stabilizer, zennit_canonizers): +def _get_epsilon_gamma_box_composite( + low, + high, + epsilon, + gamma, + stabilizer, + zennit_canonizers, +): zennit_canonizers = zennit_canonizers or [] canonizers = canonizers_base() + default_attention_converters + zennit_canonizers composite = EpsilonGammaBox( @@ -449,13 +497,13 @@ def canonizers_base(): def _replace_add_function_with_sum_module(model: Module) -> fx.GraphModule: + treated = False try: traced_model = fx.symbolic_trace(model) - except: + except Exception: warnings.warn( - "Your model cannot be traced by torch.fx.symbolic_trace.") - return model - treated = False + "Your model may not be traced by torch.fx.symbolic_trace.") + return model, treated for node in traced_model.graph.nodes: if node.target is _operator.add: treated = True @@ -467,8 +515,8 @@ def _replace_add_function_with_sum_module(model: Module) -> fx.GraphModule: node.replace_all_uses_with(replacement) traced_model.graph.erase_node(node) if not treated: - return model + return model, treated # ensure changes traced_model.graph.lint() traced_model.recompile() - return traced_model + return traced_model, treated diff --git a/pnpxai/explainers/rap/attribute.py b/pnpxai/explainers/rap/attribute.py index 59579dce..fc9f41bc 100644 --- a/pnpxai/explainers/rap/attribute.py +++ b/pnpxai/explainers/rap/attribute.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Dict +from typing import Any, Optional, Dict, Optional, List, Union, Callable import torch from torch import Tensor, nn @@ -23,9 +23,23 @@ class RAP(Explainer): """ SUPPORTED_MODULES = [Linear, Convolution] + SUPPORTED_DTYPES = [float] + SUPPORTED_NDIMS = [4] + alias = 'rap' - def __init__(self, model: Model): - super().__init__(model) + def __init__( + self, + model: Model, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + ): + super().__init__( + model, + target_input_keys, + additional_input_keys, + output_modifier, + ) self.method = RelativeAttributePropagation(model) def compute_pred(self, output: Tensor) -> Tensor: @@ -59,7 +73,11 @@ def attribute(self, inputs: DataSource, targets: DataSource, *args: Any, **kwarg Returns: DataSource: RAP attributions. """ - outputs = self.method.run(inputs) + target_inputs, _ = self.format_inputs(inputs) + assert ( + len(target_inputs) == 1 + ), "RAP for multiple target inputs is not supported." + outputs = self.method.run(*target_inputs) preds = self.compute_pred(outputs) relprop = self.method.relprop(preds) return relprop diff --git a/pnpxai/explainers/smooth_grad.py b/pnpxai/explainers/smooth_grad.py index 001fd793..b080702d 100644 --- a/pnpxai/explainers/smooth_grad.py +++ b/pnpxai/explainers/smooth_grad.py @@ -1,19 +1,21 @@ -from typing import Tuple, Callable, Sequence, Union, Optional, Dict +from typing import Tuple, Callable, Sequence, Union, Optional, Any, List from torch import Tensor from torch.nn.modules import Module from pnpxai.core.detector.types import Linear, Convolution, LSTM, RNN, Attention from pnpxai.utils import format_into_tuple, format_out_tuple_if_single +from pnpxai.explainers.base import Tunable +from pnpxai.explainers.types import TunableParameter, TargetLayerOrTupleOfTargetLayers from pnpxai.explainers.zennit.attribution import SmoothGradient as SmoothGradAttributor from pnpxai.explainers.zennit.attribution import ( LayerSmoothGradient as LayerSmoothGradAttributor, ) from pnpxai.explainers.zennit.base import ZennitExplainer -from pnpxai.explainers.utils import captum_wrap_model_input +from pnpxai.explainers.utils import ModelWrapperForLayerAttribution -class SmoothGrad(ZennitExplainer): +class SmoothGrad(ZennitExplainer, Tunable): """ SmoothGrad explainer. @@ -23,7 +25,7 @@ class SmoothGrad(ZennitExplainer): model (Module): The PyTorch model for which attribution is to be computed. noise_level (float): The added noise level. n_iter (int): The Number of iterations the algorithm makes - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained n_classes (Optional[int]): Number of classes forward_arg_extractor: A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned. additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s). @@ -34,58 +36,72 @@ class SmoothGrad(ZennitExplainer): """ SUPPORTED_MODULES = [Linear, Convolution, LSTM, RNN, Attention] + SUPPORTED_DTYPES = [float, int] + SUPPORTED_NDIMS = [2, 4] + alias = ['smooth_grad', 'sg'] def __init__( self, model: Module, noise_level: float = 0.1, n_iter: int = 20, - forward_arg_extractor: Optional[ - Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]] - ] = None, - additional_forward_arg_extractor: Optional[ - Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]] - ] = None, - layer: Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]] = None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, n_classes: Optional[int] = None, ) -> None: - super().__init__( - model, forward_arg_extractor, additional_forward_arg_extractor, n_classes + ZennitExplainer.__init__( + self, + model, + target_input_keys, + additional_input_keys, + output_modifier, + n_classes, ) - self.noise_level = noise_level - self.n_iter = n_iter - self.layer = layer + self.noise_level = TunableParameter( + name='noise_level', + current_value=noise_level, + dtype=float, + is_leaf=True, + space={"low": 0.05, "high": 0.95, "step": 0.05}, + ) + self.n_iter = TunableParameter( + name='n_iter', + current_value=n_iter, + dtype=int, + is_leaf=True, + space={"low": 10, "high": 100, "step": 10}, + ) + self.target_layer = target_layer + Tunable.__init__(self) + self.register_tunable_params([self.noise_level, self.n_iter]) @property def _layer_attributor(self) -> LayerSmoothGradAttributor: - wrapped_model = captum_wrap_model_input(self.model) - layers = ( - [ - wrapped_model.input_maps[layer] if isinstance(layer, str) else layer - for layer in self.layer - ] - if isinstance(self.layer, Sequence) - else [self.layer] - ) - if len(layers) == 1: - layers = layers[0] + wrapped_model = ModelWrapperForLayerAttribution(self._wrapped_model) + layers = [ + wrapped_model.input_maps[target_layer] if isinstance(target_layer, str) else target_layer + for target_layer in format_into_tuple(self.target_layer) + ] + layers = format_out_tuple_if_single(layers) return LayerSmoothGradAttributor( model=wrapped_model, layer=layers, - noise_level=self.noise_level, - n_iter=self.n_iter, + noise_level=self.noise_level.current_value, + n_iter=self.n_iter.current_value, ) @property def _attributor(self) -> SmoothGradAttributor: return SmoothGradAttributor( - model=self.model, - noise_level=self.noise_level, - n_iter=self.n_iter, + model=self._wrapped_model, + noise_level=self.noise_level.current_value, + n_iter=self.n_iter.current_value, ) def attributor(self) -> Union[SmoothGradAttributor, LayerSmoothGradAttributor]: - if self.layer is None: + if self.target_layer is None: return self._attributor return self._layer_attributor @@ -104,28 +120,14 @@ def attribute( Returns: Union[torch.Tensor, Tuple[torch.Tensor]]: The result of the explanation. """ - forward_args, additional_forward_args = self._extract_forward_args(inputs) + forward_args, additional_forward_args = self.format_inputs(inputs) with self.attributor() as attributor: grads = format_into_tuple( attributor.forward( - forward_args, + format_out_tuple_if_single(forward_args), targets, additional_forward_args, return_squared=False, ) ) return format_out_tuple_if_single(grads) - - def get_tunables(self) -> Dict[str, Tuple[type, dict]]: - """ - Provides Tunable parameters for the optimizer - - Tunable parameters: - `noise_level` (float): Value can be selected in the range of `range(0, 0.95, 0.05)` - - `n_iter` (int): Value can be selected in the range of `range(10, 100, 10)` - """ - return { - "noise_level": (float, {"low": 0.05, "high": 0.95, "step": 0.05}), - "n_iter": (int, {"low": 10, "high": 100, "step": 10}), - } diff --git a/pnpxai/explainers/types.py b/pnpxai/explainers/types.py index 48d618ea..78e0fdd4 100644 --- a/pnpxai/explainers/types.py +++ b/pnpxai/explainers/types.py @@ -1,8 +1,85 @@ -from typing import Tuple, Union, Callable, List +from typing import Tuple, Union, Callable, Any, Type, Optional + from torch import Tensor from torch.nn.modules import Module -TensorOrTupleOfTensors = Union[Tensor, Tuple[Tensor]] -ForwardArgumentExtractor = Callable[[TensorOrTupleOfTensors], TensorOrTupleOfTensors] + +TensorOrTupleOfTensors = Union[Tensor, Tuple[Tensor, ...]] TargetLayer = Union[str, Module] -TargetLayerOrListOfTargetLayers = Union[TargetLayer, List[TargetLayer]] \ No newline at end of file +TargetLayerOrTupleOfTargetLayers = Union[TargetLayer, Tuple[TargetLayer, ...]] + + +class TunableParameter: + def __init__( + self, + name: str, + current_value: Any, + dtype: Type, + is_leaf: bool, + space: Optional[Any] = None, + selector: Optional[Any] = None, + ): + self._name = name + self._current_value = current_value + self.dtype = dtype + self.is_leaf = is_leaf + self._space = space + self._selector = selector + if self.is_leaf and self.space is None: + raise ValueError("If 'is_leaf' is True, 'space' cannot be None.") + self._disabled = False + + def __repr__(self): + return repr(self._current_value) + + @property + def name(self): + return self._name + + @property + def current_value(self): + return self._current_value + + @property + def space(self): + return self._space + + @property + def selector(self): + return self._selector + + def rename(self, name): + self._name = name + return self + + def update_value(self, value): + self._current_value = value + return self + + def set_space(self, space: Any): + self._space = space + return self + + def set_selector(self, function_selector, set_space=True): + self._selector = function_selector + if set_space: + self.set_space({'choices': function_selector.choices}) + return self + + def is_callable(self): + return isinstance(self._current_value, Callable) + + @property + def disabled(self): + return self._disabled + + def disable(self): + self._disabled = True + + def enable(self): + self._disabled = False + + def __call__(self, *args, **kwargs): + if not self.is_callable(): + raise TypeError(f'{self._current_value} is not callable.') + return self._current_value(*args, **kwargs) diff --git a/pnpxai/explainers/utils/__init__.py b/pnpxai/explainers/utils/__init__.py index adbecf52..210d2660 100644 --- a/pnpxai/explainers/utils/__init__.py +++ b/pnpxai/explainers/utils/__init__.py @@ -1,8 +1,6 @@ from pnpxai.explainers.utils.utils import ( find_cam_target_layer, - get_default_feature_mask_fn, - captum_wrap_model_input, - _format_to_tuple, + ModelWrapperForLayerAttribution, ) from pnpxai.explainers.utils.base import UtilFunction from pnpxai.explainers.utils.baselines import BaselineFunction diff --git a/pnpxai/explainers/utils/baselines.py b/pnpxai/explainers/utils/baselines.py index bf5d8699..5467f414 100644 --- a/pnpxai/explainers/utils/baselines.py +++ b/pnpxai/explainers/utils/baselines.py @@ -1,11 +1,11 @@ -from typing import Literal, Union, Literal +from typing import Literal, Union import torch import torchvision.transforms.functional as TF +from pnpxai.explainers.base import Tunable +from pnpxai.explainers.types import TunableParameter from pnpxai.explainers.utils.base import UtilFunction -BaselineMethod = Literal['zeros', 'invert', 'gaussian_blur', 'token'] - class BaselineFunction(UtilFunction): """ @@ -22,16 +22,9 @@ class BaselineFunction(UtilFunction): def __init__(self, *args, **kwargs): pass - @classmethod - def from_method(cls, method, **kwargs): - baseline_fn = BASELINE_FUNCTIONS.get(method, None) - if baseline_fn is None: - raise ValueError - return baseline_fn(**kwargs) - class TokenBaselineFunction(BaselineFunction): - def __init__(self, token_id, **kwargs): + def __init__(self, token_id): super().__init__() self.token_id = token_id @@ -40,7 +33,7 @@ def __call__(self, inputs: torch.Tensor): class ZeroBaselineFunction(BaselineFunction): - def __init__(self, *args, **kwargs): + def __init__(self): super().__init__() def __call__(self, inputs: torch.Tensor): @@ -48,7 +41,7 @@ def __call__(self, inputs: torch.Tensor): class MeanBaselineFunction(BaselineFunction): - def __init__(self, dim, **kwargs): + def __init__(self, dim): super().__init__() self.dim = dim @@ -59,64 +52,79 @@ def __call__(self, inputs: torch.Tensor): class InvertBaselineFunction(BaselineFunction): - def __init__(self, *args, **kwargs): + def __init__(self): super().__init__() def __call__(self, inputs: torch.Tensor): return TF.invert(inputs) -class GaussianBlurBaselineFunction(BaselineFunction): +class GaussianBlurBaselineFunction(BaselineFunction, Tunable): def __init__( self, kernel_size_x: int = 3, kernel_size_y: int = 3, sigma_x: float = .5, sigma_y: float = .5, - **kwargs ): - super().__init__() - self.kernel_size_x = kernel_size_x - self.kernel_size_y = kernel_size_y - self.sigma_x = sigma_x - self.sigma_y = sigma_y + BaselineFunction.__init__(self) + self.kernel_size_x = TunableParameter( + name='kernel_size_x', + current_value=kernel_size_x, + dtype=int, + is_leaf=True, + space={'low': 1, 'high': 11, 'step': 2}, + ) + self.kernel_size_y = TunableParameter( + name='kernel_size_y', + current_value=kernel_size_y, + dtype=int, + is_leaf=True, + space={'low': 1, 'high': 11, 'step': 2}, + ) + self.sigma_x = TunableParameter( + name='sigma_x', + current_value=sigma_x, + dtype=float, + is_leaf=True, + space={'low': .05, 'high': 2., 'step': .05}, + ) + self.sigma_y = TunableParameter( + name='sigma_y', + current_value=sigma_y, + dtype=float, + is_leaf=True, + space={'low': .05, 'high': 2., 'step': .05}, + ) + Tunable.__init__(self) + self.register_tunable_params([ + self.kernel_size_x, self.kernel_size_y, + self.sigma_x, self.sigma_y, + ]) def __call__(self, inputs: torch.Tensor): return TF.gaussian_blur( inputs, - kernel_size=[self.kernel_size_x, self.kernel_size_y], - sigma=[self.sigma_x, self.sigma_y], + kernel_size=[ + self.kernel_size_x.current_value, + self.kernel_size_y.current_value, + ], + sigma=[self.sigma_x.current_value, self.sigma_y.current_value], ) - def get_tunables(self): - return { - 'kernel_size_x': (int, {'low': 1, 'high': 11, 'step': 2}), - 'kernel_size_y': (int, {'low': 1, 'high': 11, 'step': 2}), - 'sigma_x': (float, {'low': .05, 'high': 2., 'step': .05}), - 'sigma_y': (float, {'low': .05, 'high': 2., 'step': .05}), - } - - -BaselineMethodOrFunction = Union[BaselineMethod, BaselineFunction] - -BASELINE_FUNCTIONS_FOR_IMAGE = { - 'zeros': ZeroBaselineFunction, - 'mean': MeanBaselineFunction, - 'invert': InvertBaselineFunction, - 'gaussian_blur': GaussianBlurBaselineFunction, -} - -BASELINE_FUNCTIONS_FOR_TEXT = { - 'token': TokenBaselineFunction, -} -BASELINE_FUNCTIONS_FOR_TIME_SERIES = { - 'zeros': ZeroBaselineFunction, - 'mean': MeanBaselineFunction, +BASELINE_FUNCTIONS = { # (dtype, ndims): {available util functions} + (float, 2): { + 'zeros': ZeroBaselineFunction, + 'mean': MeanBaselineFunction, + }, + (float, 4): { + 'zeros': ZeroBaselineFunction, + 'mean': MeanBaselineFunction, + 'invert': InvertBaselineFunction, + 'gaussian_blur': GaussianBlurBaselineFunction, + }, + (int, 2): { + 'token': TokenBaselineFunction, + }, } - -BASELINE_FUNCTIONS = { - **BASELINE_FUNCTIONS_FOR_IMAGE, - **BASELINE_FUNCTIONS_FOR_TEXT, - **BASELINE_FUNCTIONS_FOR_TIME_SERIES, -} \ No newline at end of file diff --git a/pnpxai/explainers/utils/feature_masks.py b/pnpxai/explainers/utils/feature_masks.py index f6a68ab4..f3f2dea3 100644 --- a/pnpxai/explainers/utils/feature_masks.py +++ b/pnpxai/explainers/utils/feature_masks.py @@ -1,6 +1,7 @@ -from typing import Literal, Union +from typing import Literal, Union, Sequence import torch +import numpy as np from skimage.segmentation import ( felzenszwalb, quickshift, @@ -8,7 +9,11 @@ watershed, ) +from pnpxai.explainers.base import Tunable +from pnpxai.explainers.types import TunableParameter from pnpxai.explainers.utils.base import UtilFunction +from torchvision.transforms import InterpolationMode, Resize +import torchvision.transforms.functional as TF class FeatureMaskFunction(UtilFunction): @@ -41,13 +46,6 @@ class FeatureMaskFunction(UtilFunction): def __init__(self): pass - @classmethod - def from_method(cls, method, **kwargs): - feature_mask_fn = FEATURE_MASK_FUNCTIONS.get(method, None) - if feature_mask_fn is None: - raise ValueError - return feature_mask_fn(**kwargs) - @staticmethod def _skseg_for_tensor(fn, inputs: torch.Tensor, **kwargs) -> torch.Tensor: """ @@ -73,46 +71,106 @@ def _skseg_for_tensor(fn, inputs: torch.Tensor, **kwargs) -> torch.Tensor: return torch.stack(feature_mask).long().to(inputs.device) -class Felzenszwalb(FeatureMaskFunction): +class Checkerboard(FeatureMaskFunction): + def __init__( + self, + size: Sequence[int] = [20, 20], + ): + assert len(size) == 2 + self.size = size + self._n_checkers = size[0] * size[1] + + def __call__(self, inputs: torch.Tensor): + assert inputs.dim() == 4 + + bsz, c, h, w = inputs.size() + # print(input_size) + + resize = Resize([h, w], interpolation=InterpolationMode.NEAREST) + + patch_masks = [] + for i in range(self._n_checkers): + mask = np.zeros(self._n_checkers) + mask[i] = i + mask = resize(torch.Tensor(mask).reshape(-1,self.size[0], self.size[1])).unsqueeze(1) + patch_masks.append(mask.numpy()) + return torch.from_numpy(sum(patch_masks)).squeeze(1).repeat(bsz, 1, 1).long().to(inputs.device) + + +class Felzenszwalb(FeatureMaskFunction, Tunable): def __init__( self, scale: float = 250., sigma: float = 1., - min_size: int = 50, ): - super().__init__() - self.scale = scale - self.sigma = sigma + FeatureMaskFunction.__init__(self) + self.scale = TunableParameter( + name='scale', + current_value=scale, + dtype=float, + is_leaf=True, + space={'low': 1e0, 'high': 1e3, 'log': True}, + ) + self.sigma = TunableParameter( + name='sigma', + current_value=sigma, + dtype=float, + is_leaf=True, + space={'low': 0., 'high': 2., 'step': .05}, + ) + Tunable.__init__(self) + self.register_tunable_params([self.scale, self.sigma]) def __call__(self, inputs: torch.Tensor): return self._skseg_for_tensor( felzenszwalb, inputs, - scale=self.scale, - sigma=self.sigma, + scale=self.scale.current_value, + sigma=self.sigma.current_value, min_size=50, ) - def get_tunables(self): - return { - 'scale': (float, {'low': 1e0, 'high': 1e3, 'log': True}), - 'sigma': (float, {'low': 0., 'high': 2., 'step': .05}), - } - -class Quickshift(FeatureMaskFunction): +class Quickshift(FeatureMaskFunction, Tunable): def __init__( self, ratio: float = 1., kernel_size: float = 5, - max_dist: float = 10., + max_dist: int = 10., sigma: float = 0., ): - super().__init__() - self.ratio = ratio - self.kernel_size = kernel_size - self.max_dist = max_dist - self.sigma = sigma + FeatureMaskFunction.__init__(self) + self.ratio = TunableParameter( + name='ratio', + current_value=ratio, + dtype=float, + is_leaf=True, + space={'low': 0., 'high': 1., 'step': .1}, + ) + self.kernel_size = TunableParameter( + name='kernel_size', + current_value=kernel_size, + dtype=float, + is_leaf=True, + space={'low': 1., 'high': 10., 'step': 1.}, + ) + self.max_dist = TunableParameter( + name='max_dist', + current_value=max_dist, + dtype=int, + is_leaf=True, + space={'low': 1, 'high': 20, 'step': 1}, + ) + self.sigma = TunableParameter( + name='sigma', + current_value=sigma, + dtype=float, + is_leaf=True, + space={'low': 0., 'high': 2., 'step': .1}, + ) + Tunable.__init__(self) + self.register_tunable_params([ + self.ratio, self.kernel_size, self.max_dist, self.sigma]) def __call__(self, inputs: torch.Tensor): if inputs.size(1) == 1: @@ -120,49 +178,55 @@ def __call__(self, inputs: torch.Tensor): return self._skseg_for_tensor( quickshift, inputs, - ratio=self.ratio, - kernel_size=self.kernel_size, - max_dist=self.max_dist, - sigma=self.sigma, + ratio=self.ratio.current_value, + kernel_size=self.kernel_size.current_value, + max_dist=self.max_dist.current_value, + sigma=self.sigma.current_value, ) - def get_tunables(self): - return { - 'ratio': (float, {'low': 0., 'high': 1., 'step': .1}), - 'kernel_size': (float, {'low': 1., 'high': 10., 'step': 1.}), - 'max_dist': (int, {'low': 1, 'high': 20, 'step': 1}), - 'sigma': (float, {'low': 0., 'high': 2., 'step': .1,}) - } - -class Slic(FeatureMaskFunction): +class Slic(FeatureMaskFunction, Tunable): def __init__( self, n_segments: int = 150, compactness: float = 1., sigma: float = 0. ): - super().__init__() - self.n_segments = n_segments - self.compactness = compactness - self.sigma = sigma + FeatureMaskFunction.__init__(self) + self.n_segments = TunableParameter( + name='n_segments', + current_value=n_segments, + dtype=int, + is_leaf=True, + space={'low': 100, 'high': 500, 'step': 10}, + ) + self.compactness = TunableParameter( + name='compactness', + current_value=compactness, + dtype=float, + is_leaf=True, + space={'low': 1e-2, 'high': 1e2, 'log': True}, + ) + self.sigma = TunableParameter( + name='sigma', + current_value=sigma, + dtype=float, + is_leaf=True, + space={'low': 0., 'high': 2., 'step': .1}, + ) + Tunable.__init__(self) + self.register_tunable_params([ + self.n_segments, self.compactness, self.sigma]) def __call__(self, inputs: torch.Tensor): return self._skseg_for_tensor( slic, inputs, - n_segments=self.n_segments, - compactness=self.compactness, - sigma=self.sigma, + n_segments=self.n_segments.current_value, + compactness=self.compactness.current_value, + sigma=self.sigma.current_value, ) - def get_tunables(self): - return { - 'n_segments': (float, {'low': 100, 'high': 500, 'step': 10}), - 'compactness': (float, {'low': 1e-2, 'high': 1e2, 'log': True}), - 'sigma': (float, {'low': 0., 'high': 2., 'step': .1}), - } - class Watershed(FeatureMaskFunction): def __init__( @@ -170,24 +234,32 @@ def __init__( markers: int, compactness: float, ): - super().__init__() - self.markers = markers - self.compactness = compactness + FeatureMaskFunction.__init__(self) + self.markers = TunableParameter( + name='markers', + current_value=markers, + dtype=int, + is_leaf=True, + space={'low': 10, 'high': 200, 'step': 10}, + ) + self.compactness = TunableParameter( + name='compactness', + current_value=compactness, + dtype=float, + is_leaf=True, + space={'low': 1e-6, 'high': 1., 'log': True}, + ) + Tunable.__init__(self) + self.register_tunable_params([self.markers, self.compactness]) def __call__(self, inputs: torch.Tensor): return self._skseg_for_tensor( watershed, inputs, - markers=self.markers, - compactness=self.compactness, + markers=self.markers.current_value, + compactness=self.compactness.current_value, ) - def get_tunables(self): - return { - 'markers': (int, {'low': 10, 'high': 200, 'step': 10}), - 'compactness': (float, {'low': 1e-6, 'high': 1., 'log': True}), - } - class NoMask1d(FeatureMaskFunction): def __init__(self): @@ -209,30 +281,20 @@ def __call__(self, inputs): return seq_masks.to(inputs.device) - -FEATURE_MASK_FUNCTIONS_FOR_IMAGE = { - 'felzenszwalb': Felzenszwalb, - 'quickshift': Quickshift, - 'slic': Slic, - # 'watershed': watershed_for_tensor, TODO: watershed -} - -FEATURE_MASK_FUNCTIONS_FOR_TEXT = { - 'no_mask_1d': NoMask1d, -} - -FEATURE_MASK_FUNCTIONS_FOR_TIME_SERIES = { - 'no_mask_2d': NoMask2d, -} - FEATURE_MASK_FUNCTIONS = { - **FEATURE_MASK_FUNCTIONS_FOR_IMAGE, - **FEATURE_MASK_FUNCTIONS_FOR_TEXT, - **FEATURE_MASK_FUNCTIONS_FOR_TIME_SERIES, + (float, 2): { + 'no_mask_1d': NoMask1d, + }, + (float, 3): { + 'no_mask_2d': NoMask2d, + }, + (float, 4): { + 'checkerboard': Checkerboard, + 'felzenszwalb': Felzenszwalb, + 'quickshift': Quickshift, + 'slic': Slic, + }, + (int, 2): { + 'no_mask_1d': NoMask1d, + } } - - -FeatureMaskMethod = Literal[ - 'felzenszwalb', 'quickshift', 'slic', 'watershed', 'no_mask_1d' -] -FeatureMaskMethodOrFunction = Union[FeatureMaskMethod, FeatureMaskFunction] diff --git a/pnpxai/explainers/utils/function_selectors.py b/pnpxai/explainers/utils/function_selectors.py index 538d2b7e..091dee24 100644 --- a/pnpxai/explainers/utils/function_selectors.py +++ b/pnpxai/explainers/utils/function_selectors.py @@ -1,4 +1,5 @@ from typing import Callable, Optional, Dict, Any +from collections import defaultdict class FunctionSelector: @@ -50,20 +51,42 @@ def __init__( self, data: Optional[Dict[str, Callable]] = None, default_kwargs: Optional[Dict[str, Any]] = None, + choicewise_default_kwargs: Optional[Dict[str, Any]] = None, + fallback_options: Optional[Dict[str, Callable]] = None, ): self._data = data or {} self._default_kwargs = default_kwargs or {} + self._choicewise_default_kwargs = choicewise_default_kwargs or defaultdict(dict) + self._fallback_options = fallback_options or {} @property def choices(self): return list(self._data.keys()) + @property + def default_kwargs(self): + return self._default_kwargs + + @property + def data(self): + return self._data + def add(self, key: str, value: Callable): self._data[key] = value return value + def add_default_kwargs(self, key, value, choice=None): + if choice is None: + self._default_kwargs[key] = value + else: + self._choicewise_default_kwargs[choice][key] = value + + def add_fallback_option(self, key, value): + self._fallback_options[key] = value + def get(self, key: str): - return self._data[key] + data = {**self._data, **self._fallback_options} + return data[key] def delete(self, key: str): return self._data.pop(key, None) @@ -73,8 +96,12 @@ def all(self): def select(self, key: str, **kwargs): fn_type = self.get(key) - kwargs = {**self._default_kwargs, **kwargs} + kwargs = { + **self._default_kwargs, + **self._choicewise_default_kwargs[key], + **kwargs, + } return fn_type(**kwargs) - def get_tunables(self): - return {'method': (list, {'choices': self.choices})} + def get_key(self, value): + return {v: k for k, v in self._data.items()}.get(value) diff --git a/pnpxai/explainers/utils/postprocess.py b/pnpxai/explainers/utils/postprocess.py index f7c43b9d..3823bfb9 100644 --- a/pnpxai/explainers/utils/postprocess.py +++ b/pnpxai/explainers/utils/postprocess.py @@ -1,45 +1,47 @@ from torch import Tensor from pnpxai.explainers.utils.base import UtilFunction +from pnpxai.explainers.base import Tunable +from pnpxai.explainers.types import TunableParameter -def sumpos(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.sum(channel_dim).clamp(min=0) +def sumpos(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.sum(pooling_dim).clamp(min=0) -def sumabs(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.sum(channel_dim).abs() +def sumabs(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.sum(pooling_dim).abs() -def l1norm(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.abs().sum(channel_dim) +def l1norm(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.abs().sum(pooling_dim) -def maxnorm(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.abs().max(channel_dim)[0] +def maxnorm(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.abs().max(pooling_dim)[0] -def l2norm(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.pow(2).sum(channel_dim).sqrt() +def l2norm(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.pow(2).sum(pooling_dim).sqrt() -def l2normsq(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.pow(2).sum(channel_dim) +def l2normsq(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.pow(2).sum(pooling_dim) -def possum(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.clamp(min=0).sum(channel_dim) +def possum(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.clamp(min=0).sum(pooling_dim) -def posmaxnorm(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.clamp(min=0).max(channel_dim)[0] +def posmaxnorm(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.clamp(min=0).max(pooling_dim)[0] -def posl2norm(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.clamp(min=0).pow(2).sum(channel_dim).sqrt() +def posl2norm(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.clamp(min=0).pow(2).sum(pooling_dim).sqrt() -def posl2normsq(attrs: Tensor, channel_dim: int) -> Tensor: - return attrs.clamp(min=0).pow(2).sum(channel_dim) +def posl2normsq(attrs: Tensor, pooling_dim: int) -> Tensor: + return attrs.clamp(min=0).pow(2).sum(pooling_dim) def identity(attrs: Tensor, *args, **kwargs) -> Tensor: @@ -53,100 +55,100 @@ class PoolingFunction(UtilFunction): attributions and highlight important features. Parameters: - channel_dim (int): + pooling_dim (int): The dimension of the input channels. This dimension is used by the pooling function to perform aggregation operations correctly. Notes: - `PoolingFunction` is intended to be subclassed. Concrete pooling methods should inherit from this class and implement the actual pooling logic. - - The pooling operation should be compatible with the `channel_dim` provided during + - The pooling operation should be compatible with the `pooling_dim` provided during initialization. """ - def __init__(self, channel_dim: int): + def __init__(self, pooling_dim: int): super().__init__() - self.channel_dim = channel_dim + self.pooling_dim = pooling_dim class SumPos(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return sumpos(attrs, self.channel_dim) + return sumpos(attrs, self.pooling_dim) class SumAbs(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return sumabs(attrs, self.channel_dim) + return sumabs(attrs, self.pooling_dim) class L1Norm(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return l1norm(attrs, self.channel_dim) + return l1norm(attrs, self.pooling_dim) class MaxNorm(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return maxnorm(attrs, self.channel_dim) + return maxnorm(attrs, self.pooling_dim) class L2Norm(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return l2norm(attrs, self.channel_dim) + return l2norm(attrs, self.pooling_dim) class L2NormSquare(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return l2normsq(attrs, self.channel_dim) + return l2normsq(attrs, self.pooling_dim) class PosSum(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return possum(attrs, self.channel_dim) + return possum(attrs, self.pooling_dim) class PosMaxNorm(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return posmaxnorm(attrs, self.channel_dim) + return posmaxnorm(attrs, self.pooling_dim) class PosL2Norm(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return posl2norm(attrs, self.channel_dim) + return posl2norm(attrs, self.pooling_dim) class PosL2NormSquare(PoolingFunction): - def __init__(self, channel_dim): - super().__init__(channel_dim) + def __init__(self, pooling_dim): + super().__init__(pooling_dim) def __call__(self, attrs: Tensor): - return posl2normsq(attrs, self.channel_dim) + return posl2normsq(attrs, self.pooling_dim) class Identity(UtilFunction): @@ -157,27 +159,34 @@ def __call__(self, inputs: Tensor): return identity(inputs) -POOLING_FUNCTIONS_FOR_IMAGE = { - 'sumpos': SumPos, - 'sumabs': SumAbs, - 'l1norm': L1Norm, - 'maxnorm': MaxNorm, - 'l2norm': L2Norm, - 'l2normsq': L2NormSquare, - 'possum': PosSum, - 'posmaxnorm': PosMaxNorm, - 'posl2norm': PosL2Norm, - 'posl2normsq': PosL2NormSquare, -} - -POOLING_FUNCTIONS_FOR_TEXT = POOLING_FUNCTIONS_FOR_IMAGE - -POOLING_FUNCTIONS_FOR_TIME_SERIES = {'identity': Identity} - POOLING_FUNCTIONS = { - **POOLING_FUNCTIONS_FOR_IMAGE, - **POOLING_FUNCTIONS_FOR_TEXT, - **POOLING_FUNCTIONS_FOR_TIME_SERIES, + (float, 2): { + 'identity': Identity, + }, + (float, 4): { + 'sumpos': SumPos, + 'sumabs': SumAbs, + 'l1norm': L1Norm, + 'maxnorm': MaxNorm, + 'l2norm': L2Norm, + 'l2normsq': L2NormSquare, + 'possum': PosSum, + 'posmaxnorm': PosMaxNorm, + 'posl2norm': PosL2Norm, + 'posl2normsq': PosL2NormSquare, + }, + (int, 2): { + 'sumpos': SumPos, + 'sumabs': SumAbs, + 'l1norm': L1Norm, + 'maxnorm': MaxNorm, + 'l2norm': L2Norm, + 'l2normsq': L2NormSquare, + 'possum': PosSum, + 'posmaxnorm': PosMaxNorm, + 'posl2norm': PosL2Norm, + 'posl2normsq': PosL2NormSquare, + }, } @@ -213,25 +222,20 @@ def __call__(self, attrs: Tensor): return minmax(attrs) -NORMALIZATION_FUNCTIONS_FOR_IMAGE = { - 'minmax': MinMax, - 'identity': Identity, -} - -NORMALIZATION_FUNCTIONS_FOR_TEXT = NORMALIZATION_FUNCTIONS_FOR_IMAGE - -NORMALIZATION_FUNCTIONS_FOR_TIME_SERIES = { - 'identity': Identity, -} - NORMALIZATION_FUNCTIONS = { - **NORMALIZATION_FUNCTIONS_FOR_IMAGE, - **NORMALIZATION_FUNCTIONS_FOR_TEXT, - **NORMALIZATION_FUNCTIONS_FOR_TIME_SERIES, + (float, 2): { + 'identity': Identity, + }, + (float, 4): { + 'minmax': MinMax, + }, + (int, 2): { + 'minmax': MinMax, + }, } -class PostProcessor(UtilFunction): +class PostProcessor(UtilFunction, Tunable): """ A class that applies a series of post-processing steps to the output of an attribution method. This includes pooling and normalization functions to refine and transform the attributions. @@ -245,7 +249,7 @@ class PostProcessor(UtilFunction): scaling or adjusting the attributions to a certain range or format. Methods: - from_name(pooling_method: str, normalization_method: str, channel_dim: int) -> PostProcessor: + from_name(pooling_method: str, normalization_method: str, pooling_dim: int) -> PostProcessor: Creates a `PostProcessor` instance using the specified method names for pooling and normalization, and the channel dimension. This is a convenience method for instantiating `PostProcessor` with predefined methods. @@ -261,31 +265,47 @@ class PostProcessor(UtilFunction): def __init__( self, - pooling_fn: PoolingFunction, - normalization_fn: NormalizationFunction, - ): - self.pooling_fn = pooling_fn - self.normalization_fn = normalization_fn - - @classmethod - def from_name( - cls, - pooling_method: str, - normalization_method: str, - channel_dim: int, + modality, + pooling_method=None, + normalization_method=None, ): - return cls( - pooling_fn=POOLING_FUNCTIONS[pooling_method](channel_dim), - normalization_fn=NORMALIZATION_FUNCTIONS[normalization_method](), + UtilFunction.__init__(self) + self.modality = modality + self._pooling_method = TunableParameter( + name='pooling_method', + current_value=pooling_method or modality.util_functions['pooling_fn'].choices[0], + dtype=str, + is_leaf=True, + space={'choices': modality.util_functions['pooling_fn'].choices} ) + self._normalization_method = TunableParameter( + name='normalization_method', + current_value=normalization_method or modality.util_functions['normalization_fn'].choices[0], + dtype=str, + is_leaf=True, + space={'choices': modality.util_functions['normalization_fn'].choices} + ) + Tunable.__init__(self) + self.register_tunable_params([ + self._pooling_method, self._normalization_method]) + + @property + def pooling_method(self): + return self._pooling_method.current_value + + @property + def normalization_method(self): + return self._normalization_method.current_value + + @property + def pooling_fn(self): + return self.modality.util_functions['pooling_fn'].select(self.pooling_method) + + @property + def normalization_fn(self): + return self.modality.util_functions['normalization_fn'].select(self.normalization_method) def __call__(self, attrs): pooled = self.pooling_fn(attrs) normalized = self.normalization_fn(pooled) return normalized - - def get_tunables(self): - return { - 'pooling_fn': (PoolingFunction, {}), - 'normalization_fn': (NormalizationFunction, {}), - } diff --git a/pnpxai/explainers/utils/types.py b/pnpxai/explainers/utils/types.py new file mode 100644 index 00000000..8bdce3c6 --- /dev/null +++ b/pnpxai/explainers/utils/types.py @@ -0,0 +1,20 @@ +from typing import Tuple, Union +from pnpxai.explainers.utils.baselines import BaselineFunction +from pnpxai.explainers.utils.feature_masks import FeatureMaskFunction +from pnpxai.explainers.utils.postprocess import PostProcessor + + +BaselineFunctionOrTupleOfBaselineFunctions = Union[ + BaselineFunction, + Tuple[BaselineFunction, ...] +] + +FeatureMaskFunctionOrTupleOfFeatureMaskFunctions = Union[ + FeatureMaskFunction, + Tuple[FeatureMaskFunction, ...] +] + +PostProcessorOrTupleOfPostProcessors = Union[ + PostProcessor, + Tuple[PostProcessor, ...] +] diff --git a/pnpxai/explainers/utils/utils.py b/pnpxai/explainers/utils/utils.py index 99c04e41..b5099ec0 100644 --- a/pnpxai/explainers/utils/utils.py +++ b/pnpxai/explainers/utils/utils.py @@ -1,14 +1,10 @@ -from typing import Sequence, Any, Union -import functools - -import torch from torch import nn -from captum.attr._utils.input_layer_wrapper import ModelInputWrapper -from skimage.segmentation import felzenszwalb +from captum.attr._utils.input_layer_wrapper import ModelInputWrapper, InputIdentity from pnpxai.core.detector import symbolic_trace from pnpxai.core.detector.utils import get_target_module_of, find_nearest_user_of from pnpxai.core.detector.types import Convolution, Pool +from pnpxai.core.utils import ModelWrapper def find_cam_target_layer(model: nn.Module) -> nn.Module: @@ -29,63 +25,15 @@ def find_cam_target_layer(model: nn.Module) -> nn.Module: return target_module -def default_feature_mask_fn_image(inputs: torch.Tensor, scale=250): - feature_mask = [ - torch.tensor(felzenszwalb(input.permute(1,2,0).detach().cpu().numpy(), scale=scale)) - for input in inputs - ] - return torch.LongTensor(torch.stack(feature_mask)).to(inputs.device) - - -def default_feature_mask_fn_text(inputs): return None - - -def default_feature_mask_fn_image_text(images, text): - fm_img = default_feature_mask_fn_image(images) - bsz, text_len = text.size() - fm_text = torch.arange(text_len).repeat(bsz).view(bsz, text_len) - fm_text += fm_img.max().item() + 1 - return fm_img, fm_text - - -def get_default_feature_mask_fn(modality): - if modality == 'image': - return default_feature_mask_fn_image - elif modality == 'text': - return default_feature_mask_fn_text - elif modality == ('image', 'text'): - return default_feature_mask_fn_image_text - else: - raise NotImplementedError(f"default_feature_mask_fn for '{modality}' not supported.") - - -def default_baseline_fn_image(inputs: torch.Tensor): - return torch.zeros_like(inputs) - -def default_baseline_fn_text(inputs: torch.Tensor, mask_token_id: int=0): - return torch.ones_like(inputs, dtype=torch.long) * mask_token_id - -def default_baseline_fn_image_text(images: torch.Tensor, text: torch.Tensor, mask_token_id: int=0): - return default_baseline_fn_image(images), default_baseline_fn_text(text, mask_token_id) - -def get_default_baseline_fn(modality, mask_token_id=0): - if modality == 'image': - return default_baseline_fn_image - elif modality == 'text': - return default_baseline_fn_text - elif modality == ('image', 'text'): - return functools.partial(default_baseline_fn_image_text, mask_token_id=mask_token_id) - else: - raise NotImplementedError(f"default_baseline_fn for '{modality}' not supported.") - -def captum_wrap_model_input(model): - if isinstance(model, nn.DataParallel): - return ModelInputWrapper(model.module) - return ModelInputWrapper(model) - - +class ModelWrapperForLayerAttribution(ModelInputWrapper): + def __init__( + self, + wrapped_model: ModelWrapper, + ): + super().__init__(wrapped_model) -def _format_to_tuple(obj: Union[Any, Sequence[Any]]): - if isinstance(obj, Sequence): - return tuple(obj) - return (obj,) \ No newline at end of file + # override + self.arg_name_list = wrapped_model.required_order + self.input_maps = nn.ModuleDict({ + arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list + }) diff --git a/pnpxai/explainers/var_grad.py b/pnpxai/explainers/var_grad.py index ac646a22..36118b99 100644 --- a/pnpxai/explainers/var_grad.py +++ b/pnpxai/explainers/var_grad.py @@ -1,10 +1,4 @@ -from typing import ( - Optional, - Callable, - Tuple, - Union, - Sequence, -) +from typing import Optional, Callable, Tuple, Union, Any, List from torch import Tensor from torch.nn.modules import Module @@ -12,76 +6,80 @@ from pnpxai.core.detector.types import Linear, Convolution, LSTM, RNN, Attention from pnpxai.utils import format_into_tuple, format_out_tuple_if_single from pnpxai.explainers.smooth_grad import SmoothGrad +from pnpxai.explainers.types import TargetLayerOrTupleOfTargetLayers class VarGrad(SmoothGrad): - """ + """ VarGrad explainer. Supported Modules: `Linear`, `Convolution`, `LSTM`, `RNN`, `Attention` Parameters: model (Module): The PyTorch model for which attribution is to be computed. - noise_level (float): The noise level added during attribution. - n_iter (int): The number of iterations, the input is modified. - layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained. + noise_level (float): The noise level added during attribution. + n_iter (int): The number of iterations, the input is modified. + target_layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained. forward_arg_extractor: A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned. - additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s). + additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s). **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer Reference: Lorenz Richter, Ayman Boustati, Nikolas Nüsken, Francisco J. R. Ruiz, Ömer Deniz Akyildiz. VarGrad: A Low-Variance Gradient Estimator for Variational Inference. """ + SUPPORTED_MODULES = [Linear, Convolution, LSTM, RNN, Attention] + SUPPORTED_DTYPES = [float, int] + SUPPORTED_NDIMS = [2, 4] + alias = ['var_grad', 'vg'] + + def __init__( + self, + model: Module, + noise_level: float = .1, + n_iter: int = 20, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + target_layer: Optional[TargetLayerOrTupleOfTargetLayers] = None, + n_classes: Optional[int] = None, + ) -> None: + super().__init__( + model=model, + noise_level=noise_level, + n_iter=n_iter, + target_input_keys=target_input_keys, + additional_input_keys=additional_input_keys, + output_modifier=output_modifier, + target_layer=target_layer, + n_classes=n_classes, + ) - SUPPORTED_MODULES = [Linear, Convolution, LSTM, RNN, Attention] - - def __init__( - self, - model: Module, - noise_level: float=.1, - n_iter: int=20, - forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]=None, - additional_forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]=None, - layer: Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]=None, - n_classes: Optional[int]=None, - ) -> None: - super().__init__( - model=model, - noise_level=noise_level, - n_iter=n_iter, - forward_arg_extractor=forward_arg_extractor, - additional_forward_arg_extractor=additional_forward_arg_extractor, - layer=layer, - n_classes=n_classes, - ) - - def attribute( - self, - inputs: Union[Tensor, Tuple[Tensor]], - targets: Tensor, - ) -> Union[Tensor, Tuple[Tensor]]: - """ + def attribute( + self, + inputs: Union[Tensor, Tuple[Tensor]], + targets: Tensor, + ) -> Union[Tensor, Tuple[Tensor]]: + """ Computes attributions for the given inputs and targets. - Args: inputs (torch.Tensor): The input data. targets (torch.Tensor): The target labels for the inputs. - + Returns: torch.Tensor: The result of the explanation. """ - forward_args, additional_forward_args = self._extract_forward_args(inputs) - with self.attributor() as attributor: - avg_grads, avg_grads_sq = attributor.forward( - forward_args, - targets, - additional_forward_args, - return_squared=True, - ) - vargrads = tuple( - avg_grad_sq - avg_grad for avg_grad_sq, avg_grad in zip( - format_into_tuple(avg_grads_sq), - format_into_tuple(avg_grads), - ) - ) - return format_out_tuple_if_single(vargrads) + forward_args, additional_forward_args = self.format_inputs(inputs) + with self.attributor() as attributor: + avg_grads, avg_grads_sq = attributor.forward( + format_out_tuple_if_single(forward_args), + targets, + additional_forward_args, + return_squared=True, + ) + vargrads = tuple( + avg_grad_sq - avg_grad for avg_grad_sq, avg_grad in zip( + format_into_tuple(avg_grads_sq), + format_into_tuple(avg_grads), + ) + ) + return format_out_tuple_if_single(vargrads) diff --git a/pnpxai/explainers/zennit/attribution.py b/pnpxai/explainers/zennit/attribution.py index 2fb369df..6ae1f960 100644 --- a/pnpxai/explainers/zennit/attribution.py +++ b/pnpxai/explainers/zennit/attribution.py @@ -1,11 +1,10 @@ -from typing import Optional, Sequence, Union, List, Tuple, Callable, Any, Dict, Literal +from typing import Optional, Sequence, Union, List, Tuple, Callable, Any, Dict from collections import defaultdict import threading import torch import torchvision.transforms.functional as TF -from torch import Tensor, device from torch.nn import Module from captum._utils.common import _run_forward, _sort_key_list, _reduce_list from captum._utils.gradient import ( @@ -27,12 +26,13 @@ class Gradient(ZennitGradient): def __init__( self, model: Module, - composite: Optional[Composite]=None, + composite: Optional[Composite] = None, attr_output=None, create_graph=False, retain_graph=None, ) -> None: - super().__init__(model, composite, attr_output, create_graph, retain_graph) + super().__init__( + model, composite, attr_output, create_graph, retain_graph) def grad(self, forward_args, targets, additional_forward_args=None): self._process_forward_args_before_grad(forward_args) @@ -75,12 +75,13 @@ def __init__( self, model: Module, layer: Union[str, Module, Sequence[Union[str, Module]]], - composite: Optional[Composite]=None, + composite: Optional[Composite] = None, attr_output=None, create_graph=False, retain_graph=None, ) -> None: - super().__init__(model, composite, attr_output, create_graph, retain_graph) + super().__init__( + model, composite, attr_output, create_graph, retain_graph) self.layer = layer def grad(self, forward_args, targets, additional_forward_args=None): @@ -101,23 +102,24 @@ class SmoothGradient(Gradient): def __init__( self, model: Module, - noise_level: Union[float, List[float]]=.1, - n_iter: int=20, - composite: Optional[Composite]=None, + noise_level: Union[float, List[float]] = .1, + n_iter: int = 20, + composite: Optional[Composite] = None, attr_output=None, create_graph=None, retain_graph=None, ) -> None: - super().__init__(model, composite, attr_output, create_graph, retain_graph) + super().__init__( + model, composite, attr_output, create_graph, retain_graph) self.noise_level = noise_level self.n_iter = n_iter def forward( self, - inputs: TensorOrTupleOfTensors, + inputs: Tensor, targets: Tensor, - additional_forward_args: Optional[TensorOrTupleOfTensors]=None, - return_squared: bool=False, + additional_forward_args: Optional[TensorOrTupleOfTensors] = None, + return_squared: bool = False, ): dims = tuple(range(1, inputs.ndim)) std = self.noise_level * (inputs.amax(dims, keepdim=True) - inputs.amin(dims, keepdim=True)) @@ -224,8 +226,8 @@ def _forward_layer_distributed_eval_with_noise( forward_hook_with_return: bool = True, require_layer_grads: bool = True, ) -> Union[ - Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor], - Dict[Module, Dict[device, Tuple[Tensor, ...]]], + Tuple[Dict[Module, Dict[torch.device, Tuple[Tensor, ...]]], Tensor], + Dict[Module, Dict[torch.device, Tuple[Tensor, ...]]], ]: r""" A helper function that allows to set a hook on model's `layer`, run the forward @@ -315,7 +317,6 @@ def forward_hook(module, inp, out=None): return saved_layer - def compute_layer_gradients_and_eval_with_noise( forward_fn: Callable, layer: List[Union[str, Module]], diff --git a/pnpxai/explainers/zennit/base.py b/pnpxai/explainers/zennit/base.py index f9f8469e..a9f87cc4 100644 --- a/pnpxai/explainers/zennit/base.py +++ b/pnpxai/explainers/zennit/base.py @@ -1,26 +1,30 @@ -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union, List, Any from torch import Tensor from torch.nn import Module from pnpxai.explainers.base import Explainer from pnpxai.utils import format_into_tuple -from pnpxai.explainers.base import Explainer class ZennitExplainer(Explainer): def __init__( self, model: Module, - forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]] = None, - additional_forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]] = None, - n_classes: Optional[int]=None, + target_input_keys: Optional[List[Union[str, int]]] = None, + additional_input_keys: Optional[List[Union[str, int]]] = None, + output_modifier: Optional[Callable[[Any], Tensor]] = None, + n_classes: Optional[int] = None, **kwargs ) -> None: - super().__init__(model, forward_arg_extractor, additional_forward_arg_extractor) + super().__init__( + model, + target_input_keys, + additional_input_keys, + output_modifier, + ) self.n_classes = n_classes - def __init_subclass__(cls) -> None: cls.attribute = set_n_classes_before(cls.attribute) return super().__init_subclass__() @@ -35,7 +39,8 @@ def wrapper(*args, **kwargs): if isinstance(inputs, Tensor): inputs = format_into_tuple(inputs) if self.n_classes is None: - outputs = self.model(*inputs) + formatted = self._wrapped_model.format_inputs(inputs) + outputs = self._wrapped_model(*formatted) self.n_classes = outputs.shape[-1] return func(*args, **kwargs) - return wrapper \ No newline at end of file + return wrapper diff --git a/pnpxai/explainers/zennit/hooks.py b/pnpxai/explainers/zennit/hooks.py index ec20696b..508b972e 100644 --- a/pnpxai/explainers/zennit/hooks.py +++ b/pnpxai/explainers/zennit/hooks.py @@ -1,4 +1,5 @@ -from zennit.core import RemovableHandleList, RemovableHandle, Hook, BasicHook +from zennit.core import RemovableHandleList, RemovableHandle, Hook + class HookWithKwargs(Hook): '''Base class for hooks to be used to compute layer-wise attributions.''' @@ -22,11 +23,3 @@ def register(self, module): module.register_forward_hook(self.post_forward), module.register_forward_hook(self.forward, with_kwargs=True), ]) - - -# myeongjin hi -class BasicHookWithRelevanceModifier(BasicHook): - def __init__( - ) -> None: - pass - diff --git a/pnpxai/llm/fact_score.py b/pnpxai/llm/fact_score.py index c9059d8e..537b2555 100644 --- a/pnpxai/llm/fact_score.py +++ b/pnpxai/llm/fact_score.py @@ -15,7 +15,7 @@ def __init__( atomic_fact_generator: Callable[[str], List[str]], knowledge_source: Callable[[str, str], List[Dict[str, str]]], scorer: Callable[[str, str, List[Dict[str, str]]], Any], - aggregate_fn: Optional[Callable[[List[Any]], Any]]=None, + aggregate_fn: Optional[Callable[[List[Any]], Any]] = None, ) -> None: self.atomic_fact_generator = atomic_fact_generator self.knowledge_source = knowledge_source diff --git a/pnpxai/utils.py b/pnpxai/utils.py index 32965821..6bf8b130 100644 --- a/pnpxai/utils.py +++ b/pnpxai/utils.py @@ -1,4 +1,5 @@ import random +import re from io import TextIOWrapper from contextlib import contextmanager from typing import Sequence, Callable, Any, Union, Optional, Tuple, TypeVar @@ -102,6 +103,8 @@ def linear_from_params(weight: Tensor, bias: Optional[Tensor] = None) -> nn.Line T = TypeVar('T') + + def format_into_tuple(obj: T) -> Tuple[T]: if isinstance(obj, Sequence) and not isinstance(obj, str): return tuple(obj) @@ -123,3 +126,29 @@ def format_into_tuple_all(**kwargs): def generate_param_key(*args): # ensure the uniqueness of param name of optuna return '.'.join([str(arg) for arg in args if arg is not None]) + + +def _camel_to_snake(name): + return re.sub(r'(?= 0: with gr.Row(): - for explainer_id in outputs: - self._plot_output_column(outputs[explainer_id]) - + for explainer_key in outputs: + self._plot_output_column(outputs[explainer_key]) if on_options_change is not None: btn_options_change.click( on_options_change, @@ -80,7 +88,8 @@ def render_outputs(outputs): state_input_id, select_explainers, select_metrics, - select_postprocessors, + select_pooling_fn, + select_norm_fn, ], [state_data], ) @@ -91,16 +100,16 @@ def _plot_output_column(self, explainer_data: Dict[int, dict]): with gr.Column(): if len(explainer_data) > 0: datum = next(iter(explainer_data.values())) - gr.Markdown(f"{datum['explainer'].__class__.__name__}") - explanation = datum["postprocessed"].cpu().detach().squeeze().numpy() + gr.Markdown(f"{datum.explainer.__class__.__name__}") + explanation = datum.explanations.cpu().detach().squeeze().numpy() exp_plot = plt.figure() axes = exp_plot.subplots(1, 1) axes.imshow(explanation, cmap="twilight") plot_title = "\n".join( [ - f"{explainer_data[metric_id]['metric'].__class__.__name__}: {explainer_data[metric_id]['evaluation'].item():.2f}" - for metric_id in explainer_data + f"{explainer_data[metric_key].metric.__class__.__name__}: {explainer_data[metric_key].evaluations.item():.2f}" + for metric_key in explainer_data ] ) gr.Plot(exp_plot, label=plot_title, show_label=True) @@ -124,6 +133,7 @@ def __init__( experiment: Experiment, input_visualizer: Optional[callable] = None, ): + assert len(format_into_tuple(experiment.modality)) == 1, 'Multimodal not supported' self.experiment = experiment self._vi: VisualizerInterface = None self._input_visualizer = input_visualizer @@ -133,20 +143,17 @@ def build(self): self._vi = VisualizerInterface() self._vi.build( self._get_input_data(), - explainers=self._get_explainers_options(), - metrics=self._get_metrics_options(), - postprocessors=self._get_postprocessors_options(), + explainer_options=self._get_options(self.experiment.explainers), + metric_options=self._get_options(self.experiment.metrics), + pooling_fn_options=self._get_options( + self.experiment.modality.util_functions['pooling_fn']), + normalization_fn_options=self._get_options( + self.experiment.modality.util_functions['normalization_fn']), on_options_change=self._on_options_change, ) - def _get_explainers_options(self) -> List[Tuple[int, str]]: - return list(zip(*self.experiment.manager.get_explainers())) - - def _get_metrics_options(self) -> List[Tuple[int, str]]: - return list(zip(*self.experiment.manager.get_metrics())) - - def _get_postprocessors_options(self) -> List[Tuple[int, str]]: - return list(zip(*self.experiment.manager.get_postprocessors())) + def _get_options(self, selector): + return [(v.__name__, k) for k, v in selector.data.items()] def _get_input_data(self) -> List[np.ndarray]: return [ @@ -161,19 +168,22 @@ def _get_input_data(self) -> List[np.ndarray]: def _on_options_change( self, data_id: int, - explainer_ids: Sequence[int], - metric_ids: Sequence[int], - postprocessor_id: int, + explainer_keys: Sequence[int], + metric_keys: Sequence[int], + pooling_method, + normalization_method, + # postprocessor_id: int, ): outputs = defaultdict(dict) # Convert data_id from range to actual data_id data_id = self.experiment.manager.get_data()[1][data_id] - for explainer_id in explainer_ids: - for metric_id in metric_ids: - outputs[explainer_id][metric_id] = self.experiment.run_batch( - explainer_id=explainer_id, - metric_id=metric_id, - postprocessor_id=postprocessor_id, + for explainer_key in explainer_keys: + for metric_key in metric_keys: + outputs[explainer_key][metric_key] = self.experiment.run_batch( + explainer_key=explainer_key, + metric_key=metric_key, + pooling_method=pooling_method, + normalization_method=normalization_method, data_ids=[data_id], ) return outputs diff --git a/scripts/helpers.py b/scripts/helpers.py new file mode 100644 index 00000000..c75cf0b3 --- /dev/null +++ b/scripts/helpers.py @@ -0,0 +1,245 @@ +from typing import Optional, List +import os +import json +import requests +import functools +from io import BytesIO +from pathlib import Path +from urllib3 import disable_warnings +from urllib3.exceptions import InsecureRequestWarning + +import torch +import torchvision +from torch import Tensor +from torch.utils.data import Dataset, Subset, DataLoader +from transformers import BertTokenizer, BertForSequenceClassification +from transformers import ViltForQuestionAnswering, ViltProcessor + +from PIL import Image + + +# datasets + +class ImageNetDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.root_dir = root_dir + self.img_dir = os.path.join(self.root_dir, 'samples/') + self.label_dir = os.path.join( + self.root_dir, 'imagenet_class_index.json') + + with open(self.label_dir) as json_data: + self.idx_to_labels = json.load(json_data) + + self.img_names = os.listdir(self.img_dir) + self.img_names.sort() + + self.transform = transform + + def __len__(self): + return len(self.img_names) + + def __getitem__(self, idx): + img_path = os.path.join(self.img_dir, self.img_names[idx]) + image = Image.open(img_path).convert('RGB') + label = idx + + if self.transform: + image = self.transform(image) + + return image, label + + def idx_to_label(self, idx): + return self.idx_to_labels[str(idx)][1] + + +def get_imagenet_dataset( + transform, + subset_size: int = 100, # ignored if indices is not None + root_dir="./data/ImageNet", + indices: Optional[List[int]] = None, +): + os.chdir(Path(__file__).parent) # ensure path + dataset = ImageNetDataset(root_dir=root_dir, transform=transform) + if indices is not None: + return Subset(dataset, indices=indices) + indices = list(range(len(dataset))) + subset = Subset(dataset, indices=indices[:subset_size]) + return subset + + +class IMDBDataset(Dataset): + def __init__(self, split='test'): + super().__init__() + # data_iter = IMDB(split=split) + # self.annotations = [(line, label-1) for label, line in tqdm(data_iter)] + + def __len__(self): + return len(self.annotations) + + def __getitem__(self, idx): + return self.annotations[idx] + + +def get_imdb_dataset(split='test'): + return IMDBDataset(split=split) + +disable_warnings(InsecureRequestWarning) + + +class VQADataset(Dataset): + def __init__(self): + super().__init__() + res = requests.get('https://visualqa.org/balanced_data.json') + self.annotations = eval(res.text) + + def __len__(self): + return len(self.annotations) + + def __getitem__(self, idx): + data = self.annotations[idx] + if isinstance(data['original_image'], str): + print(f"Requesting {data['original_image']}...") + res = requests.get(data['original_image'], verify=False) + img = Image.open(BytesIO(res.content)).convert('RGB') + data['original_image'] = img + return data['original_image'], data['question'], data['original_answer'] + + +def get_vqa_dataset(): + return VQADataset() + + +# models +def get_torchvision_model(model_name): + weights = torchvision.models.get_model_weights(model_name).DEFAULT + model = torchvision.models.get_model(model_name, weights=weights).eval() + transform = weights.transforms() + return model, transform + + +class Bert(BertForSequenceClassification): + def forward(self, input_ids, token_type_ids, attention_mask): + return super().forward( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask + ).logits + + +def get_bert_model(model_name, num_labels): + return Bert.from_pretrained(model_name, num_labels=num_labels) + + +class Vilt(ViltForQuestionAnswering): + def forward( + self, + pixel_values, + input_ids, + token_type_ids, + attention_mask, + pixel_mask, + ): + return super().forward( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + pixel_mask=pixel_mask, + ).logits + + +def get_vilt_model(model_name): + return Vilt.from_pretrained(model_name) + + +# utils +def img_to_np(img): return img.permute(1, 2, 0).detach().numpy() + + +def denormalize_image(inputs, mean, std): + return img_to_np( + inputs + * Tensor(std)[:, None, None] + + Tensor(mean)[:, None, None] + ) + + +def bert_collate_fn(batch, tokenizer=None): + inputs = tokenizer( + [d[0] for d in batch], + padding=True, + truncation=True, + return_tensors='pt', + ) + labels = torch.tensor([d[1] for d in batch]) + return tuple(inputs.values()), labels + + +def get_bert_tokenizer(model_name): + return BertTokenizer.from_pretrained(model_name) + + +def get_vilt_processor(model_name): + return ViltProcessor.from_pretrained(model_name) + + +def vilt_collate_fn(batch, processor=None, label2id=None): + imgs = [d[0] for d in batch] + qsts = [d[1] for d in batch] + inputs = processor( + images=imgs, + text=qsts, + padding=True, + truncation=True, + return_tensors='pt', + ) + labels = torch.tensor([label2id[d[2]] for d in batch]) + return ( + inputs['pixel_values'], + inputs['input_ids'], + inputs['token_type_ids'], + inputs['attention_mask'], + inputs['pixel_mask'], + labels, + ) + + +def load_model_and_dataloader_for_tutorial(modality, device): + if modality == 'image': + model, transform = get_torchvision_model('resnet18') + model = model.to(device) + model.eval() + dataset = get_imagenet_dataset(transform) + loader = DataLoader(dataset, batch_size=8, shuffle=False) + return model, loader, transform + elif modality == 'text': + model = get_bert_model( + 'fabriceyhc/bert-base-uncased-imdb', num_labels=2) + model = model.to(device) + model.eval() + dataset = get_imdb_dataset(split='test') + tokenizer = get_bert_tokenizer('fabriceyhc/bert-base-uncased-imdb') + loader = DataLoader( + dataset, + batch_size=8, + shuffle=False, + collate_fn=functools.partial(bert_collate_fn, tokenizer=tokenizer) + ) + return model, loader, tokenizer + elif modality == ('image', 'text'): + model = get_vilt_model('dandelin/vilt-b32-finetuned-vqa') + model.to(device) + model.eval() + dataset = get_vqa_dataset() + processor = get_vilt_processor('dandelin/vilt-b32-finetuned-vqa') + loader = DataLoader( + dataset, + batch_size=2, + shuffle=False, + collate_fn=functools.partial( + vilt_collate_fn, + processor=processor, + label2id=model.config.label2id, + ), + ) + return model, loader, processor diff --git a/scripts/test_baf.py b/scripts/test_baf.py new file mode 100644 index 00000000..b6f8f394 --- /dev/null +++ b/scripts/test_baf.py @@ -0,0 +1,562 @@ +''' +This script implements benchmark test on various explainers and gets the best +performing explainer on BAF (bank account fraud detection) task, using PnPXAI +framework. + +Prerequisites: +- This script makes `--data_dir` and downloads baf data from kaggle, if it does + not exist. Please install kaggle e.g. + + ```bash + pip install kaggle + ``` + +Flags: +--fast_dev_run: runs the script with small samples and trials + +Example: + +```bash +python -m scripts.test_baf --model tab_resnet --data_dir baf --log_dir baf --fast_dev_run +``` +''' + +import argparse +import os +import re +import itertools +from collections import defaultdict +from pprint import pprint + +import numpy as np +import pandas as pd +import torch +from torch import nn +from tqdm import tqdm +from torch.utils.data import Dataset, DataLoader, Subset +from PIL import Image +from sklearn.preprocessing import StandardScaler, OneHotEncoder +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report, roc_auc_score +import xgboost + +from pnpxai import XaiRecommender, Experiment, AutoExplanation +from pnpxai.core.modality.modality import Modality +from pnpxai.explainers import Lime, KernelShap +from pnpxai.evaluator.metrics import AbPC, MoRF, LeRF + + +BAF_MODEL_CHOICES = ['tab_resnet', 'xgb'] + + +parser = argparse.ArgumentParser() +parser.add_argument('--model', type=str, choices=BAF_MODEL_CHOICES, required=True) +parser.add_argument('--data_dir', type=str, required=True) +parser.add_argument('--log_dir', type=str, required=True) +parser.add_argument('--seed', type=int, default=42) +parser.add_argument('--num_workers', type=int, default=0) +parser.add_argument('--batch_size', type=int, default=128) +parser.add_argument('--disable_gpu', action='store_true') +parser.add_argument('--fast_dev_run', action='store_true') + + +#------------------------------------------------------------------------------# +#----------------------------------- data -------------------------------------# +#------------------------------------------------------------------------------# + +class PandasDataset(Dataset): + def __init__(self, inputs: pd.DataFrame, labels: pd.Series): + super().__init__() + self.inputs = inputs + self.labels = labels + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, idx): + return self.inputs.iloc[idx], self.labels.iloc[idx] + + +def collate_fn(batch): + inputs = torch.stack([torch.from_numpy(d[0].values) for d in batch]).to(torch.float) + labels = torch.tensor([d[1] for d in batch], dtype=torch.long) + return inputs, labels + + +def download_data(root_dir): + ZIPFILE_NAME = 'bank-account-fraud-dataset-neurips-2022' + FILE_NAME = 'Base.csv' + raw_dir = os.path.join(root_dir, 'raw') + zipfile_path = os.path.join(raw_dir, ZIPFILE_NAME) + file_path = os.path.join(raw_dir, FILE_NAME) + if not os.path.exists(file_path): + print("Downloading the dataset...") + os.makedirs(raw_dir, exist_ok=True) + os.system(f"kaggle datasets download -d sgpjesus/{ZIPFILE_NAME} -p {raw_dir}") + os.system(f"unzip {zipfile_path} -d {raw_dir}") + return file_path + + +def preprocess_data(file_path, mul=2, random_state=42): + df = pd.read_csv(file_path) + is_train = df['month'] < 5 + is_valid = (df['month'] >= 5) & (df['month'] < 6) + is_test = df['month'] >= 6 + is_fraud = df['fraud_bool'] == 1 + + balance_df = lambda is_data: pd.concat([ + df.loc[is_data&is_fraud], + df.loc[is_data&(~is_fraud)].sample((is_data&is_fraud).sum()*mul, random_state=random_state) + ]).reset_index(drop=True).drop(columns=['month']) + dfs = {'train': balance_df(is_train), 'valid': balance_df(is_valid), 'test': balance_df(is_test)} + + scaler = StandardScaler() + ohe = OneHotEncoder() + preprocessed = {} + for split, features in dfs.items(): + preprocessed[f'y_{split}'] = features.pop('fraud_bool') + + float_cols = features.select_dtypes(include=[float, int]).columns + float_features = pd.concat([features.pop(c) for c in float_cols], axis=1) + float_features_data = scaler.fit_transform(float_features) if split == 'train' else \ + scaler.transform(float_features) + float_features = pd.DataFrame( + data=float_features_data, + index=float_features.index, + columns=float_features.columns, + ) + + categ_cols = features.select_dtypes(include=['object', int]).columns + categ_features = pd.concat([features.pop(c) for c in categ_cols], axis=1) + categ_features_data = ohe.fit_transform(categ_features) if split == 'train' else \ + ohe.transform(categ_features) + categ_features = pd.DataFrame( + data=categ_features_data.toarray(), + index=categ_features.index, + columns=[f'{c}_{v}' for c, values in zip(categ_cols, ohe.categories_) for v in values] + ) + preprocessed[f'x_{split}'] = pd.concat([float_features, categ_features], axis=1) + return preprocessed + + +#------------------------------------------------------------------------------# +#----------------------------------- model ------------------------------------# +#------------------------------------------------------------------------------# + +# tab resnet +class ResNetBlock(nn.Module): + def __init__(self, in_features, out_features): + super(ResNetBlock, self).__init__() + self.bn = nn.BatchNorm1d(in_features) + self.fc1 = nn.Linear(in_features, out_features) + self.fc2 = nn.Linear(out_features, out_features) + self.dropout = nn.Dropout(0.2) + + def forward(self, x): + y = torch.relu(self.fc1(self.bn(x))) + y = self.dropout(y) + y = self.fc2(y) + y = self.dropout(y) + return torch.add(x, y) + + +class TabResNet(nn.Module): + def __init__(self, in_features, out_features, num_blocks=1, embedding_dim=128): + super(TabResNet, self).__init__() + self.embedding = nn.Linear(in_features, embedding_dim) + self.res_blocks = [] + for i in range(num_blocks): + self.res_blocks.append(ResNetBlock(embedding_dim, embedding_dim)) + self.res_blocks = nn.ModuleList(self.res_blocks) + self.bn = nn.BatchNorm1d(embedding_dim) + self.fc = nn.Linear(embedding_dim, out_features) + + def forward(self, x): + x = self.embedding(x) + for block in self.res_blocks: + x = block(x) + x = torch.relu(self.bn(x)) + x = self.fc(x) + return x + + +def train( + model_dir, + checkpoint_nm, + model, + train_loader, + valid_inputs, + valid_labels, + device, +): + os.makedirs(model_dir, exist_ok=True) + model_path = os.path.join(model_dir, checkpoint_nm) + if os.path.exists(model_path): + return model_path + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + loss_fn = nn.CrossEntropyLoss() + + model.to(device) + + # Train the model + for epoch in range(10): + model.train() + train_loss = 0 + for i, (inputs, labels) in enumerate(train_loader): + inputs, labels = inputs.to(device), labels.to(device) + optimizer.zero_grad() + output = model(inputs) + loss = loss_fn(output, labels) + loss.backward() + optimizer.step() + train_loss += loss.item() + print(f"Epoch {epoch}, Loss: {train_loss / len(train_loader)}") + + model.eval() + valid_loss = 0 + with torch.no_grad(): + y_pred = model(valid_inputs) + loss = loss_fn(y_pred, valid_labels) + valid_loss = loss.item() + + print(f"Validation Loss: {valid_loss}") + y_pred = torch.argmax(y_pred, dim=1) + print(classification_report(valid_labels.to('cpu'), y_pred.to('cpu'))) + torch.save(model.state_dict(), model_path) + return model_path + + +class TorchModelForXGBoost(nn.Module): + def __init__(self, xgb_model): + super().__init__() + self.xgb_model = xgb_model + self._dummy_layer = nn.Linear(1, 1) + + def forward(self, x: torch.Tensor): + out = self.xgb_model.predict_proba(x.cpu().numpy()) + return torch.from_numpy(out) + + + +#------------------------------------------------------------------------------# +#----------------------------------- kmeans -----------------------------------# +#------------------------------------------------------------------------------# + +from sklearn.cluster import KMeans as SklearnKMeans +from pnpxai.explainers.utils.baselines import BaselineFunction +from pnpxai.explainers.base import Tunable +from pnpxai.explainers.types import TunableParameter + + +class KMeans(BaselineFunction, Tunable): + def __init__(self, background_data, n_clusters=8): + self.background_data = background_data + self.n_clusters = TunableParameter( + name='n_clusters', + current_value=n_clusters, + dtype=int, + is_leaf=True, + space={'low': 8, 'high': len(background_data)//10, 'step': 10}, + ) + self.kmeans_ = SklearnKMeans(n_clusters).fit(background_data) + BaselineFunction.__init__(self) + Tunable.__init__(self) + self.register_tunable_params([self.n_clusters]) + + def __call__(self, inputs): + cluster_ids = self.kmeans_.predict(inputs.numpy()) + cluster_centers = self.kmeans_.cluster_centers_[cluster_ids] + return torch.from_numpy(cluster_centers).to(inputs.device) + + +#------------------------------------------------------------------------------# +#----------------------------------- main -------------------------------------# +#------------------------------------------------------------------------------# + +def main_resnet(args): + # setup + use_gpu = torch.cuda.is_available() and not args.disable_gpu + device = torch.device('cuda' if use_gpu else 'cpu') + torch.manual_seed(args.seed) + + # prepare data + data_fpth = download_data(args.data_dir) # download data and get the filepath + data = preprocess_data(data_fpth) # load and preprocess data + + train_set = PandasDataset(data['x_train'], data['y_train']) + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + collate_fn=collate_fn, + num_workers=args.num_workers, + shuffle=True, + pin_memory=use_gpu, + ) + valid_set = PandasDataset(data['x_valid'], data['y_valid']) + valid_loader = DataLoader( + valid_set, + batch_size=args.batch_size, + collate_fn=collate_fn, + num_workers=args.num_workers, + shuffle=False, + pin_memory=use_gpu, + ) + test_set = PandasDataset(data['x_test'], data['y_test']) + if args.fast_dev_run: + indices = list(range(args.batch_size*2)) + test_set = Subset(test_set, indices=indices) + test_loader = DataLoader( + test_set, + batch_size=args.batch_size, + collate_fn=collate_fn, + num_workers=args.num_workers, + shuffle=False, + pin_memory=use_gpu, + ) + valid_inputs, valid_labels = next(iter(valid_loader)) # use small sample for validation + valid_inputs, valid_labels = valid_inputs.to(device), valid_labels.to(device) + + + # prepare model + model = TabResNet( + in_features=len(data['x_train'].columns), + out_features=len(data['y_train'].unique()), + ) + ckpt_fpth = train( + model_dir=args.log_dir, + checkpoint_nm='tabresnet.pkl', + model=model, + train_loader=train_loader, + valid_inputs=valid_inputs, + valid_labels=valid_labels, + device=device, + ) + model.load_state_dict(torch.load(ckpt_fpth)) + model.to(device).eval() + + # prepare modality + sample_batch = next(iter(test_loader)) + modality = Modality( + dtype=sample_batch[0].dtype, + ndims=sample_batch[0].dim(), + ) + + ''' + #--------------------------------------------------------------------------# + #------------------------------- recommend --------------------------------# + #--------------------------------------------------------------------------# + + # You can get pnpxai recommendation results without AutoExplanation as followings: + + recommended = XaiRecommender().recommend( + modality=modality, + model=model, + ) + + recommended.print_tabular() + ''' + + ''' + #--------------------------------------------------------------------------# + #------------------------------ experiment --------------------------------# + #--------------------------------------------------------------------------# + + # You can manually create experiment as followings: + expr = Experiment( + model=model, + data=test_loader, + modality=modality, + target_input_keys=[0], # feature location in batch from dataloader + target_class_extractor=lambda outputs: outputs.argmax(-1), # extract target class from output batch + label_key=-1, # label location in input batch from dataloader + ) + + # add recommended explainers recommended + camel_to_snake = lambda name: re.sub(r'(? 0, + self.annotations) + ) + + self.images = glob.glob(os.path.join(images_dir, '*.jpg')) + self.id_to_image = { + int(os.path.basename(image).replace('.', '_').split('_')[-2]): image + for image in self.images + } + + def __len__(self): + return len(self.annotations) + + def __getitem__(self, idx): + # get image + text + annotation = self.annotations[idx] + question = self.questions[idx]['question'] + image = Image.open(self.id_to_image[annotation['image_id']]).convert('RGB') + label_loc = np.argmax(annotation['scores']) + label_id = annotation['labels'][label_loc] + label = self.id2label[label_id] + return { + 'image': image, + 'question': question, + 'answer': label + } + + +class CollateSamplesFromVqaDataset: + def __init__(self, processor, label2id): + self.processor = processor + self.label2id = label2id + + def __call__(self, batch): + images = [d.pop('image') for d in batch] + questions = [d.pop('question') for d in batch] + answers = [d.pop('answer') for d in batch] + batch = self.processor( + images, questions, + return_tensors='pt', + padding=True, + truncation=True, + ) + batch['labels'] = torch.tensor([ + self.label2id[ans] for ans in answers]) + return batch + + +#------------------------------------------------------------------------------# +#------------------------------ control group ---------------------------------# +#------------------------------------------------------------------------------# + + +# The other frameworks (omnixai, autoxai, xaitk, openxai) do not support vqa +CAPTUM_EXPLAINERS = { + 'grad_x_input': captum.attr.InputXGradient, + 'integrated_gradients': captum.attr.IntegratedGradients, + 'kernel_shap': captum.attr.KernelShap, + 'lime': captum.attr.Lime, + # 'lrp_uniform_epsilon': captum.attr.LRP, # rule is not assigned to embedding layer +} + + +#------------------------------------------------------------------------------# +#---------------------------------- metrics -----------------------------------# +#------------------------------------------------------------------------------# + +class CompoundMetric(Metric): + def __init__( + self, + model, + metrics, + weights, + explainer=None, + target_input_keys=None, + additional_input_keys=None, + output_modifier=None, + ): + super().__init__( + model, explainer, target_input_keys, + additional_input_keys, output_modifier, + ) + assert len(metrics) == len(weights) + self.metrics = metrics + self.weights = weights + + def evaluate(self, inputs, targets, attrs): + values = torch.zeros(attrs.size(0)).to(attrs.device) + for weight, metric in zip(self.weights, self.metrics): + values += weight * metric.set_explainer(self.explainer).evaluate(inputs, targets, attrs) + return values + + +#------------------------------------------------------------------------------# +#---------------------------------- records ------------------------------------# +#------------------------------------------------------------------------------# + + +FIELDS = ['id', 'model', 'source', 'explainer', 'metric', 'value', 'params'] + + +def get_last_index(jsonl_file): + if not os.path.exists(jsonl_file): + return -1 # No previous records + + with open(jsonl_file, "r", encoding="utf-8") as file: + lines = file.readlines() + if lines: + last_record = json.loads(lines[-1]) # Read last JSON object + return last_record["id"] # Get last recorded index + return -1 + + +def filter_records(jsonl_file, **filters): + results = [] + if not os.path.exists(jsonl_file): + return results + with open(jsonl_file, 'r', encoding='utf-8') as f: + for line in f: + record = json.loads(line) + if all(record.get(k) == v for k, v in filters.items() if v is not None): + results.append(record) + return results + +def write_record(jsonl_file, record): + with open(jsonl_file, 'a', encoding='utf-8') as f: + f.write(json.dumps(record) + '\n') + +#------------------------------------------------------------------------------# +#----------------------------------- main -------------------------------------# +#------------------------------------------------------------------------------# + +def main(args): + # setup + use_gpu = torch.cuda.is_available() and not args.disable_gpu + device = torch.device('cuda' if use_gpu else 'cpu') + + # prepare model + model = load_pretrained_model(args.model) + model.to(device) + model.eval() + + # prepare dataloader + dataset = VqaDataset( + data_dir=args.data_dir, + label2id=model.config.label2id, + id2label=model.config.id2label, + fast_dev_run=args.fast_dev_run, + ) + if not args.fast_dev_run: + random.seed(args.seed) + sample_indices = random.sample(range(len(dataset)), args.n_samples) + dataset = Subset(dataset, indices=sample_indices) + collate_fn = CollateSamplesFromVqaDataset( + processor={ + 'vilt': ViltProcessor.from_pretrained('dandelin/vilt-b32-finetuned-vqa'), + 'visual_bert': None, + }.get(args.model), + label2id=model.config.label2id, + ) + dataloader = DataLoader( + dataset, + collate_fn=collate_fn, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=use_gpu, + ) + + # prepare modality + sample_batch = next(iter(dataloader)) + img_modality = Modality( + dtype=sample_batch['pixel_values'].dtype, + ndims=sample_batch['pixel_values'].dim(), + pooling_dim=1, + ) + + + #--------------------------------------------------------------------------# + #--------------------------- auto explanation -----------------------------# + #--------------------------------------------------------------------------# + + # create experiment + expr = AutoExplanation( + model=model, + data=dataloader, + modality=img_modality, + target_input_keys=['pixel_values'], # img + additional_input_keys=[ + 'input_ids', # Do not target input_ids by regarding it as one of additional inputs + 'token_type_ids', + 'attention_mask', + 'pixel_mask', + ], + output_modifier=lambda outputs: outputs.logits, + target_class_extractor=lambda modified_outputs: modified_outputs.argmax(-1), + label_key='labels', + ) + + # update explainers + target_explainer_keys = [] if args.explainers is None else args.explainers.split(',') + for explainer_key in expr.explainers.choices: + if explainer_key not in target_explainer_keys: + expr.explainers.delete(explainer_key) + + # update metrics + expr.metrics.delete('morf') + expr.metrics.delete('lerf') + + expr.metrics.add('cmpx', Complexity) + expr.metrics.add('cmpd', CompoundMetric) + + # update util functions + for feature_mask_fn_key in expr.modality.util_functions['feature_mask_fn'].choices: + if feature_mask_fn_key != 'checkerboard': + expr.modality.util_functions['feature_mask_fn'].delete(feature_mask_fn_key) + + # set log file + result_dir = os.path.join(args.base_dir, 'logs') + os.makedirs(result_dir, exist_ok=True) + record_file_name = os.path.basename(__file__).split('.')[0] + '.jsonl' + if args.fast_dev_run: + record_file_name = 'dev_' + record_file_name + + + # optimize all + best_params = defaultdict(dict) + combs = list(itertools.product( + expr.explainers.choices, + expr.metrics.choices, + )) + pbar = tqdm(combs, total=len(combs)) + captum_explainers = {} + for idx, (explainer_key, metric_key) in enumerate(pbar): + if expr.is_tunable(explainer_key): + results = filter_records( + record_file_name, + model=args.model, + source='pnpxai', + explainer=explainer_key, + metric=metric_key, + ) + if not results or args.fast_dev_run: + pbar.set_description(f'[{idx}] Optimizing {explainer_key} on {metric_key}') + metric_options = {} + if metric_key == 'cmpd': + metric_options['metrics'] = [ + expr.create_metric('abpc'), + expr.create_metric('cmpx'), + ] + metric_options['weights'] = [.8, -.2] + disable_tunable_params = {} + if explainer_key in ['lime', 'kernel_shap']: + disable_tunable_params['n_samples'] = 30 + opt_results = expr.optimize( + explainer_key=explainer_key, + metric_key=metric_key, + metric_options=metric_options, + direction={ + 'abpc': 'maximize', + 'cmpx': 'minimize', + 'cmpd': 'maximize', + }.get(metric_key), + disable_tunable_params=disable_tunable_params, + sampler='random', + seed=args.seed, + show_progress=not args.fast_dev_run, + n_trials=10 if args.fast_dev_run else 100, + num_threads=16, + errors='raise' if args.fast_dev_run else 'ignore', + ) + write_record( + record_file_name, + { + 'id': get_last_index(record_file_name)+1, + 'model': args.model, + 'source': 'pnpxai', + 'explainer': explainer_key, + 'metric': metric_key, + 'value': opt_results.value, + 'params': opt_results.params, + } + ) + captum_explainer_cls = CAPTUM_EXPLAINERS.get(explainer_key) + if captum_explainer_cls is not None: + captum_explainers[explainer_key] = captum_explainer_cls(expr._wrapped_model) + + captum_evals = defaultdict(lambda: defaultdict(int)) + pp_default = PostProcessor(modality=img_modality) + feature_mask_fn_default = Felzenszwalb() + pbar = tqdm(dataloader, total=len(dataloader)) + for batch in pbar: + inputs = expr._wrapped_model.extract_inputs(batch) # dict input + forward_args = tuple(expr._wrapped_model.extract_target_inputs(batch).values()) # tuple target inputs for captum + additional_forward_args = tuple(expr._wrapped_model.extract_additional_inputs(batch).values()) # tuple additional inputs for captum + outputs = expr._forward_batch(batch) + targets = expr.target_class_extractor(outputs) + for explainer_key, explainer in captum_explainers.items(): + attr_kwargs = { + 'inputs': forward_args[0], + 'target': targets, + 'additional_forward_args': additional_forward_args, + } + if 'feature_mask' in inspect.signature(explainer.attribute).parameters: + attr_kwargs['feature_mask'] = feature_mask_fn_default(forward_args[0]) + attrs = explainer.attribute(**attr_kwargs) + attrs_pp = pp_default(attrs) + for metric_key in expr.metrics.choices: + pbar.set_description(f'Evaluating captum {explainer_key} on {metric_key}') + metric_options = {} + if metric_key == 'cmpd': + metric_options['metrics'] = [ + expr.create_metric('abpc'), + expr.create_metric('cmpx'), + ] + metric_options['weights'] = [.8, -.2] + metric = expr.create_metric(metric_key, **metric_options) + evals_ = metric.set_explainer(explainer).evaluate(inputs, targets, attrs) + captum_evals[explainer_key][metric_key] += evals_.sum().item() + for explainer_key, evals in captum_evals.items(): + for metric_key, value in evals.items(): + write_record( + record_file_name, + { + 'id': get_last_index(record_file_name)+1, + 'model': args.model, + 'source': 'captum', + 'explainer': explainer_key, + 'metric': metric_key, + 'value': value / len(dataloader.dataset), + 'params': None, + } + ) + + +if __name__ == '__main__': + args = parser.parse_args() + main(args) + diff --git a/tutorials/visualizer_example.py b/tutorials/visualizer_example.py index 34865d84..4b15f226 100644 --- a/tutorials/visualizer_example.py +++ b/tutorials/visualizer_example.py @@ -1,5 +1,5 @@ from pnpxai.visualizer.visualizer import Visualizer -from pnpxai import AutoExplanationForImageClassification +from pnpxai import AutoExplanation, Modality import torch import numpy as np import os @@ -12,6 +12,12 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model, loader, transform = load_model_and_dataloader_for_tutorial("image", device) +sample_batch = next(iter(loader)) +modality = Modality( + dtype=sample_batch[0].dtype, + ndims=sample_batch[0].dim(), + pooling_dim=1, +) # ------------------------------------------------------------------------------# @@ -19,13 +25,14 @@ # ------------------------------------------------------------------------------# -expr = AutoExplanationForImageClassification( +expr = AutoExplanation( model=model, data=loader, - input_extractor=lambda batch: batch[0].to(device), - label_extractor=lambda batch: batch[1].to(device), - target_extractor=lambda outputs: outputs.argmax(-1).to(device), - channel_dim=1, + modality=modality, + target_input_keys=[0], + target_class_extractor=lambda outputs: outputs.argmax(-1), + label_key=-1, + target_labels=True, ) @@ -42,4 +49,4 @@ def input_visualizer(datum: torch.Tensor) -> np.ndarray: experiment=expr, input_visualizer=input_visualizer, ) -visualizer.launch() +visualizer.launch(share=True)