Skip to content

Model Interpretation Examples

This page demonstrates model interpretation and visualization techniques for understanding predictions.

Basic Interpretation

Visualize and interpret model predictions using various techniques.

import autotimm as at  # recommended alias
from autotimm import ImageClassifier
from autotimm.interpretation import ModelInterpreter
import torch
from PIL import Image


def main():
    # Load trained model
    model = ImageClassifier.load_from_checkpoint("best-model.ckpt", compile_model=False)
    model.eval()

    # Create interpreter
    interpreter = ModelInterpreter(model)

    # Load and preprocess image
    image = Image.open("test_image.jpg")

    # Get prediction with interpretation
    result = interpreter.interpret(
        image,
        methods=["gradcam", "integrated_gradients", "occlusion"],
        target_layer="layer4",  # ResNet final layer
    )

    # Visualize results
    interpreter.visualize(
        image,
        result,
        save_path="interpretation_results.png",
        show_original=True,
    )

    # Print prediction details
    print(f"Predicted class: {result['class']}")
    print(f"Confidence: {result['confidence']:.2%}")
    print(f"Attribution scores: {result['attributions']}")


if __name__ == "__main__":
    main()

Interpretation Methods:

Method Description Best For
GradCAM Gradient-based class activation mapping Quick visualization
Integrated Gradients Path integral of gradients Precise attribution
Occlusion Systematically mask input regions Understanding importance
LIME Local interpretable model approximation Decision boundaries
SHAP Shapley values for feature importance Comprehensive analysis

Interpretation Metrics

Evaluate interpretation quality with quantitative metrics.

from autotimm.interpretation import ModelInterpreter, InterpretationMetrics


def main():
    model = ImageClassifier.load_from_checkpoint("best-model.ckpt", compile_model=False)
    interpreter = ModelInterpreter(model)
    metrics = InterpretationMetrics()

    # Load test images
    images = ["img1.jpg", "img2.jpg", "img3.jpg"]

    results = []
    for img_path in images:
        image = Image.open(img_path)

        # Get interpretation
        result = interpreter.interpret(
            image,
            methods=["gradcam"],
            target_layer="layer4",
        )

        # Compute metrics
        scores = metrics.evaluate(
            result,
            metrics=["faithfulness", "sensitivity", "complexity"],
        )

        results.append({
            "image": img_path,
            "prediction": result["class"],
            "faithfulness": scores["faithfulness"],
            "sensitivity": scores["sensitivity"],
            "complexity": scores["complexity"],
        })

    # Summary statistics
    for r in results:
        print(f"{r['image']}: {r['prediction']}")
        print(f"  Faithfulness: {r['faithfulness']:.3f}")
        print(f"  Sensitivity: {r['sensitivity']:.3f}")
        print(f"  Complexity: {r['complexity']:.3f}")


if __name__ == "__main__":
    main()

Interpretation Metrics:

  • Faithfulness: How well the interpretation reflects the model's actual reasoning
  • Sensitivity: How much the interpretation changes with small input perturbations
  • Complexity: Sparsity and simplicity of the interpretation
  • Stability: Consistency of interpretations across similar inputs

Interpretation Phase 2

Advanced interpretation techniques including counterfactual explanations.

from autotimm.interpretation import ModelInterpreter, CounterfactualGenerator


def main():
    model = ImageClassifier.load_from_checkpoint("best-model.ckpt", compile_model=False)
    interpreter = ModelInterpreter(model)
    cf_generator = CounterfactualGenerator(model)

    image = Image.open("test_image.jpg")

    # Original prediction
    original_result = interpreter.interpret(image)
    print(f"Original: {original_result['class']} ({original_result['confidence']:.2%})")

    # Generate counterfactual
    counterfactual = cf_generator.generate(
        image,
        target_class="different_class",
        max_iterations=100,
        regularization=0.01,
    )

    # Compare original and counterfactual
    cf_result = interpreter.interpret(counterfactual["image"])
    print(f"Counterfactual: {cf_result['class']} ({cf_result['confidence']:.2%})")
    print(f"Distance: {counterfactual['distance']:.4f}")

    # Visualize changes
    interpreter.visualize_counterfactual(
        original_image=image,
        counterfactual_image=counterfactual["image"],
        save_path="counterfactual_comparison.png",
    )


if __name__ == "__main__":
    main()

Counterfactual Explanations:

  • What-If Analysis: "What changes would flip the prediction?"
  • Minimal Perturbations: Find smallest changes needed
  • Feature Importance: Identify most influential features
  • Decision Boundaries: Understand classification boundaries

Interpretation Phase 3

Comprehensive interpretation with multiple techniques and analysis.

from autotimm.interpretation import (
    ModelInterpreter,
    InterpretationMetrics,
    InterpretationComparison,
)


def main():
    model = ImageClassifier.load_from_checkpoint("best-model.ckpt", compile_model=False)
    interpreter = ModelInterpreter(model)
    metrics = InterpretationMetrics()
    comparison = InterpretationComparison()

    image = Image.open("test_image.jpg")

    # Run multiple interpretation methods
    methods = ["gradcam", "integrated_gradients", "occlusion", "lime", "shap"]
    results = {}

    for method in methods:
        result = interpreter.interpret(
            image,
            methods=[method],
            target_layer="layer4",
        )

        # Evaluate each method
        scores = metrics.evaluate(result)
        results[method] = {
            "interpretation": result,
            "scores": scores,
        }

    # Compare methods
    comparison_result = comparison.compare(
        results,
        metrics=["faithfulness", "sensitivity", "agreement"],
    )

    # Visualize comparison
    comparison.visualize_comparison(
        comparison_result,
        save_path="method_comparison.png",
    )

    # Print best method
    best_method = comparison_result["best_method"]
    print(f"Best interpretation method: {best_method}")
    print(f"Scores: {results[best_method]['scores']}")


if __name__ == "__main__":
    main()

Method Comparison Criteria:

  • Agreement: How similar are different methods' results?
  • Consistency: Do methods agree on important regions?
  • Computational Cost: Time and memory requirements
  • Interpretability: How easy to understand and explain?

Interactive Visualization

Create interactive visualizations for exploring interpretations.

from autotimm.interpretation import InteractiveVisualizer
import gradio as gr


def main():
    model = ImageClassifier.load_from_checkpoint("best-model.ckpt", compile_model=False)
    visualizer = InteractiveVisualizer(model)

    # Create interactive interface
    def interpret_image(image, method, target_layer):
        result = visualizer.interpret(
            image,
            method=method,
            target_layer=target_layer,
        )
        return visualizer.create_visualization(result)

    # Build Gradio interface
    interface = gr.Interface(
        fn=interpret_image,
        inputs=[
            gr.Image(type="pil", label="Input Image"),
            gr.Dropdown(
                choices=["gradcam", "integrated_gradients", "occlusion", "lime"],
                value="gradcam",
                label="Interpretation Method",
            ),
            gr.Dropdown(
                choices=["layer1", "layer2", "layer3", "layer4"],
                value="layer4",
                label="Target Layer",
            ),
        ],
        outputs=gr.Image(type="pil", label="Interpretation"),
        title="Model Interpretation Explorer",
        description="Explore different interpretation methods for image classification",
    )

    interface.launch(share=True)


if __name__ == "__main__":
    main()

Interactive Features:

  • Real-time interpretation: See results as you change settings
  • Method comparison: Switch between interpretation methods
  • Layer selection: Explore different network layers
  • Parameter tuning: Adjust interpretation parameters
  • Export results: Save interpretations and visualizations

Comprehensive Tutorial

The comprehensive_interpretation_tutorial.ipynb notebook provides a complete guide to model interpretation.

Topics Covered:

  1. Setup and Installation
  2. Required dependencies
  3. Model loading and preparation

  4. Basic Interpretation

  5. GradCAM visualization
  6. Integrated Gradients
  7. Occlusion sensitivity

  8. Advanced Techniques

  9. LIME explanations
  10. SHAP values
  11. Counterfactual generation

  12. Quantitative Evaluation

  13. Interpretation metrics
  14. Method comparison
  15. Statistical analysis

  16. Practical Applications

  17. Debugging models
  18. Understanding failures
  19. Building trust in predictions

  20. Interactive Tools

  21. Building dashboards
  22. User interfaces
  23. Real-time interpretation

Running the Notebook:

# Install Jupyter
pip install jupyter notebook

# Launch notebook
jupyter notebook examples/interpretation/comprehensive_interpretation_tutorial.ipynb

Running Examples

python examples/interpretation/interpretation_demo.py
python examples/interpretation/interpretation_metrics_demo.py
python examples/interpretation/interpretation_phase2_demo.py
python examples/interpretation/interpretation_phase3_demo.py
python examples/interpretation/interactive_visualization_demo.py

# Run the comprehensive notebook
jupyter notebook examples/interpretation/comprehensive_interpretation_tutorial.ipynb

See Also