To truly trust a model, we must look beyond its inputs. By visualizing the internal layers and attention mechanisms, we see the patterns the AI has truly learned.
1Visualizing Vision: Grad-CAM
For Convolutional Neural Networks (CNNs), we use Grad-CAM (Gradient-weighted Class Activation Mapping). This technique looks at the gradients of a specific class flowing into the final convolutional layer. It produces a Heatmap that is overlaid on the original image, showing exactly which pixels were 'responsible' for the classification. If a model classifies an image as 'Pneumonia', Grad-CAM shows the doctor exactly which area of the X-ray lung the AI was looking at.
// Grad-CAM Implementation Concept
function getGradCAM(image, model, targetClass) {
const finalConvLayer = model.getLayer('conv_final');
// Calculate gradients of the target class
// with respect to the feature map
const gradients = computeGradients(
targetClass, finalConvLayer
);
// Generate heatmap
return generateHeatmap(gradients, finalConvLayer);
}2The Focus of Language: Attention
In Transformer models (like BERT or GPT), the Attention Mechanism is the key to understanding. An Attention Map is a visualization of the 'attention weights' that connect words in a sentence. It shows us if the model correctly connects a pronoun (like 'it') to the correct noun ('the ball'). If a model's attention is focused on irrelevant words, it's a sign that the model lacks the context needed for high-quality language generation.
// Extracting Attention Weights
function visualizeAttention(sentence, model) {
const tokens = tokenize(sentence);
// Get attention matrix from Layer 12, Head 4
const attentionMatrix = model.getAttentionWeights(
tokens, 12, 4
);
plotAttentionMap(tokens, attentionMatrix);
}3Shortcut Learning
Internal interpretation is vital for detecting Shortcut Learning (or the 'Clever Hans' effect). This occurs when a model finds a simple, unintended correlation to solve a task. For example, a model might learn to detect 'Cancer' with 99% accuracy because all the cancer images were taken with a specific hospital's ruler in the frame. Without XAI heatmaps, you might deploy this 'perfect' model, only for it to fail when used at a different hospital without that specific ruler.
// Debugging a Clever Hans Model
function runAudit(model, testImages) {
for (let img of testImages) {
let heatmap = getGradCAM(img, model);
// If the model is looking at the ruler instead of
// the tissue, we have a shortcut learning problem.
if (heatmap.locates("ruler_pixels")) {
flagForRetraining(model);
}
}
}