Source code for tf_keras_vis.activation_maximization.callbacks

import warnings
from abc import ABC
from contextlib import contextmanager
from inspect import signature

import imageio
import numpy as np
import tensorflow as tf
from deprecated import deprecated
from PIL import Image, ImageDraw, ImageFont

from ..utils import listify


[docs] class Callback(ABC): """Abstract class for defining callbacks. """
[docs] def on_begin(self, **kwargs) -> None: """Called at the begin of optimization process. Args: kwargs: The parameters that was passed to :obj:`tf_keras_vis.activation_maximization.ActivationMaximization.__call__()`. """ pass
[docs] def __call__(self, i, values, grads, scores, model_outputs, **kwargs) -> None: """This function will be called after updating input values by gradient descent in :obj:`tf_keras_vis.activation_maximization.ActivationMaximization.__call__()`. Args: i: The current number of optimizer iteration. values: A list of tf.Tensor that indicates current `values`. grads: A list of tf.Tensor that indicates the gradients with respect to model input. scores: A list of tf.Tensor that indicates score values with respect to each the model outputs. model_outputs: A list of tf.Tensor that indicates the model outputs. regularizations: A list of tuples of (str, tf.Tensor) that indicates the regularizer values. overall_score: A list of tf.Tensor that indicates the overall scores that includes the scores and regularization values. """ pass
[docs] def on_end(self) -> None: """Called at the end of optimization process. """ pass
[docs] @deprecated(version='0.7.0', reason="Use `Progress` instead.") class PrintLogger(Callback): """Callback to print values during optimization. Warnings: This class is now **deprecated**! Please use :obj:`tf_keras_vis.activation_maximization.callbacks.Progress` instead. """ def __init__(self, interval=10): """ Args: interval: An integer that indicates the interval of printing. Defaults to 10. """ self.interval = interval def __call__(self, i, values, grads, scores, model_outputs, regularizations, **kwargs): i += 1 if (i % self.interval == 0): tf.print('Steps: {:03d}\tScores: {},\tRegularization: {}'.format( i, self._tolist(scores), self._tolist(regularizations))) def _tolist(self, ary): if isinstance(ary, list) or isinstance(ary, (np.ndarray, np.generic)): return [self._tolist(e) for e in ary] elif isinstance(ary, tuple): return tuple(self._tolist(e) for e in ary) elif tf.is_tensor(ary): return ary.numpy().tolist() else: return ary
[docs] class GifGenerator2D(Callback): """Callback to construct a gif of optimized image. """ def __init__(self, path) -> None: """ Args: path: The file path to save gif. """ self.path = path def on_begin(self, **kwargs) -> None: self.data = None def __call__(self, i, values, *args, **kwargs) -> None: if self.data is None: self.data = [[] for _ in range(len(values))] for n, value in enumerate(values): value = value[0].numpy() if tf.is_tensor(value[0]) else value[0] img = Image.fromarray(value.astype(np.uint8)) # 1st image in the batch ImageDraw.Draw(img).text((10, 10), f"Step {i + 1}", font=ImageFont.load_default()) self.data[n].append(np.asarray(img)) def on_end(self) -> None: path = self.path if self.path.endswith(".gif") else f"{self.path}.gif" for i in range(len(self.data)): with imageio.get_writer(path, mode='I', loop=0) as writer: for data in self.data[i]: writer.append_data(data)
[docs] class Progress(Callback): """Callback to print values during optimization. """ def on_begin(self, steps=None, **kwargs) -> None: self.progbar = tf.keras.utils.Progbar(steps) def __call__(self, i, values, grads, scores, model_outputs, regularizations, **kwargs) -> None: if len(scores) > 1: scores = [(f"Score[{j}]", score_value) for j, score_value in enumerate(scores)] else: scores = [("Score", score_value) for score_value in scores] scores += regularizations self.progbar.update(i + 1, scores + regularizations)
@contextmanager def managed_callbacks(callbacks=None, **kwargs): activated_callbacks = [] try: for c in listify(callbacks): if len(signature(c.on_begin).parameters) == 0: warnings.warn("`Callback#on_begin()` now must accept keyword arguments.", DeprecationWarning) c.on_begin() else: c.on_begin(**kwargs) activated_callbacks.append(c) yield activated_callbacks for _ in range(len(activated_callbacks)): activated_callbacks.pop(0).on_end() finally: for c in activated_callbacks: try: c.on_end() except Exception as e: tf.print("Exception args: ", e) pass