sicara/tf-explain
Interpretability Methods for tf.keras models with Tensorflow 2.0
repo name | sicara/tf-explain |
repo link | https://github.com/sicara/tf-explain |
homepage | https://tf-explain.readthedocs.io |
language | Python |
size (curr.) | 917 kB |
stars (curr.) | 550 |
created | 2019-07-15 |
license | MIT License |
tf-explain
tf-explain implements interpretability methods as Tensorflow 2.0 callbacks to ease neural network’s understanding.
See Introducing tf-explain, Interpretability for Tensorflow 2.0
Documentation: https://tf-explain.readthedocs.io
Installation
tf-explain is available on PyPi as an alpha release. To install it:
virtualenv venv -p python3.6
pip install tf-explain
tf-explain is compatible with Tensorflow 2. It is not declared as a dependency to let you choose between CPU and GPU versions. Additionally to the previous install, run:
# For CPU or GPU
pip install tensorflow==2.1.0
Available Methods
- Activations Visualization
- Vanilla Gradients
- Gradients*Inputs
- Occlusion Sensitivity
- Grad CAM (Class Activation Maps)
- SmoothGrad
- Integrated Gradients
Activations Visualization
Visualize how a given input comes out of a specific activation layer
from tf_explain.callbacks.activations_visualization import ActivationsVisualizationCallback
model = [...]
callbacks = [
ActivationsVisualizationCallback(
validation_data=(x_val, y_val),
layers_name=["activation_1"],
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Vanilla Gradients
Visualize gradients importance on input image
from tf_explain.callbacks.vanilla_gradients import VanillaGradientsCallback
model = [...]
callbacks = [
VanillaGradientsCallback(
validation_data=(x_val, y_val),
class_index=0,
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Gradients*Inputs
Variant of Vanilla Gradients ponderating gradients with input values
from tf_explain.callbacks.gradients_inputs import GradientsInputsCallback
model = [...]
callbacks = [
GradientsInputsCallback(
validation_data=(x_val, y_val),
class_index=0,
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Occlusion Sensitivity
Visualize how parts of the image affects neural network’s confidence by occluding parts iteratively
from tf_explain.callbacks.occlusion_sensitivity import OcclusionSensitivityCallback
model = [...]
callbacks = [
OcclusionSensitivityCallback(
validation_data=(x_val, y_val),
class_index=0,
patch_size=4,
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Grad CAM
Visualize how parts of the image affects neural network’s output by looking into the activation maps
From Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
from tf_explain.callbacks.grad_cam import GradCAMCallback
model = [...]
callbacks = [
GradCAMCallback(
validation_data=(x_val, y_val),
class_index=0,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
SmoothGrad
Visualize stabilized gradients on the inputs towards the decision
From SmoothGrad: removing noise by adding noise
from tf_explain.callbacks.smoothgrad import SmoothGradCallback
model = [...]
callbacks = [
SmoothGradCallback(
validation_data=(x_val, y_val),
class_index=0,
num_samples=20,
noise=1.,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Integrated Gradients
Visualize an average of the gradients along the construction of the input towards the decision
From Axiomatic Attribution for Deep Networks
from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback
model = [...]
callbacks = [
IntegratedGradientsCallback(
validation_data=(x_val, y_val),
class_index=0,
n_steps=20,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Visualizing the results
When you use the callbacks, the output files are created in the logs
directory.
You can see them in Tensorboard with the following command: tensorboard --logdir logs
Roadmap
- Subclassing API Support
- Additional Methods
- Auto-generated API Documentation & Documentation Testing
Contributing
To contribute to the project, please read the dedicated section.